diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 68aff793ae6aa..76f6d7aeca0d8 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -5,11 +5,11 @@ import os import sys import zipfile -# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB -# Note that we have 400 MiB quota, please use it wisely. -# See https://github.com/pypi/support/issues/3792 . +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB +# Note that we have 800 MiB quota, please use it wisely. +# See https://github.com/pypi/support/issues/6326 . # Please also sync the value with the one in Dockerfile. -VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400)) +VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450)) def print_top_10_largest_files(zip_file): diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index 37e2980eea974..2ef36089b6afb 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -8,7 +8,7 @@ This benchmark aims to: Latest results: [results link](https://blog.vllm.ai/2024/09/05/perf-update.html), scroll to the end. -Latest reproduction guilde: [github issue link](https://github.com/vllm-project/vllm/issues/8176) +Latest reproduction guide: [github issue link](https://github.com/vllm-project/vllm/issues/8176) ## Setup diff --git a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py index 50431d0cd4c5e..5ea5a50a258a4 100644 --- a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py +++ b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py @@ -218,7 +218,7 @@ if __name__ == "__main__": "--xaxis", type=str, default="# of max concurrency.", - help="column name to use as X Axis in comparision graph", + help="column name to use as X Axis in comparison graph", ) args = parser.parse_args() diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json index 2d88a0b30c4f8..f758097e098e4 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc2.json @@ -1,6 +1,6 @@ [ { - "test_name": "serving_llama8B_tp1_sharegpt", + "test_name": "serving_llama8B_bf16_tp1_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -32,7 +32,7 @@ } }, { - "test_name": "serving_llama8B_tp2_sharegpt", + "test_name": "serving_llama8B_bf16_tp2_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -64,7 +64,7 @@ } }, { - "test_name": "serving_llama8B_tp4_sharegpt", + "test_name": "serving_llama8B_bf16_tp4_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -96,7 +96,7 @@ } }, { - "test_name": "serving_llama8B_tp1_random_128_128", + "test_name": "serving_llama8B_bf16_tp1_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -131,7 +131,7 @@ } }, { - "test_name": "serving_llama8B_tp2_random_128_128", + "test_name": "serving_llama8B_bf16_tp2_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -166,7 +166,7 @@ } }, { - "test_name": "serving_llama8B_tp4_random_128_128", + "test_name": "serving_llama8B_bf16_tp4_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -198,5 +198,413 @@ "random-output-len": 128, "num_prompts": 1000 } + }, + { + "test_name": "serving_llama8B_int8_tp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } } ] diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json index 823abbaa99f86..ce396d6e54f27 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json @@ -1,6 +1,6 @@ [ { - "test_name": "serving_llama8B_pp1_sharegpt", + "test_name": "serving_llama8B_bf16_pp1_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -32,7 +32,39 @@ } }, { - "test_name": "serving_llama8B_pp3_sharegpt", + "test_name": "serving_llama8B_bf16_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_bf16_pp3_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -64,7 +96,7 @@ } }, { - "test_name": "serving_llama8B_tp2pp3_sharegpt", + "test_name": "serving_llama8B_bf16_tp2pp3_sharegpt", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], "server_environment_variables": { @@ -97,7 +129,7 @@ } }, { - "test_name": "serving_llama8B_pp1_random_128_128", + "test_name": "serving_llama8B_bf16_pp1_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -132,7 +164,42 @@ } }, { - "test_name": "serving_llama8B_pp3_random_128_128", + "test_name": "serving_llama8B_bf16_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_bf16_pp3_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -167,7 +234,7 @@ } }, { - "test_name": "serving_llama8B_tp2pp3_random_128_128", + "test_name": "serving_llama8B_bf16_tp2pp3_random_128_128", "qps_list": ["inf"], "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], "server_environment_variables": { @@ -201,5 +268,553 @@ "ignore-eos": "", "num_prompts": 1000 } + }, + { + "test_name": "serving_llama8B_int8_pp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_tp2pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int8_pp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int8_tp2pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_pp1_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_tp2pp3_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_int4_pp1_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp2_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, + { + "test_name": "serving_llama8B_int4_tp2pp3_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 2, + "pipeline_parallel_size": 3, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } } ] diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 92a1bcada3879..8c6ef7817aaf8 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,21 +1,22 @@ steps: - # aarch64 + CUDA builds - - label: "Build arm64 wheel - CUDA 12.8" - id: build-wheel-arm64-cuda-12-8 + # aarch64 + CUDA builds. PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9 + - label: "Build arm64 wheel - CUDA 12.9" + depends_on: ~ + id: build-wheel-arm64-cuda-12-9 agents: queue: arm64_cpu_queue_postmerge commands: # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" - # x86 + CUDA builds - label: "Build wheel - CUDA 12.8" + depends_on: ~ id: build-wheel-cuda-12-8 agents: queue: cpu_queue_postmerge @@ -27,12 +28,8 @@ steps: env: DOCKER_BUILDKIT: "1" - - block: "Build CUDA 12.6 wheel" - key: block-build-cu126-wheel - depends_on: ~ - - label: "Build wheel - CUDA 12.6" - depends_on: block-build-cu126-wheel + depends_on: ~ id: build-wheel-cuda-12-6 agents: queue: cpu_queue_postmerge @@ -44,18 +41,14 @@ steps: env: DOCKER_BUILDKIT: "1" - # Note(simon): We can always build CUDA 11.8 wheel to ensure the build is working. - # However, this block can be uncommented to save some compute hours. - # - block: "Build CUDA 11.8 wheel" - # key: block-build-cu118-wheel - - - label: "Build wheel - CUDA 11.8" - # depends_on: block-build-cu118-wheel - id: build-wheel-cuda-11-8 + # x86 + CUDA builds + - label: "Build wheel - CUDA 12.9" + depends_on: ~ + id: build-wheel-cuda-12-9 agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -75,6 +68,7 @@ steps: - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + # PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9 - label: "Build release image (arm64)" depends_on: ~ id: build-release-image-arm64 @@ -82,7 +76,7 @@ steps: queue: arm64_cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" # Add job to create multi-arch manifest @@ -102,8 +96,6 @@ steps: depends_on: - create-multi-arch-manifest - build-wheel-cuda-12-8 - - build-wheel-cuda-12-6 - - build-wheel-cuda-11-8 id: annotate-release-workflow agents: queue: cpu_queue_postmerge @@ -150,18 +142,24 @@ steps: env: DOCKER_BUILDKIT: "1" - - block: "Build Neuron release image" - key: block-neuron-release-image-build - depends_on: ~ - - - label: "Build and publish Neuron release image" - depends_on: block-neuron-release-image-build + - label: "Build and publish nightly multi-arch image to DockerHub" + depends_on: + - create-multi-arch-manifest + if: build.env("NIGHTLY") == "1" agents: - queue: neuron-postmerge + queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." - - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest" - - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" + - "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly" + - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly-$BUILDKITE_COMMIT" + - "docker push vllm/vllm-openai:nightly" + - "docker push vllm/vllm-openai:nightly-$BUILDKITE_COMMIT" + # Clean up old nightly builds (keep only last 14) + - "bash .buildkite/scripts/cleanup-nightly-builds.sh" + plugins: + - docker-login#v3.0.0: + username: vllmbot + password-env: DOCKERHUB_TOKEN env: DOCKER_BUILDKIT: "1" diff --git a/.buildkite/scripts/annotate-release.sh b/.buildkite/scripts/annotate-release.sh index 94e0ac2398f34..fde48603ad3cd 100755 --- a/.buildkite/scripts/annotate-release.sh +++ b/.buildkite/scripts/annotate-release.sh @@ -14,18 +14,33 @@ buildkite-agent annotate --style 'info' --context 'release-workflow' << EOF To download the wheel: \`\`\` aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux2014_aarch64.whl . + aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu126/vllm-${RELEASE_VERSION}+cu126-cp38-abi3-manylinux1_x86_64.whl . -aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu118/vllm-${RELEASE_VERSION}+cu118-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu129/vllm-${RELEASE_VERSION}+cu129-cp38-abi3-manylinux1_x86_64.whl . \`\`\` To download and upload the image: \`\`\` -docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} -docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} vllm/vllm-openai -docker tag vllm/vllm-openai vllm/vllm-openai:latest -docker tag vllm/vllm-openai vllm/vllm-openai:v${RELEASE_VERSION} -docker push vllm/vllm-openai:latest -docker push vllm/vllm-openai:v${RELEASE_VERSION} +docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-x86_64 +docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-aarch64 + +docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-x86_64 vllm/vllm-openai:x86_64 +docker tag vllm/vllm-openai:x86_64 vllm/vllm-openai:latest-x86_64 +docker tag vllm/vllm-openai:x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 +docker push vllm/vllm-openai:latest-x86_64 +docker push vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 + +docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-aarch64 vllm/vllm-openai:aarch64 +docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:latest-aarch64 +docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 +docker push vllm/vllm-openai:latest-aarch64 +docker push vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 + +docker manifest create vllm/vllm-openai:latest vllm/vllm-openai:latest-x86_64 vllm/vllm-openai:latest-aarch64 --amend +docker manifest create vllm/vllm-openai:v${RELEASE_VERSION} vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 --amend +docker manifest push vllm/vllm-openai:latest +docker manifest push vllm/vllm-openai:v${RELEASE_VERSION} \`\`\` EOF \ No newline at end of file diff --git a/.buildkite/scripts/cleanup-nightly-builds.sh b/.buildkite/scripts/cleanup-nightly-builds.sh new file mode 100755 index 0000000000000..1a82f7d085233 --- /dev/null +++ b/.buildkite/scripts/cleanup-nightly-builds.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +set -ex + +# Clean up old nightly builds from DockerHub, keeping only the last 14 builds +# This script uses DockerHub API to list and delete old tags with "nightly-" prefix + +# DockerHub API endpoint for vllm/vllm-openai repository +REPO_API_URL="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags" + +# Get DockerHub token from environment +if [ -z "$DOCKERHUB_TOKEN" ]; then + echo "Error: DOCKERHUB_TOKEN environment variable is not set" + exit 1 +fi + +# Function to get all tags from DockerHub +get_all_tags() { + local page=1 + local all_tags="" + + while true; do + local response=$(curl -s -H "Authorization: Bearer $DOCKERHUB_TOKEN" \ + "$REPO_API_URL?page=$page&page_size=100") + + # Get both last_updated timestamp and tag name, separated by | + local tags=$(echo "$response" | jq -r '.results[] | select(.name | startswith("nightly-")) | "\(.last_updated)|\(.name)"') + + if [ -z "$tags" ]; then + break + fi + + all_tags="$all_tags$tags"$'\n' + page=$((page + 1)) + done + + # Sort by timestamp (newest first) and extract just the tag names + echo "$all_tags" | sort -r | cut -d'|' -f2 +} + +delete_tag() { + local tag_name="$1" + echo "Deleting tag: $tag_name" + + local delete_url="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags/$tag_name" + local response=$(curl -s -X DELETE -H "Authorization: Bearer $DOCKERHUB_TOKEN" "$delete_url") + + if echo "$response" | jq -e '.detail' > /dev/null 2>&1; then + echo "Warning: Failed to delete tag $tag_name: $(echo "$response" | jq -r '.detail')" + else + echo "Successfully deleted tag: $tag_name" + fi +} + +# Get all nightly- prefixed tags, sorted by last_updated timestamp (newest first) +echo "Fetching all tags from DockerHub..." +all_tags=$(get_all_tags) + +if [ -z "$all_tags" ]; then + echo "No tags found to clean up" + exit 0 +fi + +# Count total tags +total_tags=$(echo "$all_tags" | wc -l) +echo "Found $total_tags tags" + +# Keep only the last 14 builds (including the current one) +tags_to_keep=14 +tags_to_delete=$((total_tags - tags_to_keep)) + +if [ $tags_to_delete -le 0 ]; then + echo "No tags need to be deleted (only $total_tags tags found, keeping $tags_to_keep)" + exit 0 +fi + +echo "Will delete $tags_to_delete old tags, keeping the newest $tags_to_keep" + +# Get tags to delete (skip the first $tags_to_keep tags) +tags_to_delete_list=$(echo "$all_tags" | tail -n +$((tags_to_keep + 1))) + +if [ -z "$tags_to_delete_list" ]; then + echo "No tags to delete" + exit 0 +fi + +# Delete old tags +echo "Deleting old tags..." +while IFS= read -r tag; do + if [ -n "$tag" ]; then + delete_tag "$tag" + # Add a small delay to avoid rate limiting + sleep 1 + fi +done <<< "$tags_to_delete_list" + +echo "Cleanup completed successfully" diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index c395011a24485..7f90181048d0f 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -167,12 +167,6 @@ if [[ $commands == *" entrypoints/llm "* ]]; then --ignore=entrypoints/llm/test_prompt_validation.py "} fi -#Obsolete currently -##ignore certain Entrypoints/llm tests -#if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then -# commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "} -#fi - # --ignore=entrypoints/openai/test_encoder_decoder.py \ # --ignore=entrypoints/openai/test_embedding.py \ # --ignore=entrypoints/openai/test_oot_registration.py diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 0f734763f13fd..64943d2a15a79 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -66,7 +66,6 @@ function cpu_tests() { pytest -x -v -s tests/models/language/pooling -m cpu_model pytest -x -v -s tests/models/multimodal/generation \ - --ignore=tests/models/multimodal/generation/test_mllama.py \ --ignore=tests/models/multimodal/generation/test_pixtral.py \ -m cpu_model" diff --git a/.buildkite/scripts/hardware_ci/run-neuron-test.sh b/.buildkite/scripts/hardware_ci/run-neuron-test.sh deleted file mode 100644 index a397457c83261..0000000000000 --- a/.buildkite/scripts/hardware_ci/run-neuron-test.sh +++ /dev/null @@ -1,64 +0,0 @@ -#!/bin/bash - -# This script build the Neuron docker image and run the API server inside the container. -# It serves a sanity check for compilation and basic model usage. -set -e -set -v - -image_name="neuron/vllm-ci" -container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" - -HF_CACHE="$(realpath ~)/huggingface" -mkdir -p "${HF_CACHE}" -HF_MOUNT="/root/.cache/huggingface" -HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN) - -NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache" -mkdir -p "${NEURON_COMPILE_CACHE_URL}" -NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache" - -# Try building the docker image -aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws - -# prune old image and containers to save disk space, and only once a day -# by using a timestamp file in tmp. -if [ -f /tmp/neuron-docker-build-timestamp ]; then - last_build=$(cat /tmp/neuron-docker-build-timestamp) - current_time=$(date +%s) - if [ $((current_time - last_build)) -gt 86400 ]; then - # Remove dangling images (those that are not tagged and not used by any container) - docker image prune -f - # Remove unused volumes / force the system prune for old images as well. - docker volume prune -f && docker system prune -f - echo "$current_time" > /tmp/neuron-docker-build-timestamp - fi -else - date "+%s" > /tmp/neuron-docker-build-timestamp -fi - -docker build -t "${image_name}" -f docker/Dockerfile.neuron . - -# Setup cleanup -remove_docker_container() { - docker image rm -f "${image_name}" || true; -} -trap remove_docker_container EXIT - -# Run the image -docker run --rm -it --device=/dev/neuron0 --network bridge \ - -v "${HF_CACHE}:${HF_MOUNT}" \ - -e "HF_HOME=${HF_MOUNT}" \ - -e "HF_TOKEN=${HF_TOKEN}" \ - -v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \ - -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ - --name "${container_name}" \ - ${image_name} \ - /bin/bash -c " - set -e; # Exit on first error - python3 /workspace/vllm/examples/offline_inference/neuron.py; - python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; - for f in /workspace/vllm/tests/neuron/2_core/*.py; do - echo \"Running test file: \$f\"; - python3 -m pytest \$f -v --capture=tee-sys; - done - " \ No newline at end of file diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 73f3e63fbf5f6..8c9b00990e995 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -30,10 +30,12 @@ docker run \ bash -c ' set -e echo $ZE_AFFINITY_MASK - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray - VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + pip install tblib==3.1.0 + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager cd tests pytest -v -s v1/core pytest -v -s v1/engine diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 745f285c008ad..43aa8c47be299 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -58,14 +58,15 @@ python3 .buildkite/generate_index.py --wheel "$normal_wheel" aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" -if [[ $normal_wheel == *"cu118"* ]]; then - # if $normal_wheel matches cu118, do not upload the index.html - echo "Skipping index files for cu118 wheels" -elif [[ $normal_wheel == *"cu126"* ]]; then +if [[ $normal_wheel == *"cu126"* ]]; then # if $normal_wheel matches cu126, do not upload the index.html echo "Skipping index files for cu126 wheels" +elif [[ $normal_wheel == *"cu128"* ]]; then + # if $normal_wheel matches cu128, do not upload the index.html + echo "Skipping index files for cu128 wheels" else - # only upload index.html for cu128 wheels (default wheels) + # only upload index.html for cu129 wheels (default wheels) as it + # is available on both x86 and arm64 aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" fi @@ -74,14 +75,15 @@ fi aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" -if [[ $normal_wheel == *"cu118"* ]]; then - # if $normal_wheel matches cu118, do not upload the index.html - echo "Skipping index files for cu118 wheels" -elif [[ $normal_wheel == *"cu126"* ]]; then +if [[ $normal_wheel == *"cu126"* ]]; then # if $normal_wheel matches cu126, do not upload the index.html echo "Skipping index files for cu126 wheels" +elif [[ $normal_wheel == *"cu128"* ]]; then + # if $normal_wheel matches cu128, do not upload the index.html + echo "Skipping index files for cu128 wheels" else - # only upload index.html for cu128 wheels (default wheels) + # only upload index.html for cu129 wheels (default wheels) as it + # is available on both x86 and arm64 aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" fi diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 55349e0ac9321..66dfc990805f2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -41,29 +41,27 @@ steps: commands: - bash standalone_tests/pytorch_nightly_dependency.sh -- label: Async Engine, Inputs, Utils, Worker Test # 24min +- label: Async Engine, Inputs, Utils, Worker Test # 36min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - - tests/mq_llm_engine - - tests/async_engine - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal - tests/utils_ - - tests/worker - tests/standalone_tests/lazy_imports.py + - tests/transformers_utils commands: - python3 standalone_tests/lazy_imports.py - - pytest -v -s mq_llm_engine # MQLLMEngine - - pytest -v -s async_engine # AsyncLLMEngine - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s multimodal - pytest -v -s utils_ # Utils - - pytest -v -s worker # Worker + - pytest -v -s transformers_utils # transformers_utils -- label: Python-only Installation Test +- label: Python-only Installation Test # 10min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - tests/standalone_tests/python_only_compile.sh @@ -71,7 +69,8 @@ steps: commands: - bash standalone_tests/python_only_compile.sh -- label: Basic Correctness Test # 30min +- label: Basic Correctness Test # 20min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] fast_check: true torch_nightly: true @@ -79,26 +78,26 @@ steps: - vllm/ - tests/basic_correctness/test_basic_correctness - tests/basic_correctness/test_cpu_offload - - tests/basic_correctness/test_preemption - tests/basic_correctness/test_cumem.py commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s basic_correctness/test_cumem.py - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py - - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py -- label: Core Test # 10min - mirror_hardwares: [amdexperimental] +- label: Entrypoints Unit Tests # 5min + timeout_in_minutes: 10 + working_dir: "/vllm-workspace/tests" fast_check: true source_file_dependencies: - - vllm/core - - vllm/distributed - - tests/core + - vllm/entrypoints + - tests/entrypoints/ commands: - - pytest -v -s core + - pytest -v -s entrypoints/openai/tool_parsers + - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling -- label: Entrypoints Test (LLM) # 40min +- label: Entrypoints Integration Test (LLM) # 30min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true @@ -109,12 +108,12 @@ steps: - tests/entrypoints/offline_mode commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests -- label: Entrypoints Test (API Server) # 40min +- label: Entrypoints Integration Test (API Server) # 100min + timeout_in_minutes: 130 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true @@ -126,10 +125,24 @@ steps: commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/ - pytest -v -s entrypoints/test_chat_utils.py -- label: Distributed Tests (4 GPUs) # 10min +- label: Entrypoints Integration Test (Pooling) + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + working_dir: "/vllm-workspace/tests" + fast_check: true + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/entrypoints/pooling + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/pooling + +- label: Distributed Tests (4 GPUs) # 35min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -172,7 +185,8 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd -- label: EPLB Algorithm Test +- label: EPLB Algorithm Test # 5min + timeout_in_minutes: 15 working_dir: "/vllm-workspace/tests" source_file_dependencies: - vllm/distributed/eplb @@ -181,6 +195,7 @@ steps: - pytest -v -s distributed/test_eplb_algo.py - label: EPLB Execution Test # 5min + timeout_in_minutes: 15 working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -189,26 +204,26 @@ steps: commands: - pytest -v -s distributed/test_eplb_execute.py -- label: Metrics, Tracing Test # 10min +- label: Metrics, Tracing Test # 12min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] num_gpus: 2 source_file_dependencies: - vllm/ - - tests/metrics - - tests/tracing + - tests/v1/tracing commands: - - pytest -v -s metrics - "pip install \ 'opentelemetry-sdk>=1.26.0' \ 'opentelemetry-api>=1.26.0' \ 'opentelemetry-exporter-otlp>=1.26.0' \ 'opentelemetry-semantic-conventions-ai>=0.4.1'" - - pytest -v -s tracing + - pytest -v -s v1/tracing ##### fast check tests ##### ##### 1 GPU test ##### -- label: Regression Test # 5min +- label: Regression Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -218,7 +233,8 @@ steps: - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test # 10min +- label: Engine Test # 25min + timeout_in_minutes: 40 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -233,7 +249,8 @@ steps: # OOM in the CI unless we run this separately - pytest -v -s tokenization -- label: V1 Test e2e + engine +- label: V1 Test e2e + engine # 30min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -244,7 +261,8 @@ steps: - pytest -v -s v1/e2e - pytest -v -s v1/engine -- label: V1 Test entrypoints +- label: V1 Test entrypoints # 35min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -252,7 +270,8 @@ steps: commands: - pytest -v -s v1/entrypoints -- label: V1 Test others +- label: V1 Test others # 42min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -276,7 +295,8 @@ steps: - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine -- label: Examples Test # 25min +- label: Examples Test # 30min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/examples" source_file_dependencies: @@ -294,14 +314,14 @@ steps: - python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - - python3 offline_inference/encoder_decoder.py - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 -- label: Platform Tests (CUDA) +- label: Platform Tests (CUDA) # 4min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -309,7 +329,8 @@ steps: commands: - pytest -v -s cuda/test_cuda_context.py -- label: Samplers Test # 36min +- label: Samplers Test # 56min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers @@ -320,15 +341,23 @@ steps: - pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers -- label: LoRA Test %N # 15min each +- label: LoRA Test %N # 20min each + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py + commands: + - pytest -v -s lora \ + --shard-id=$$BUILDKITE_PARALLEL_JOB \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --ignore=lora/test_chatglm3_tp.py \ + --ignore=lora/test_llama_tp.py \ + --ignore=lora/test_llm_with_multi_loras.py parallelism: 4 -- label: PyTorch Compilation Unit Tests +- label: PyTorch Compilation Unit Tests # 15min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -343,8 +372,10 @@ steps: - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_decorator.py + - pytest -v -s compile/test_noop_elimination.py -- label: PyTorch Fullgraph Smoke Test # 9min +- label: PyTorch Fullgraph Smoke Test # 15min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -352,13 +383,10 @@ steps: - tests/compile commands: - pytest -v -s compile/test_basic_correctness.py - # these tests need to be separated, cannot combine - - pytest -v -s compile/piecewise/test_simple.py - - pytest -v -s compile/piecewise/test_toy_llama.py - - pytest -v -s compile/piecewise/test_full_cudagraph.py - - pytest -v -s compile/piecewise/test_multiple_graphs.py + - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 18min +- label: PyTorch Fullgraph Test # 20min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -367,7 +395,8 @@ steps: commands: - pytest -v -s compile/test_full_graph.py -- label: Kernels Core Operation Test +- label: Kernels Core Operation Test # 48min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -375,7 +404,8 @@ steps: commands: - pytest -v -s kernels/core -- label: Kernels Attention Test %N +- label: Kernels Attention Test %N # 23min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/attention/ @@ -386,7 +416,8 @@ steps: - pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 -- label: Kernels Quantization Test %N +- label: Kernels Quantization Test %N # 64min + timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/quantization/ @@ -396,7 +427,8 @@ steps: - pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 -- label: Kernels MoE Test %N +- label: Kernels MoE Test %N # 40min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/quantization/cutlass_w8a8/moe/ @@ -408,7 +440,8 @@ steps: - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 2 -- label: Kernels Mamba Test +- label: Kernels Mamba Test # 31min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/mamba/ @@ -416,7 +449,8 @@ steps: commands: - pytest -v -s kernels/mamba -- label: Tensorizer Test # 11min +- label: Tensorizer Test # 14min + timeout_in_minutes: 25 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/model_loader @@ -428,7 +462,8 @@ steps: - pytest -v -s tensorizer_loader - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py -- label: Model Executor Test +- label: Model Executor Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor @@ -438,7 +473,8 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s model_executor -- label: Benchmarks # 9min +- label: Benchmarks # 11min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/.buildkite" source_file_dependencies: @@ -446,7 +482,8 @@ steps: commands: - bash scripts/run-benchmarks.sh -- label: Benchmarks CLI Test # 10min +- label: Benchmarks CLI Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ @@ -454,7 +491,8 @@ steps: commands: - pytest -v -s benchmarks/ -- label: Quantization Test +- label: Quantization Test # 70min + timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -463,10 +501,15 @@ steps: commands: # temporary install here since we need nightly, will move to requirements/test.in # after torchao 0.12 release, and pin a working version of torchao nightly here + + # since torchao nightly is only compatible with torch nightly currently + # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now + # we can only upgrade after this is resolved - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -474,7 +517,8 @@ steps: commands: - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 -- label: OpenAI API correctness +- label: OpenAI API correctness # 22min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ @@ -483,15 +527,8 @@ steps: commands: # LMEval+Transcription WER check - pytest -s entrypoints/openai/correctness/ -- label: Encoder Decoder tests # 5min - mirror_hardwares: [amdexperimental] - source_file_dependencies: - - vllm/ - - tests/encoder_decoder - commands: - - pytest -v -s encoder_decoder - -- label: OpenAI-Compatible Tool Use # 20 min +- label: OpenAI-Compatible Tool Use # 23 min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] fast_check: false source_file_dependencies: @@ -504,30 +541,82 @@ steps: ##### models test ##### -- label: Basic Models Test # 24min +- label: Basic Models Tests (Initialization) + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ - - tests/models + - tests/models/test_initialization.py commands: - - pytest -v -s models/test_transformers.py - - pytest -v -s models/test_registry.py - - pytest -v -s models/test_utils.py - - pytest -v -s models/test_vision.py - - pytest -v -s models/test_initialization.py + # Run a subset of model initialization tests + - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset -- label: Language Models Test (Standard) +- label: Basic Models Tests (Extra Initialization) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/test_initialization.py + commands: + # Only when vLLM model source is modified - test initialization of a large + # subset of supported models (the complement of the small subset in the above + # test.) Also run if model initialization test file is modified + - pytest -v -s models/test_initialization.py \ + -k 'not test_can_initialize_small_subset' \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Basic Models Tests (Other) + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/test_transformers.py + - tests/models/test_registry.py + - tests/models/test_utils.py + - tests/models/test_vision.py + commands: + - pytest -v -s models/test_transformers.py \ + models/test_registry.py \ + models/test_utils.py \ + models/test_vision.py + +- label: Language Models Tests (Standard) + timeout_in_minutes: 25 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ - tests/models/language commands: + # Test standard language models, excluding a subset of slow tests - pip freeze | grep -E 'torch' - - pytest -v -s models/language -m core_model + - pytest -v -s models/language -m 'core_model and (not slow_test)' -- label: Language Models Test (Hybrid) # 35 min +- label: Language Models Tests (Extra Standard) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/language/pooling/test_embedding.py + - tests/models/language/generation/test_common.py + - tests/models/language/pooling/test_classification.py + commands: + # Shard slow subset of standard language models tests. Only run when model + # source is modified, or when specified test files are modified + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m 'core_model and slow_test' \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Language Models Tests (Hybrid) %N + timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -538,9 +627,15 @@ steps: # Note: also needed to run plamo2 model in vLLM - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5' - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' - - pytest -v -s models/language/generation -m hybrid_model + # Shard hybrid language model tests + - pytest -v -s models/language/generation \ + -m hybrid_model \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 -- label: Language Models Test (Extended Generation) # 1hr20min +- label: Language Models Test (Extended Generation) # 80min + timeout_in_minutes: 110 mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: @@ -551,7 +646,18 @@ steps: - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' +- label: Language Models Test (PPL) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation_ppl_test + commands: + - pytest -v -s models/language/generation_ppl_test + - label: Language Models Test (Extended Pooling) # 36min + timeout_in_minutes: 50 mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: @@ -560,7 +666,18 @@ steps: commands: - pytest -v -s models/language/pooling -m 'not core_model' -- label: Multi-Modal Processor Test +- label: Language Models Test (MTEB) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling_mteb_test + commands: + - pytest -v -s models/language/pooling_mteb_test + +- label: Multi-Modal Processor Test # 44min + timeout_in_minutes: 60 source_file_dependencies: - vllm/ - tests/models/multimodal @@ -568,7 +685,8 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/processing -- label: Multi-Modal Models Test (Standard) +- label: Multi-Modal Models Test (Standard) # 60min + timeout_in_minutes: 80 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -578,7 +696,7 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work + - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 mirror_hardwares: [amdexperimental] @@ -610,7 +728,8 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' -- label: Quantized Models Test +- label: Quantized Models Test # 45 min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers/quantization @@ -640,7 +759,8 @@ steps: - python3 examples/offline_inference/audio_language.py --model-type whisper - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl -- label: Blackwell Test +- label: Blackwell Test # 38 min + timeout_in_minutes: 60 working_dir: "/vllm-workspace/" gpu: b200 # optional: true @@ -662,11 +782,12 @@ steps: # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - - pytest -v -s tests/kernels/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py - - pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py + - pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -678,10 +799,25 @@ steps: - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py +- label: GPT-OSS Eval (Blackwell) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + optional: true # disable while debugging + source_file_dependencies: + - tests/evals/gpt_oss + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - uv pip install --system 'gpt-oss[eval]==0.0.5' + - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2' + ##### 1 GPU test ##### ##### multi gpus test ##### - label: Distributed Comm Ops Test # 7min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -691,8 +827,11 @@ steps: commands: - pytest -v -s distributed/test_comm_ops.py - pytest -v -s distributed/test_shm_broadcast.py + - pytest -v -s distributed/test_shm_buffer.py + - pytest -v -s distributed/test_shm_storage.py - label: 2 Node Tests (4 GPUs in total) # 16min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -716,7 +855,8 @@ steps: - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code -- label: Distributed Tests (2 GPUs) # 40min +- label: Distributed Tests (2 GPUs) # 110min + timeout_in_minutes: 150 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -747,7 +887,8 @@ steps: # Avoid importing model tests that cause CUDA reinitialization error - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)' - - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' + - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py + - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' # test sequence parallel - pytest -v -s distributed/test_sequence_parallel.py # this test fails consistently. @@ -757,6 +898,7 @@ steps: - pytest -v -s models/multimodal/generation/test_maverick.py - label: Plugin Tests (2 GPUs) # 40min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -772,7 +914,7 @@ steps: # begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin - pip install -e ./plugins/prithvi_io_processor_plugin - pytest -v -s plugins_tests/test_io_processor_plugins.py - - pip uninstall prithvi_io_processor_plugin -y + - pip uninstall prithvi_io_processor_plugin -y # end io_processor plugins test # other tests continue here: - pytest -v -s plugins_tests/test_scheduler_plugins.py @@ -782,7 +924,8 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins -- label: Pipeline Parallelism Test # 45min +- label: Pipeline + Context Parallelism Test # 45min + timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -796,7 +939,8 @@ steps: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py -- label: LoRA TP Test (Distributed) +- label: LoRA TP Test (Distributed) # 17 min + timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] num_gpus: 4 source_file_dependencies: @@ -814,9 +958,10 @@ steps: - label: Weight Loading Multiple GPU Test # 33min + timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" - num_gpus: 2 + num_gpus: 2 optional: true source_file_dependencies: - vllm/ @@ -866,9 +1011,21 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 -- label: Qwen MoE EP Test # optional +##### H200 test ##### +- label: Distrubted Tests (H200) # optional gpu: h200 optional: true + working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - pytest -v -s tests/distributed/test_context_parallel.py + - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + +##### B200 test ##### +- label: Distributed Tests (B200) # optional + gpu: b200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - pytest -v -s tests/distributed/test_context_parallel.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000000..bc6342956109b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,32 @@ +[run] +source = vllm +omit = + */tests/* + */test_* + */__pycache__/* + */build/* + */dist/* + */vllm.egg-info/* + */third_party/* + */examples/* + */benchmarks/* + */docs/* + +[report] +exclude_lines = + pragma: no cover + def __repr__ + if self.debug: + if settings.DEBUG + raise AssertionError + raise NotImplementedError + if 0: + if __name__ == .__main__.: + class .*\bProtocol\): + @(abc\.)?abstractmethod + +[html] +directory = htmlcov + +[xml] +output = coverage.xml diff --git a/.github/.bc-linter.yml b/.github/.bc-linter.yml new file mode 100644 index 0000000000000..443dfa45af22c --- /dev/null +++ b/.github/.bc-linter.yml @@ -0,0 +1,24 @@ +# doc: https://github.com/pytorch/test-infra/blob/main/tools/stronghold/docs/bc_linter_config.md +version: 1 +paths: +# We temporarily disable globally, and will only enable with `annotations.include` +# include: +# - "vllm/v1/attetion/*.py" +# - "vllm/v1/core/*.py" +exclude: + - "**/*.py" + +scan: + functions: true # check free functions and methods + classes: true # check classes/dataclasses + public_only: true # ignore names starting with "_" at any level + +annotations: + include: # decorators that force‑include a symbol + - name: "bc_linter_include" # matched by simple name or dotted suffix + propagate_to_members: false # for classes, include methods/inner classes + exclude: # decorators that force‑exclude a symbol + - name: "bc_linter_skip" # matched by simple name or dotted suffix + propagate_to_members: true # for classes, exclude methods/inner classes + +excluded_violations: [] # e.g. ["ParameterRenamed", "FieldTypeChanged"] diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c087fd555c661..08717cdde643a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,21 +2,27 @@ # for more info about CODEOWNERS file # This lists cover the "core" components of vLLM that require careful review +/vllm/attention @LucasWilkinson /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn +/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn /vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill -/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill +/vllm/model_executor/layers/fused_moe @mgoin +/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 /vllm/model_executor/layers/mamba @tdoublep -/vllm/multimodal @DarkLight1337 @ywang96 +/vllm/model_executor/model_loader @22quinn +/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche +/vllm/v1/attention @LucasWilkinson +/vllm/v1/sample @22quinn @houseroad /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee -/vllm/reasoning @aarnphm -/vllm/entrypoints @aarnphm +/vllm/reasoning @aarnphm @chaunceyjiang +/vllm/entrypoints @aarnphm @chaunceyjiang /vllm/compilation @zou3519 @youkaichao @ProExpertProg +/vllm/distributed/kv_transfer @NickLucche @ApostaC CMakeLists.txt @tlrmchlsmth @LucasWilkinson # Any change to the VllmConfig changes can have a large user-facing impact, @@ -25,27 +31,39 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat -/vllm/v1/structured_output @mgoin @russellb @aarnphm +/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett +/vllm/v1/spec_decode @benchislett @luccafong +/vllm/v1/attention/backends/flashinfer.py @mgoin /vllm/v1/attention/backends/triton_attn.py @tdoublep +/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC +/vllm/v1/kv_cache_interface.py @heheda12345 +/vllm/v1/offloading @ApostaC # Test ownership /.buildkite/lm-eval-harness @mgoin @simon-mo -/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao -/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm -/tests/kernels @tlrmchlsmth @WoosukKwon @yewentao256 +/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm @NickLucche +/tests/evals @mgoin +/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256 /tests/models @DarkLight1337 @ywang96 -/tests/multimodal @DarkLight1337 @ywang96 -/tests/prefix_caching @comaniac @KuntaiDu +/tests/multimodal @DarkLight1337 @ywang96 @NickLucche /tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 /tests/test_inputs.py @DarkLight1337 @ywang96 /tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm /tests/v1/structured_output @mgoin @russellb @aarnphm +/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC /tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee /tests/models/language/generation/test_hybrid.py @tdoublep +/tests/v1/kv_connector/nixl_integration @NickLucche +/tests/v1/kv_connector @ApostaC +/tests/v1/offloading @ApostaC + +# Transformers backend +/vllm/model_executor/models/transformers.py @hmellor +/tests/models/test_transformers.py @hmellor # Docs /docs @hmellor @@ -67,6 +85,9 @@ mkdocs.yaml @hmellor /vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow /vllm/model_executor/models/qwen* @sighingnow +# MTP-specific files +/vllm/model_executor/models/deepseek_mtp.py @luccafong + # Mistral-specific files /vllm/model_executor/models/mistral*.py @patrickvonplaten /vllm/model_executor/models/mixtral*.py @patrickvonplaten @@ -86,3 +107,11 @@ mkdocs.yaml @hmellor /vllm/attention/ops/rocm*.py @gshtras /vllm/model_executor/layers/fused_moe/rocm*.py @gshtras +# TPU +/vllm/v1/worker/tpu* @NickLucche +/vllm/platforms/tpu.py @NickLucche +/vllm/v1/sample/tpu @NickLucche +/vllm/tests/v1/tpu @NickLucche + +# KVConnector installation files +/requirements/kv_connectors.txt @NickLucche diff --git a/.github/mergify.yml b/.github/mergify.yml index 495d207d44260..75ee3e3c55b46 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -124,9 +124,16 @@ pull_request_rules: - or: - files~=^examples/.*gpt[-_]?oss.*\.py - files~=^tests/.*gpt[-_]?oss.*\.py + - files~=^tests/entrypoints/openai/test_response_api_with_harmony.py + - files~=^tests/entrypoints/test_context.py - files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py - files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py + - files~=^vllm/entrypoints/harmony_utils.py + - files~=^vllm/entrypoints/tool_server.py + - files~=^vllm/entrypoints/tool.py + - files~=^vllm/entrypoints/context.py - title~=(?i)gpt[-_]?oss + - title~=(?i)harmony actions: label: add: @@ -164,7 +171,7 @@ pull_request_rules: - files=examples/online_serving/openai_chat_completion_structured_outputs.py - files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py - files~=^tests/v1/structured_output/ - - files=tests/v1/entrypoints/llm/test_guided_generate.py + - files=tests/v1/entrypoints/llm/test_struct_output_generate.py - files~=^vllm/v1/structured_output/ actions: label: @@ -273,6 +280,20 @@ pull_request_rules: users: - "sangstar" +- name: assign reviewer for modelopt changes + conditions: + - or: + - files~=^vllm/model_executor/layers/quantization/modelopt\.py$ + - files~=^vllm/model_executor/layers/quantization/__init__\.py$ + - files~=^tests/models/quantization/test_modelopt\.py$ + - files~=^tests/quantization/test_modelopt\.py$ + - files~=^tests/models/quantization/test_nvfp4\.py$ + - files~=^docs/features/quantization/modelopt\.md$ + actions: + assign: + users: + - "Edwardf0t1" + - name: remove 'needs-rebase' label when conflict is resolved conditions: - -conflict @@ -281,3 +302,20 @@ pull_request_rules: label: remove: - needs-rebase + +- name: label-kv-connector + description: Automatically apply kv-connector label + conditions: + - or: + - files~=^examples/online_serving/disaggregated[^/]*/.* + - files~=^examples/offline_inference/disaggregated[^/]*/.* + - files~=^examples/others/lmcache/ + - files~=^tests/v1/kv_connector/ + - files~=^vllm/distributed/kv_transfer/ + - title~=(?i)\bP/?D\b + - title~=(?i)NIXL + - title~=(?i)LMCache + actions: + label: + add: + - kv-connector \ No newline at end of file diff --git a/.github/workflows/add_label_automerge.yml b/.github/workflows/add_label_automerge.yml index 315042fbf5cf4..d8bbedef3174b 100644 --- a/.github/workflows/add_label_automerge.yml +++ b/.github/workflows/add_label_automerge.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Add label - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | github.rest.issues.addLabels({ diff --git a/.github/workflows/bc-lint.yml b/.github/workflows/bc-lint.yml new file mode 100644 index 0000000000000..823695a921321 --- /dev/null +++ b/.github/workflows/bc-lint.yml @@ -0,0 +1,29 @@ +name: BC Lint + +on: + pull_request: + types: + - opened + - synchronize + - reopened + - labeled + - unlabeled + +jobs: + bc_lint: + if: github.repository_owner == 'vllm-project' + runs-on: ubuntu-latest + steps: + - name: Run BC Lint Action + uses: pytorch/test-infra/.github/actions/bc-lint@main + with: + repo: ${{ github.event.pull_request.head.repo.full_name }} + base_sha: ${{ github.event.pull_request.base.sha }} + head_sha: ${{ github.event.pull_request.head.sha }} + suppression: ${{ contains(github.event.pull_request.labels.*.name, 'suppress-bc-linter') }} + docs_link: 'https://github.com/pytorch/test-infra/wiki/BC-Linter' + config_dir: .github + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml index d5c6b8d43a6ef..c3e132a536a42 100644 --- a/.github/workflows/cleanup_pr_body.yml +++ b/.github/workflows/cleanup_pr_body.yml @@ -16,7 +16,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: '3.12' diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml index e0ab3872d8fa3..c2b17abe811cd 100644 --- a/.github/workflows/issue_autolabel.yml +++ b/.github/workflows/issue_autolabel.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Label issues based on keywords - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | // Configuration: Add new labels and keywords here diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 195579f206a2f..e21d13b8161f3 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.12" - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 1ee605dc7bb0d..8884359fa0ce4 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Remind to run full CI on PR - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | try { diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 656f3d3fa7bc4..82844810a633a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: actions: write runs-on: ubuntu-latest steps: - - uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0 + - uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0 with: # Increasing this value ensures that changes to this workflow # propagate to all issues and PRs in days rather than months diff --git a/.gitignore b/.gitignore index 465935d488f84..b1df673e83ca8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* -# triton jit +# triton jit .triton # Byte-compiled / optimized / DLL files @@ -177,6 +177,14 @@ cython_debug/ # VSCode .vscode/ +# Claude +CLAUDE.md +.claude/ + +# Codex +AGENTS.md +.codex/ + # DS Store .DS_Store @@ -209,4 +217,4 @@ shellcheck*/ csrc/moe/marlin_moe_wna16/kernel_* # Ignore ep_kernels_workspace folder -ep_kernels_workspace/ \ No newline at end of file +ep_kernels_workspace/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c16bdeeecd07a..13ad3af97d839 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -164,9 +164,7 @@ repos: name: Validate configuration has default values and that each field has a docstring entry: python tools/validate_config.py language: python - types: [python] - pass_filenames: true - files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py + additional_dependencies: [regex] # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/.yapfignore b/.yapfignore index 2d6dcf8380cac..38158259032a6 100644 --- a/.yapfignore +++ b/.yapfignore @@ -1 +1,2 @@ collect_env.py +vllm/model_executor/layers/fla/ops/*.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f1f9a781a07a..180b896a7abac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,6 +13,10 @@ cmake_minimum_required(VERSION 3.26) # cmake --install . --component _C project(vllm_extensions LANGUAGES CXX) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + + # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") @@ -171,6 +175,16 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() +# +# Set CUDA include flags for CXX compiler. +# +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include") + if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include/cccl") + endif() +endif() + # # Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. # setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. @@ -294,7 +308,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" - "csrc/attention/mla/cutlass_mla_entry.cu" "csrc/quantization/fp8/per_token_group_quant.cu") set_gencode_flags_for_srcs( @@ -581,7 +594,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) set(SRCS - "csrc/attention/mla/cutlass_mla_kernels.cu" "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -779,6 +791,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # Hadacore kernels + cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}") + if(HADACORE_ARCHS) + set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${HADACORE_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + message(STATUS "Building hadacore") + endif() + # if CUDA endif endif() diff --git a/MANIFEST.in b/MANIFEST.in index 82fd22b845f09..fb3cccbb4a9c1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,7 +2,6 @@ include LICENSE include requirements/common.txt include requirements/cuda.txt include requirements/rocm.txt -include requirements/neuron.txt include requirements/cpu.txt include CMakeLists.txt diff --git a/README.md b/README.md index 8812aac4ea266..0c6e5aa6b31d2 100644 --- a/README.md +++ b/README.md @@ -14,19 +14,24 @@ Easy, fast, and cheap LLM serving for everyone | Documentation | Blog | Paper | Twitter/X | User Forum | Developer Slack |

+--- +Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) and [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco for our latest updates on vLLM and to meet the vLLM team! Register now for the largest vLLM community events of the year! + --- *Latest News* 🔥 +- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA). +- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing). - [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH). -- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). -- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). - [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
Previous News +- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). +- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). - [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). @@ -76,7 +81,7 @@ vLLM is flexible and easy to use with: - Tensor, pipeline, data and expert parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron +- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend. - Prefix caching support - Multi-LoRA support diff --git a/benchmarks/README.md b/benchmarks/README.md index 38072152b653b..269a4d51ec2ef 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -1,802 +1,20 @@ -# Benchmarking vLLM +# Benchmarks -This README guides you through running benchmark tests with the extensive -datasets supported on vLLM. It’s a living document, updated as new features and datasets -become available. +This directory used to contain vLLM's benchmark scripts and utilities for performance testing and evaluation. -## Dataset Overview +## Contents - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
DatasetOnlineOfflineData Path
ShareGPTwget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
ShareGPT4V (Image) - wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json -
-
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
- wget http://images.cocodataset.org/zips/train2017.zip -
ShareGPT4Video (Video) - git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video -
BurstGPTwget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv
Sonnet (deprecated)Local file: benchmarks/sonnet.txt
Randomsynthetic
RandomMultiModal (Image/Video)🟡🚧synthetic
Prefix Repetitionsynthetic
HuggingFace-VisionArenalmarena-ai/VisionArena-Chat
HuggingFace-InstructCoderlikaixin/InstructCoder
HuggingFace-AIMOAI-MO/aimo-validation-aime , AI-MO/NuminaMath-1.5, AI-MO/NuminaMath-CoT
HuggingFace-Otherlmms-lab/LLaVA-OneVision-Data, Aeala/ShareGPT_Vicuna_unfiltered
CustomLocal file: data.jsonl
+- **Serving benchmarks**: Scripts for testing online inference performance (latency, throughput) +- **Throughput benchmarks**: Scripts for testing offline batch inference performance +- **Specialized benchmarks**: Tools for testing specific features like structured output, prefix caching, long document QA, request prioritization, and multi-modal inference +- **Dataset utilities**: Framework for loading and sampling from various benchmark datasets (ShareGPT, HuggingFace datasets, synthetic data, etc.) -✅: supported +## Usage -🟡: Partial support +For detailed usage instructions, examples, and dataset information, see the [Benchmark CLI documentation](https://docs.vllm.ai/en/latest/contributing/benchmarks.html#benchmark-cli). -🚧: to be supported +For full CLI reference see: -**Note**: HuggingFace dataset's `dataset-name` should be set to `hf` - -## 🚀 Example - Online Benchmark - -
-Show more - -
- -First start serving your model - -```bash -vllm serve NousResearch/Hermes-3-Llama-3.1-8B -``` - -Then run the benchmarking script - -```bash -# download dataset -# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -vllm bench serve \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --endpoint /v1/completions \ - --dataset-name sharegpt \ - --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json \ - --num-prompts 10 -``` - -If successful, you will see the following output - -```text -============ Serving Benchmark Result ============ -Successful requests: 10 -Benchmark duration (s): 5.78 -Total input tokens: 1369 -Total generated tokens: 2212 -Request throughput (req/s): 1.73 -Output token throughput (tok/s): 382.89 -Total Token throughput (tok/s): 619.85 ----------------Time to First Token---------------- -Mean TTFT (ms): 71.54 -Median TTFT (ms): 73.88 -P99 TTFT (ms): 79.49 ------Time per Output Token (excl. 1st token)------ -Mean TPOT (ms): 7.91 -Median TPOT (ms): 7.96 -P99 TPOT (ms): 8.03 ----------------Inter-token Latency---------------- -Mean ITL (ms): 7.74 -Median ITL (ms): 7.70 -P99 ITL (ms): 8.39 -================================================== -``` - -### Custom Dataset - -If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl - -```json -{"prompt": "What is the capital of India?"} -{"prompt": "What is the capital of Iran?"} -{"prompt": "What is the capital of China?"} -``` - -```bash -# start server -VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct -``` - -```bash -# run benchmarking script -vllm bench serve --port 9001 --save-result --save-detailed \ - --backend vllm \ - --model meta-llama/Llama-3.1-8B-Instruct \ - --endpoint /v1/completions \ - --dataset-name custom \ - --dataset-path \ - --custom-skip-chat-template \ - --num-prompts 80 \ - --max-concurrency 1 \ - --temperature=0.3 \ - --top-p=0.75 \ - --result-dir "./log/" -``` - -You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. - -### VisionArena Benchmark for Vision Language Models - -```bash -# need a model with vision capability here -vllm serve Qwen/Qwen2-VL-7B-Instruct -``` - -```bash -vllm bench serve \ - --backend openai-chat \ - --endpoint-type openai-chat \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --endpoint /v1/chat/completions \ - --dataset-name hf \ - --dataset-path lmarena-ai/VisionArena-Chat \ - --hf-split train \ - --num-prompts 1000 -``` - -### InstructCoder Benchmark with Speculative Decoding - -``` bash -VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ - --speculative-config $'{"method": "ngram", - "num_speculative_tokens": 5, "prompt_lookup_max": 5, - "prompt_lookup_min": 2}' -``` - -``` bash -vllm bench serve \ - --model meta-llama/Meta-Llama-3-8B-Instruct \ - --dataset-name hf \ - --dataset-path likaixin/InstructCoder \ - --num-prompts 2048 -``` - -### Other HuggingFaceDataset Examples - -```bash -vllm serve Qwen/Qwen2-VL-7B-Instruct -``` - -`lmms-lab/LLaVA-OneVision-Data`: - -```bash -vllm bench serve \ - --backend openai-chat \ - --endpoint-type openai-chat \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --endpoint /v1/chat/completions \ - --dataset-name hf \ - --dataset-path lmms-lab/LLaVA-OneVision-Data \ - --hf-split train \ - --hf-subset "chart2text(cauldron)" \ - --num-prompts 10 -``` - -`Aeala/ShareGPT_Vicuna_unfiltered`: - -```bash -vllm bench serve \ - --backend openai-chat \ - --endpoint-type openai-chat \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --endpoint /v1/chat/completions \ - --dataset-name hf \ - --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \ - --hf-split train \ - --num-prompts 10 -``` - -`AI-MO/aimo-validation-aime`: - -``` bash -vllm bench serve \ - --model Qwen/QwQ-32B \ - --dataset-name hf \ - --dataset-path AI-MO/aimo-validation-aime \ - --num-prompts 10 \ - --seed 42 -``` - -`philschmid/mt-bench`: - -``` bash -vllm bench serve \ - --model Qwen/QwQ-32B \ - --dataset-name hf \ - --dataset-path philschmid/mt-bench \ - --num-prompts 80 -``` - -### Running With Sampling Parameters - -When using OpenAI-compatible backends such as `vllm`, optional sampling -parameters can be specified. Example client command: - -```bash -vllm bench serve \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --endpoint /v1/completions \ - --dataset-name sharegpt \ - --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json \ - --top-k 10 \ - --top-p 0.9 \ - --temperature 0.5 \ - --num-prompts 10 -``` - -### Running With Ramp-Up Request Rate - -The benchmark tool also supports ramping up the request rate over the -duration of the benchmark run. This can be useful for stress testing the -server or finding the maximum throughput that it can handle, given some latency budget. - -Two ramp-up strategies are supported: - -- `linear`: Increases the request rate linearly from a start value to an end value. -- `exponential`: Increases the request rate exponentially. - -The following arguments can be used to control the ramp-up: - -- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`). -- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. -- `--ramp-up-end-rps`: The request rate at the end of the benchmark. - -
- -## 📈 Example - Offline Throughput Benchmark - -
-Show more - -
- -```bash -vllm bench throughput \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset-name sonnet \ - --dataset-path vllm/benchmarks/sonnet.txt \ - --num-prompts 10 -``` - -If successful, you will see the following output - -```text -Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s -Total num prompt tokens: 5014 -Total num output tokens: 1500 -``` - -### VisionArena Benchmark for Vision Language Models - -```bash -vllm bench throughput \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --backend vllm-chat \ - --dataset-name hf \ - --dataset-path lmarena-ai/VisionArena-Chat \ - --num-prompts 1000 \ - --hf-split train -``` - -The `num prompt tokens` now includes image token counts - -```text -Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s -Total num prompt tokens: 14527 -Total num output tokens: 1280 -``` - -### InstructCoder Benchmark with Speculative Decoding - -``` bash -VLLM_WORKER_MULTIPROC_METHOD=spawn \ -VLLM_USE_V1=1 \ -vllm bench throughput \ - --dataset-name=hf \ - --dataset-path=likaixin/InstructCoder \ - --model=meta-llama/Meta-Llama-3-8B-Instruct \ - --input-len=1000 \ - --output-len=100 \ - --num-prompts=2048 \ - --async-engine \ - --speculative-config $'{"method": "ngram", - "num_speculative_tokens": 5, "prompt_lookup_max": 5, - "prompt_lookup_min": 2}' -``` - -```text -Throughput: 104.77 requests/s, 23836.22 total tokens/s, 10477.10 output tokens/s -Total num prompt tokens: 261136 -Total num output tokens: 204800 -``` - -### Other HuggingFaceDataset Examples - -`lmms-lab/LLaVA-OneVision-Data`: - -```bash -vllm bench throughput \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --backend vllm-chat \ - --dataset-name hf \ - --dataset-path lmms-lab/LLaVA-OneVision-Data \ - --hf-split train \ - --hf-subset "chart2text(cauldron)" \ - --num-prompts 10 -``` - -`Aeala/ShareGPT_Vicuna_unfiltered`: - -```bash -vllm bench throughput \ - --model Qwen/Qwen2-VL-7B-Instruct \ - --backend vllm-chat \ - --dataset-name hf \ - --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \ - --hf-split train \ - --num-prompts 10 -``` - -`AI-MO/aimo-validation-aime`: - -```bash -vllm bench throughput \ - --model Qwen/QwQ-32B \ - --backend vllm \ - --dataset-name hf \ - --dataset-path AI-MO/aimo-validation-aime \ - --hf-split train \ - --num-prompts 10 -``` - -Benchmark with LoRA adapters: - -``` bash -# download dataset -# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -vllm bench throughput \ - --model meta-llama/Llama-2-7b-hf \ - --backend vllm \ - --dataset_path /ShareGPT_V3_unfiltered_cleaned_split.json \ - --dataset_name sharegpt \ - --num-prompts 10 \ - --max-loras 2 \ - --max-lora-rank 8 \ - --enable-lora \ - --lora-path yard1/llama-2-7b-sql-lora-test - ``` - -
- -## 🛠️ Example - Structured Output Benchmark - -
-Show more - -
- -Benchmark the performance of structured output generation (JSON, grammar, regex). - -### Server Setup - -```bash -vllm serve NousResearch/Hermes-3-Llama-3.1-8B -``` - -### JSON Schema Benchmark - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset json \ - --structured-output-ratio 1.0 \ - --request-rate 10 \ - --num-prompts 1000 -``` - -### Grammar-based Generation Benchmark - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset grammar \ - --structure-type grammar \ - --request-rate 10 \ - --num-prompts 1000 -``` - -### Regex-based Generation Benchmark - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset regex \ - --request-rate 10 \ - --num-prompts 1000 -``` - -### Choice-based Generation Benchmark - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset choice \ - --request-rate 10 \ - --num-prompts 1000 -``` - -### XGrammar Benchmark Dataset - -```bash -python3 benchmarks/benchmark_serving_structured_output.py \ - --backend vllm \ - --model NousResearch/Hermes-3-Llama-3.1-8B \ - --dataset xgrammar_bench \ - --request-rate 10 \ - --num-prompts 1000 -``` - -
- -## 📚 Example - Long Document QA Benchmark - -
-Show more - -
- -Benchmark the performance of long document question-answering with prefix caching. - -### Basic Long Document QA Test - -```bash -python3 benchmarks/benchmark_long_document_qa_throughput.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-documents 16 \ - --document-length 2000 \ - --output-len 50 \ - --repeat-count 5 -``` - -### Different Repeat Modes - -```bash -# Random mode (default) - shuffle prompts randomly -python3 benchmarks/benchmark_long_document_qa_throughput.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-documents 8 \ - --document-length 3000 \ - --repeat-count 3 \ - --repeat-mode random - -# Tile mode - repeat entire prompt list in sequence -python3 benchmarks/benchmark_long_document_qa_throughput.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-documents 8 \ - --document-length 3000 \ - --repeat-count 3 \ - --repeat-mode tile - -# Interleave mode - repeat each prompt consecutively -python3 benchmarks/benchmark_long_document_qa_throughput.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-documents 8 \ - --document-length 3000 \ - --repeat-count 3 \ - --repeat-mode interleave -``` - -
- -## 🗂️ Example - Prefix Caching Benchmark - -
-Show more - -
- -Benchmark the efficiency of automatic prefix caching. - -### Fixed Prompt with Prefix Caching - -```bash -python3 benchmarks/benchmark_prefix_caching.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --enable-prefix-caching \ - --num-prompts 1 \ - --repeat-count 100 \ - --input-length-range 128:256 -``` - -### ShareGPT Dataset with Prefix Caching - -```bash -# download dataset -# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json - -python3 benchmarks/benchmark_prefix_caching.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \ - --enable-prefix-caching \ - --num-prompts 20 \ - --repeat-count 5 \ - --input-length-range 128:256 -``` - -### Prefix Repetition Dataset - -```bash -vllm bench serve \ - --backend openai \ - --model meta-llama/Llama-2-7b-chat-hf \ - --dataset-name prefix_repetition \ - --num-prompts 100 \ - --prefix-repetition-prefix-len 512 \ - --prefix-repetition-suffix-len 128 \ - --prefix-repetition-num-prefixes 5 \ - --prefix-repetition-output-len 128 -``` - -
- -## ⚡ Example - Request Prioritization Benchmark - -
-Show more - -
- -Benchmark the performance of request prioritization in vLLM. - -### Basic Prioritization Test - -```bash -python3 benchmarks/benchmark_prioritization.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --input-len 128 \ - --output-len 64 \ - --num-prompts 100 \ - --scheduling-policy priority -``` - -### Multiple Sequences per Prompt - -```bash -python3 benchmarks/benchmark_prioritization.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --input-len 128 \ - --output-len 64 \ - --num-prompts 100 \ - --scheduling-policy priority \ - --n 2 -``` - -
- -## 👁️ Example - Multi-Modal Benchmark - -
-Show more - -
- -Benchmark the performance of multi-modal requests in vLLM. - -### Images (ShareGPT4V) - -Start vLLM: - -```bash -python -m vllm.entrypoints.openai.api_server \ - --model Qwen/Qwen2.5-VL-7B-Instruct \ - --dtype bfloat16 \ - --limit-mm-per-prompt '{"image": 1}' \ - --allowed-local-media-path /path/to/sharegpt4v/images -``` - -Send requests with images: - -```bash -python benchmarks/benchmark_serving.py \ - --backend openai-chat \ - --model Qwen/Qwen2.5-VL-7B-Instruct \ - --dataset-name sharegpt \ - --dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \ - --num-prompts 100 \ - --save-result \ - --result-dir ~/vllm_benchmark_results \ - --save-detailed \ - --endpoint /v1/chat/completion -``` - -### Videos (ShareGPT4Video) - -Start vLLM: - -```bash -python -m vllm.entrypoints.openai.api_server \ - --model Qwen/Qwen2.5-VL-7B-Instruct \ - --dtype bfloat16 \ - --limit-mm-per-prompt '{"video": 1}' \ - --allowed-local-media-path /path/to/sharegpt4video/videos -``` - -Send requests with videos: - -```bash -python benchmarks/benchmark_serving.py \ - --backend openai-chat \ - --model Qwen/Qwen2.5-VL-7B-Instruct \ - --dataset-name sharegpt \ - --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \ - --num-prompts 100 \ - --save-result \ - --result-dir ~/vllm_benchmark_results \ - --save-detailed \ - --endpoint /v1/chat/completion -``` - -### Synthetic Random Images (random-mm) - -Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. - -Notes: - -- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. -- Video sampling is not yet implemented. - -Start the server (example): - -```bash -vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ - --dtype bfloat16 \ - --max-model-len 16384 \ - --limit-mm-per-prompt '{"image": 3, "video": 0}' \ - --mm-processor-kwargs max_pixels=1003520 -``` - -Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. - -Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens: - -```bash -vllm bench serve \ - --backend openai-chat \ - --model Qwen/Qwen2.5-VL-3B-Instruct \ - --endpoint /v1/chat/completions \ - --dataset-name random-mm \ - --num-prompts 100 \ - --max-concurrency 10 \ - --random-prefix-len 25 \ - --random-input-len 300 \ - --random-output-len 40 \ - --random-range-ratio 0.2 \ - --random-mm-base-items-per-request 2 \ - --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \ - --random-mm-bucket-config '{(224, 224, 1): 1.0}' \ - --request-rate inf \ - --ignore-eos \ - --seed 42 -``` - -The number of items per request can be controlled by passing multiple image buckets: - -```bash - --random-mm-base-items-per-request 2 \ - --random-mm-num-mm-items-range-ratio 0.5 \ - --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \ - --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \ -``` - -Flags specific to `random-mm`: - -- `--random-mm-base-items-per-request`: base number of multimodal items per request. -- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items. -- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'. -- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported). - -Behavioral notes: - -- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. - -How sampling works: - -- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. -- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. -- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. -This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. -- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. - -
+- +- +- diff --git a/benchmarks/auto_tune/README.md b/benchmarks/auto_tune/README.md index 9aad51df6e003..d1bdb4c43f10b 100644 --- a/benchmarks/auto_tune/README.md +++ b/benchmarks/auto_tune/README.md @@ -31,6 +31,12 @@ cd vllm You must set the following variables at the top of the script before execution. + Note: You can also override the default values below via environment variables when running the script. + +```bash +MODEL=meta-llama/Llama-3.3-70B-Instruct SYSTEM=TPU TP=8 DOWNLOAD_DIR='' INPUT_LEN=128 OUTPUT_LEN=2048 MAX_MODEL_LEN=2300 MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=100000000000 NUM_SEQS_LIST="128 256" NUM_BATCHED_TOKENS_LIST="1024 2048 4096" VLLM_LOGGING_LEVEL=DEBUG bash auto_tune.sh +``` + | Variable | Description | Example Value | | --- | --- | --- | | `BASE` | **Required.** The absolute path to the parent directory of your vLLM repository directory. | `"$HOME"` | @@ -143,3 +149,70 @@ The script follows a systematic process to find the optimal parameters: 4. **Track Best Result**: Throughout the process, the script tracks the parameter combination that has yielded the highest valid throughput so far. 5. **Profile Collection**: For the best-performing run, the script saves the vLLM profiler output, which can be used for deep-dive performance analysis with tools like TensorBoard. + +## Batched `auto_tune` + +The `batch_auto_tune.sh` script allows you to run multiple `auto_tune.sh` experiments sequentially from a single configuration file. It iterates through a list of parameter sets, executes `auto_tune.sh` for each, and records the results back into the input file. + +### Prerequisites + +- **jq**: This script requires `jq` to parse the JSON configuration file. +- **gcloud**: If you plan to upload results to Google Cloud Storage, the `gcloud` CLI must be installed and authenticated. + +### How to Run + +1. **Create a JSON configuration file**: Create a file (e.g., `runs_config.json`) containing an array of JSON objects. Each object defines the parameters for a single `auto_tune.sh` run. + +2. **Execute the script**: + + ```bash + bash batch_auto_tune.sh [gcs_upload_path] + ``` + + - ``: **Required.** Path to your JSON configuration file. + - `[gcs_upload_path]`: **Optional.** A GCS path (e.g., `gs://my-bucket/benchmark-results`) where the detailed results and profiles for each run will be uploaded. If this is empty, the results will be available on the local filesystem (see the log for `RESULT_FILE=/path/to/results/file.txt`). + +### Configuration File + +The JSON configuration file should contain an array of objects. Each object's keys correspond to the configuration variables for `auto_tune.sh` (see the [Configuration table above](#configuration)). These keys will be converted to uppercase environment variables for each run. + +Here is an example `runs_config.json` with two benchmark configurations: + +```json +[ + { + "base": "/home/user", + "model": "meta-llama/Llama-3.1-8B-Instruct", + "system": "TPU", # OR GPU + "tp": 8, + "input_len": 128, + "output_len": 2048, + "max_model_len": 2300, + "num_seqs_list": "128 256", + "num_batched_tokens_list": "8192 16384" + }, + { + "base": "/home/user", + "model": "meta-llama/Llama-3.1-70B-Instruct", + "system": "TPU", # OR GPU + "tp": 8, + "input_len": 4000, + "output_len": 16, + "max_model_len": 4096, + "num_seqs_list": "64 128", + "num_batched_tokens_list": "4096 8192", + "max_latency_allowed_ms": 500 + } +] +``` + +### Output + +The script modifies the input JSON file in place, adding the results of each run to the corresponding object. The following fields are added: + +- `run_id`: A unique identifier for the run, derived from the timestamp. +- `status`: The outcome of the run (`SUCCESS`, `FAILURE`, or `WARNING_NO_RESULT_FILE`). +- `results`: The content of the `result.txt` file from the `auto_tune.sh` run. +- `gcs_results`: The GCS URL where the run's artifacts are stored (if a GCS path was provided). + +A summary of successful and failed runs is also printed to the console upon completion. diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh index 82c20ffa6554c..ed3679b66f805 100644 --- a/benchmarks/auto_tune/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -5,25 +5,41 @@ TAG=$(date +"%Y_%m_%d_%H_%M") SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -BASE="$SCRIPT_DIR/../../.." -MODEL="meta-llama/Llama-3.1-8B-Instruct" -SYSTEM="TPU" -TP=1 -DOWNLOAD_DIR="" -INPUT_LEN=4000 -OUTPUT_LEN=16 -MAX_MODEL_LEN=4096 -MIN_CACHE_HIT_PCT=0 -MAX_LATENCY_ALLOWED_MS=100000000000 -NUM_SEQS_LIST="128 256" -NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096" +VLLM_LOGGING_LEVEL=${VLLM_LOGGING_LEVEL:-INFO} +BASE=${BASE:-"$SCRIPT_DIR/../../.."} +MODEL=${MODEL:-"meta-llama/Llama-3.1-8B-Instruct"} +SYSTEM=${SYSTEM:-"TPU"} +TP=${TP:-1} +DOWNLOAD_DIR=${DOWNLOAD_DIR:-""} +INPUT_LEN=${INPUT_LEN:-4000} +OUTPUT_LEN=${OUTPUT_LEN:-16} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-4096} +MIN_CACHE_HIT_PCT=${MIN_CACHE_HIT_PCT:-0} +MAX_LATENCY_ALLOWED_MS=${MAX_LATENCY_ALLOWED_MS:-100000000000} +NUM_SEQS_LIST=${NUM_SEQS_LIST:-"128 256"} +NUM_BATCHED_TOKENS_LIST=${NUM_BATCHED_TOKENS_LIST:-"512 1024 2048 4096"} LOG_FOLDER="$BASE/auto-benchmark/$TAG" RESULT="$LOG_FOLDER/result.txt" PROFILE_PATH="$LOG_FOLDER/profile" -echo "result file: $RESULT" -echo "model: $MODEL" +echo "====================== AUTO TUNE PARAMETERS ====================" +echo "SCRIPT_DIR=$SCRIPT_DIR" +echo "BASE=$BASE" +echo "MODEL=$MODEL" +echo "SYSTEM=$SYSTEM" +echo "TP=$TP" +echo "DOWNLOAD_DIR=$DOWNLOAD_DIR" +echo "INPUT_LEN=$INPUT_LEN" +echo "OUTPUT_LEN=$OUTPUT_LEN" +echo "MAX_MODEL_LEN=$MAX_MODEL_LEN" +echo "MIN_CACHE_HIT_PCT=$MIN_CACHE_HIT_PCT" +echo "MAX_LATENCY_ALLOWED_MS=$MAX_LATENCY_ALLOWED_MS" +echo "NUM_SEQS_LIST=$NUM_SEQS_LIST" +echo "NUM_BATCHED_TOKENS_LIST=$NUM_BATCHED_TOKENS_LIST" +echo "VLLM_LOGGING_LEVEL=$VLLM_LOGGING_LEVEL" +echo "RESULT_FILE=$RESULT" +echo "====================== AUTO TUNEPARAMETERS ====================" rm -rf $LOG_FOLDER rm -rf $PROFILE_PATH @@ -213,7 +229,7 @@ run_benchmark() { pkill -if vllm sleep 10 - printf '=%.0s' $(seq 1 20) + echo "====================" return 0 } diff --git a/benchmarks/auto_tune/batch_auto_tune.sh b/benchmarks/auto_tune/batch_auto_tune.sh new file mode 100755 index 0000000000000..57ef20daf6b71 --- /dev/null +++ b/benchmarks/auto_tune/batch_auto_tune.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +INPUT_JSON="$1" +GCS_PATH="$2" # Optional GCS path for uploading results for each run + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) +AUTOTUNE_SCRIPT="$SCRIPT_DIR/auto_tune.sh" + +if [[ -z "$INPUT_JSON" ]]; then + echo "Error: Input JSON file not provided." + echo "Usage: $0 [gcs_upload_path]" + exit 1 +fi + +if [[ ! -f "$INPUT_JSON" ]]; then + echo "Error: File not found at '$INPUT_JSON'" + exit 1 +fi + +if ! command -v jq &> /dev/null; then + echo "Error: 'jq' command not found. Please install jq to process the JSON input." + exit 1 +fi + +if [[ -n "$GCS_PATH" ]] && ! command -v gcloud &> /dev/null; then + echo "Error: 'gcloud' command not found, but a GCS_PATH was provided." + exit 1 +fi + +SUCCESS_COUNT=0 +FAILURE_COUNT=0 +FAILED_RUNS=() +SCRIPT_START_TIME=$(date +%s) + +json_content=$(cat "$INPUT_JSON") +if ! num_runs=$(echo "$json_content" | jq 'length'); then + echo "Error: Invalid JSON in $INPUT_JSON. 'jq' failed to get array length." >&2 + exit 1 +fi + +echo "Found $num_runs benchmark configurations in $INPUT_JSON." +echo "Starting benchmark runs..." +echo "--------------------------------------------------" + +for i in $(seq 0 $(($num_runs - 1))); do + run_object=$(echo "$json_content" | jq ".[$i]") + + RUN_START_TIME=$(date +%s) + ENV_VARS_ARRAY=() + # Dynamically create env vars from the JSON object's keys + for key in $(echo "$run_object" | jq -r 'keys_unsorted[]'); do + value=$(echo "$run_object" | jq -r ".$key") + var_name=$(echo "$key" | tr '[:lower:]' '[:upper:]' | tr -cd 'A-Z0-9_') + ENV_VARS_ARRAY+=("${var_name}=${value}") + done + + echo "Executing run #$((i+1))/$num_runs with parameters: ${ENV_VARS_ARRAY[*]}" + + # Execute auto_tune.sh and capture output + RUN_OUTPUT_FILE=$(mktemp) + if env "${ENV_VARS_ARRAY[@]}" bash "$AUTOTUNE_SCRIPT" > >(tee -a "$RUN_OUTPUT_FILE") 2>&1; then + STATUS="SUCCESS" + ((SUCCESS_COUNT++)) + else + STATUS="FAILURE" + ((FAILURE_COUNT++)) + FAILED_RUNS+=("Run #$((i+1)): $(echo $run_object | jq -c .)") + fi + + RUN_OUTPUT=$(<"$RUN_OUTPUT_FILE") + rm "$RUN_OUTPUT_FILE" + + # Parse results and optionally upload them to GCS + RUN_ID="" + RESULTS="" + GCS_RESULTS_URL="" + if [[ "$STATUS" == "SUCCESS" ]]; then + RESULT_FILE_PATH=$(echo "$RUN_OUTPUT" | grep 'RESULT_FILE=' | tail -n 1 | cut -d'=' -f2 | tr -s '/' || true) + + if [[ -n "$RESULT_FILE_PATH" && -f "$RESULT_FILE_PATH" ]]; then + RUN_ID=$(basename "$(dirname "$RESULT_FILE_PATH")") + RESULT_DIR=$(dirname "$RESULT_FILE_PATH") + RESULTS=$(cat "$RESULT_FILE_PATH") + + if [[ -n "$GCS_PATH" ]]; then + GCS_RESULTS_URL="${GCS_PATH}/${RUN_ID}" + echo "Uploading results to GCS..." + if gcloud storage rsync --recursive "$RESULT_DIR/" "$GCS_RESULTS_URL"; then + echo "GCS upload successful." + else + echo "Warning: GCS upload failed for RUN_ID $RUN_ID." + fi + fi + else + echo "Warning: Could not find result file for a successful run." + STATUS="WARNING_NO_RESULT_FILE" + fi + fi + + # Add the results back into the JSON object for this run + json_content=$(echo "$json_content" | jq --argjson i "$i" --arg run_id "$RUN_ID" --arg status "$STATUS" --arg results "$RESULTS" --arg gcs_results "$GCS_RESULTS_URL" \ + '.[$i] += {run_id: $run_id, status: $status, results: $results, gcs_results: $gcs_results}') + + RUN_END_TIME=$(date +%s) + echo "Run finished in $((RUN_END_TIME - RUN_START_TIME)) seconds. Status: $STATUS" + echo "--------------------------------------------------" + + # Save intermediate progress back to the file + echo "$json_content" > "$INPUT_JSON.tmp" && mv "$INPUT_JSON.tmp" "$INPUT_JSON" + +done + +SCRIPT_END_TIME=$(date +%s) +echo "All benchmark runs completed in $((SCRIPT_END_TIME - SCRIPT_START_TIME)) seconds." +echo +echo "====================== SUMMARY ======================" +echo "Successful runs: $SUCCESS_COUNT" +echo "Failed runs: $FAILURE_COUNT" +echo "===================================================" + +if [[ $FAILURE_COUNT -gt 0 ]]; then + echo "Details of failed runs (see JSON file for full parameters):" + for failed in "${FAILED_RUNS[@]}"; do + echo " - $failed" + done +fi + +echo "Updated results have been saved to '$INPUT_JSON'." diff --git a/benchmarks/benchmark_block_pool.py b/benchmarks/benchmark_block_pool.py index fd363c2ad0514..eae8d9927ea39 100644 --- a/benchmarks/benchmark_block_pool.py +++ b/benchmarks/benchmark_block_pool.py @@ -57,7 +57,7 @@ def invoke_main() -> None: "--num-iteration", type=int, default=1000, - help="Number of iterations to run to stablize final data readings", + help="Number of iterations to run to stabilize final data readings", ) parser.add_argument( "--allocate-blocks", diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py deleted file mode 100644 index 2ea4f9ccaff2b..0000000000000 --- a/benchmarks/benchmark_dataset.py +++ /dev/null @@ -1,1288 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This module defines a framework for sampling benchmark requests from various -datasets. Each dataset subclass of BenchmarkDataset must implement sample -generation. Supported dataset types include: - - ShareGPT - - Random (synthetic) - - Sonnet - - BurstGPT - - HuggingFace - - VisionArena -""" - -import base64 -import io -import json -import logging -import random -from abc import ABC, abstractmethod -from collections.abc import Mapping -from copy import deepcopy -from dataclasses import dataclass -from functools import cache -from io import BytesIO -from typing import Any, Callable, Optional, Union - -import numpy as np -import pandas as pd -from datasets import load_dataset -from PIL import Image -from transformers import PreTrainedTokenizerBase - -from vllm.lora.request import LoRARequest -from vllm.lora.utils import get_adapter_absolute_path -from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.image import convert_image_mode -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer - -logger = logging.getLogger(__name__) - -# ----------------------------------------------------------------------------- -# Data Classes -# ----------------------------------------------------------------------------- - - -@dataclass -class SampleRequest: - """ - Represents a single inference request for benchmarking. - """ - - prompt: Union[str, Any] - prompt_len: int - expected_output_len: int - multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None - lora_request: Optional[LoRARequest] = None - request_id: Optional[str] = None - - -# ----------------------------------------------------------------------------- -# Benchmark Dataset Base Class -# ----------------------------------------------------------------------------- - - -class BenchmarkDataset(ABC): - DEFAULT_SEED = 0 - IS_MULTIMODAL = False - - def __init__( - self, - dataset_path: Optional[str] = None, - random_seed: int = DEFAULT_SEED, - ) -> None: - """ - Initialize the BenchmarkDataset with an optional dataset path and random - seed. Args: - dataset_path (Optional[str]): Path to the dataset. If None, it - indicates that a default or random dataset might be used. - random_seed (int): Seed value for reproducible shuffling or - sampling. Defaults to DEFAULT_SEED. - """ - self.dataset_path = dataset_path - # Set the random seed, ensuring that a None value is replaced with the - # default seed. - self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED - self.data = None - - def apply_multimodal_chat_transformation( - self, prompt: str, mm_content: Optional[MultiModalDataDict] = None - ) -> list[dict]: - """ - Transform a prompt and optional multimodal content into a chat format. - This method is used for chat models that expect a specific conversation - format. - """ - content = [{"text": prompt, "type": "text"}] - if mm_content is not None: - content.append(mm_content) - return [{"role": "user", "content": content}] - - def load_data(self) -> None: - """ - Load data from the dataset path into self.data. - - This method must be overridden by subclasses since the method to load - data will vary depending on the dataset format and source. - - Raises: - NotImplementedError: If a subclass does not implement this method. - """ - # TODO (jenniferzhao): add support for downloading data - raise NotImplementedError("load_data must be implemented in subclasses.") - - def get_random_lora_request( - self, - tokenizer: PreTrainedTokenizerBase, - max_loras: Optional[int] = None, - lora_path: Optional[str] = None, - ) -> tuple[Optional[LoRARequest], AnyTokenizer]: - """ - Optionally select a random LoRA request and return its associated - tokenizer. - - This method is used when LoRA parameters are provided. It randomly - selects a LoRA based on max_loras and retrieves a cached tokenizer for - that LoRA if available. Otherwise, it returns the base tokenizer. - - Args: - tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no - LoRA is selected. max_loras (Optional[int]): The maximum number of - LoRAs available. If None, LoRA is not used. lora_path - (Optional[str]): Path to the LoRA parameters on disk. If None, LoRA - is not used. - - Returns: - tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first - element is a LoRARequest (or None if not applicable) and the second - element is the tokenizer associated with the LoRA request (or the - base tokenizer). - """ - if max_loras is None or lora_path is None: - return None, tokenizer - - # Generate a random LoRA ID in the range [1, max_loras]. - lora_id = random.randint(1, max_loras) - lora_request = LoRARequest( - lora_name=str(lora_id), - lora_int_id=lora_id, - lora_path=lora_path_on_disk(lora_path), - ) - if lora_id not in lora_tokenizer_cache: - lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) - # Return lora_request and the cached tokenizer if available; otherwise, - # return the base tokenizer - return lora_request, lora_tokenizer_cache[lora_id] or tokenizer - - @abstractmethod - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - request_id_prefix: str = "", - ) -> list[SampleRequest]: - """ - Abstract method to generate sample requests from the dataset. - - Subclasses must override this method to implement dataset-specific logic - for generating a list of SampleRequest objects. - - Args: - tokenizer (PreTrainedTokenizerBase): The tokenizer to be used - for processing the dataset's text. - num_requests (int): The number of sample requests to generate. - request_id_prefix (str) The prefix of request_id. - - Returns: - list[SampleRequest]: A list of sample requests generated from the - dataset. - """ - raise NotImplementedError("sample must be implemented in subclasses.") - - def maybe_oversample_requests( - self, - requests: list[SampleRequest], - num_requests: int, - request_id_prefix: str = "", - ) -> None: - """ - Oversamples the list of requests if its size is less than the desired - number. - - Args: - requests (List[SampleRequest]): The current list of sampled - requests. - num_requests (int): The target number of requests. - request_id_prefix (str) The prefix of the request ids. - """ - if len(requests) < num_requests: - random.seed(self.random_seed) - additional = deepcopy( - random.choices(requests, k=num_requests - len(requests)) - ) - for i in range(len(additional)): - req = additional[i] - req.request_id = request_id_prefix + str(len(requests) + i) - requests.extend(additional) - logger.info("Oversampled requests to reach %d total samples.", num_requests) - - -# ----------------------------------------------------------------------------- -# Utility Functions and Global Caches -# ----------------------------------------------------------------------------- - - -def is_valid_sequence( - prompt_len: int, - output_len: int, - min_len: int = 4, - max_prompt_len: int = 1024, - max_total_len: int = 2048, - skip_min_output_len_check: bool = False, -) -> bool: - """ - Validate a sequence based on prompt and output lengths. - - Default pruning criteria are copied from the original `sample_hf_requests` - and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as - from `sample_requests` in benchmark_throughput.py. - """ - # Check for invalid conditions - prompt_too_short = prompt_len < min_len - output_too_short = (not skip_min_output_len_check) and (output_len < min_len) - prompt_too_long = prompt_len > max_prompt_len - combined_too_long = (prompt_len + output_len) > max_total_len - - # Return True if none of the invalid conditions are met - return not ( - prompt_too_short or output_too_short or prompt_too_long or combined_too_long - ) - - -@cache -def lora_path_on_disk(lora_path: str) -> str: - return get_adapter_absolute_path(lora_path) - - -# Global cache for LoRA tokenizers. -lora_tokenizer_cache: dict[int, AnyTokenizer] = {} - - -def process_image(image: Any) -> Mapping[str, Any]: - """ - Process a single image input and return a multimedia content dictionary. - - Supports three input types: - - 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key - containing raw image data. - Loads the bytes as a PIL.Image.Image. - - 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as - a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns - a dictionary with the image as a base64 data URL. - - 3. String input: - Treats the string as a URL or local file path. - - Prepends "file://" if the string doesn't start with "http://" or - "file://". - Returns a dictionary with the image URL. - - Raises: - ValueError: If the input is not a supported type. - """ - if isinstance(image, dict) and "bytes" in image: - image = Image.open(BytesIO(image["bytes"])) - if isinstance(image, Image.Image): - image = convert_image_mode(image, "RGB") - with io.BytesIO() as image_data: - image.save(image_data, format="JPEG") - image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") - return { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, - } - - if isinstance(image, str): - image_url = ( - image if image.startswith(("http://", "file://")) else f"file://{image}" - ) - return {"type": "image_url", "image_url": {"url": image_url}} - - raise ValueError( - f"Invalid image input {image}. Must be a PIL.Image.Image" - " or str or dictionary with raw image bytes." - ) - - -def process_video(video: Any) -> Mapping[str, Any]: - """ - Process a single video input and return a multimedia content dictionary. - - Supports the following input types: - - 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key - containing raw video data. - - 2. String input: - Treats the string as a URL or local file path. - - Prepends "file://" if the string doesn't start with "http://" or - "file://". - Returns a dictionary with the image URL. - - Raises: - ValueError: If the input is not a supported type. - """ - if isinstance(video, dict) and "bytes" in video: - video_bytes = video["bytes"] - video_base64 = base64.b64encode(video_bytes).decode("utf-8") - return { - "type": "video_url", - "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, - } - - if isinstance(video, str): - video_url = ( - video if video.startswith(("http://", "file://")) else f"file://{video}" - ) - return {"type": "video_url", "video_url": {"url": video_url}} - - raise ValueError( - f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501 - ) - - -# ----------------------------------------------------------------------------- -# Random Dataset Implementation (Synthetic Data) -# ----------------------------------------------------------------------------- - - -class RandomDataset(BenchmarkDataset): - # Default values copied from benchmark_serving.py for the random dataset. - DEFAULT_PREFIX_LEN = 0 - DEFAULT_RANGE_RATIO = 0.0 - DEFAULT_INPUT_LEN = 1024 - DEFAULT_OUTPUT_LEN = 128 - - def __init__( - self, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - prefix_len: int = DEFAULT_PREFIX_LEN, - range_ratio: float = DEFAULT_RANGE_RATIO, - input_len: int = DEFAULT_INPUT_LEN, - output_len: int = DEFAULT_OUTPUT_LEN, - request_id_prefix: str = "", - **kwargs, - ) -> list[SampleRequest]: - # Enforce range_ratio < 1 - assert range_ratio < 1.0, ( - "random_range_ratio must be < 1.0 to ensure a valid sampling range" - ) - - vocab_size = tokenizer.vocab_size - num_special_tokens = tokenizer.num_special_tokens_to_add() - real_input_len = input_len - num_special_tokens - - prefix_token_ids = ( - np.random.randint(0, vocab_size, size=prefix_len).tolist() - if prefix_len > 0 - else [] - ) - - # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(real_input_len * (1 - range_ratio)) - input_high = int(real_input_len * (1 + range_ratio)) - output_low = int(output_len * (1 - range_ratio)) - # Ensure the lower bound for output length is at least 1 to prevent - # sampling 0 tokens, which can cause request failures. - output_low = max(output_low, 1) - output_high = int(output_len * (1 + range_ratio)) - - # Add logging for debugging - logger.info("Sampling input_len from [%s, %s]", input_low, input_high) - logger.info("Sampling output_len from [%s, %s]", output_low, output_high) - - input_lens = np.random.randint(input_low, input_high + 1, size=num_requests) - output_lens = np.random.randint(output_low, output_high + 1, size=num_requests) - offsets = np.random.randint(0, vocab_size, size=num_requests) - - requests = [] - for i in range(num_requests): - inner_seq = ( - (offsets[i] + i + np.arange(input_lens[i])) % vocab_size - ).tolist() - token_sequence = prefix_token_ids + inner_seq - prompt = tokenizer.decode(token_sequence) - # After decoding the prompt we have to encode and decode it again. - # This is done because in some cases N consecutive tokens - # give a string tokenized into != N number of tokens. - # For example for GPT2Tokenizer: - # [6880, 6881] -> ['Ġcalls', 'here'] -> - # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] - # To avoid uncontrolled change of the prompt length, - # the encoded sequence is truncated before being decode again. - total_input_len = prefix_len + int(input_lens[i]) - re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ - :total_input_len - ] - prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = len(re_encoded_sequence) - requests.append( - SampleRequest( - prompt=prompt, - prompt_len=total_input_len, - expected_output_len=int(output_lens[i]), - request_id=request_id_prefix + str(i), - ) - ) - - return requests - - -# ----------------------------------------------------------------------------- -# ShareGPT Dataset Implementation -# ----------------------------------------------------------------------------- - - -class ShareGPTDataset(BenchmarkDataset): - """ - Implements the ShareGPT dataset. Loads data from a JSON file and generates - sample requests based on conversation turns. - """ - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self.load_data() - - def load_data(self) -> None: - if self.dataset_path is None: - raise ValueError("dataset_path must be provided for loading data.") - - with open(self.dataset_path, encoding="utf-8") as f: - self.data = json.load(f) - # Filter entries with at least two conversation turns. - self.data = [ - entry - for entry in self.data - if "conversations" in entry and len(entry["conversations"]) >= 2 - ] - random.seed(self.random_seed) - random.shuffle(self.data) - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - lora_path: Optional[str] = None, - max_loras: Optional[int] = None, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - samples: list = [] - ind = 0 - for entry in self.data: - if len(samples) >= num_requests: - break - prompt, completion = ( - entry["conversations"][0]["value"], - entry["conversations"][1]["value"], - ) - - lora_request, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path - ) - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - new_output_len = len(completion_ids) if output_len is None else output_len - if not is_valid_sequence( - prompt_len, - new_output_len, - skip_min_output_len_check=output_len is not None, - ): - continue - if image_path := entry.get("image"): - mm_content = process_image(image_path) - elif video_path := entry.get("video"): - mm_content = process_video(video_path) - else: - mm_content = None - if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) - samples.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=new_output_len, - lora_request=lora_request, - multi_modal_data=mm_content, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) - return samples - - -# ----------------------------------------------------------------------------- -# Custom Dataset Implementation -# ----------------------------------------------------------------------------- - - -class CustomDataset(BenchmarkDataset): - """ - Implements the Custom dataset. Loads data from a JSONL file and generates - sample requests based on conversation turns. E.g., - ``` - {"prompt": "What is the capital of India?"} - {"prompt": "What is the capital of Iran?"} - {"prompt": "What is the capital of China?"} - ``` - """ - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self.load_data() - - def load_data(self) -> None: - if self.dataset_path is None: - raise ValueError("dataset_path must be provided for loading data.") - - # self.data will be a list of dictionaries - # e.g., [{"prompt": "What is the capital of India?"}, ...] - # This will be the standardized format which load_data() - # has to convert into depending on the filetype of dataset_path. - # sample() will assume this standardized format of self.data - self.data = [] - - # Load the JSONL file - if self.dataset_path.endswith(".jsonl"): - jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True) - - # check if the JSONL file has a 'prompt' column - if "prompt" not in jsonl_data.columns: - raise ValueError("JSONL file must contain a 'prompt' column.") - - # Convert each row to a dictionary and append to self.data - # This will convert the DataFrame to a list of dictionaries - # where each dictionary corresponds to a row in the DataFrame. - # This is the standardized format we want for self.data - for _, row in jsonl_data.iterrows(): - self.data.append(row.to_dict()) - else: - raise NotImplementedError( - "Only JSONL format is supported for CustomDataset." - ) - - random.seed(self.random_seed) - random.shuffle(self.data) - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - lora_path: Optional[str] = None, - max_loras: Optional[int] = None, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - skip_chat_template: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - sampled_requests = [] - for i, item in enumerate(self.data): - if len(sampled_requests) >= num_requests: - break - prompt = item["prompt"] - - # apply template - if not skip_chat_template: - prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=True, - tokenize=False, - ) - - prompt_len = len(tokenizer(prompt).input_ids) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - request_id=request_id_prefix + str(i), - ) - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - - return sampled_requests - - -# ----------------------------------------------------------------------------- -# Sonnet Dataset Implementation -# ----------------------------------------------------------------------------- - - -class SonnetDataset(BenchmarkDataset): - """ - Simplified implementation of the Sonnet dataset. Loads poem lines from a - text file and generates sample requests. Default values here copied from - `benchmark_serving.py` for the sonnet dataset. - """ - - DEFAULT_PREFIX_LEN = 200 - DEFAULT_INPUT_LEN = 550 - DEFAULT_OUTPUT_LEN = 150 - - def __init__( - self, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.load_data() - - def load_data(self) -> None: - if not self.dataset_path: - raise ValueError("dataset_path must be provided.") - with open(self.dataset_path, encoding="utf-8") as f: - self.data = f.readlines() - - def sample( - self, - tokenizer, - num_requests: int, - prefix_len: int = DEFAULT_PREFIX_LEN, - input_len: int = DEFAULT_INPUT_LEN, - output_len: int = DEFAULT_OUTPUT_LEN, - return_prompt_formatted: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - # Calculate average token length for a poem line. - tokenized_lines = [tokenizer(line).input_ids for line in self.data] - avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines) - - # Build the base prompt. - base_prompt = "Pick as many lines as you can from these poem lines:\n" - base_msg = [{"role": "user", "content": base_prompt}] - base_fmt = tokenizer.apply_chat_template( - base_msg, add_generation_prompt=True, tokenize=False - ) - base_offset = len(tokenizer(base_fmt).input_ids) - if input_len <= base_offset: - raise ValueError( - f"'input_len' must be higher than the base prompt length " - f"({base_offset})." - ) - - # Determine how many poem lines to use. - num_input_lines = round((input_len - base_offset) / avg_len) - num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0) - prefix_lines = self.data[:num_prefix_lines] - - samples = [] - ind = 0 - while len(samples) < num_requests: - extra_lines = random.choices( - self.data, k=num_input_lines - num_prefix_lines - ) - prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" - msg = [{"role": "user", "content": prompt}] - prompt_formatted = tokenizer.apply_chat_template( - msg, add_generation_prompt=True, tokenize=False - ) - prompt_len = len(tokenizer(prompt_formatted).input_ids) - - if prompt_len <= input_len: - samples.append( - SampleRequest( - prompt=prompt_formatted if return_prompt_formatted else prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - return samples - - -# ----------------------------------------------------------------------------- -# BurstGPT Dataset Implementation -# ----------------------------------------------------------------------------- - - -class BurstGPTDataset(BenchmarkDataset): - """ - Implements the BurstGPT dataset. Loads data from a CSV file and generates - sample requests based on synthetic prompt generation. Only rows with Model - "GPT-4" and positive response tokens are used. - """ - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - self.load_data() - - def load_data( - self, - ): - if self.dataset_path is None: - raise ValueError("dataset_path must be provided for loading data.") - - df = pd.read_csv(self.dataset_path) - # Filter to keep only GPT-4 rows. - gpt4_df = df[df["Model"] == "GPT-4"] - # Remove failed requests (where Response tokens is 0 or less). - gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] - # Sample the desired number of rows. - self.data = gpt4_df - - def _sample_loaded_data(self, num_requests: int) -> list: - if num_requests <= len(self.data): - data = self.data.sample(n=num_requests, random_state=self.random_seed) - else: - data = self.data.sample( - n=num_requests, - random_state=self.random_seed, - replace=True, - ) - # Convert the dataframe to a list of lists. - return data.values.tolist() - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - max_loras: Optional[int] = None, - lora_path: Optional[str] = None, - request_id_prefix: str = "", - **kwargs, - ) -> list[SampleRequest]: - samples = [] - data = self._sample_loaded_data(num_requests=num_requests) - for i in range(num_requests): - input_len = int(data[i][2]) - output_len = int(data[i][3]) - lora_req, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path - ) - vocab_size = tokenizer.vocab_size - # Generate a synthetic prompt: a list of token IDs computed as (i + - # j) modulo vocab_size. - token_ids = [(i + j) % vocab_size for j in range(input_len)] - prompt = tokenizer.decode(token_ids) - samples.append( - SampleRequest( - prompt=prompt, - prompt_len=input_len, - expected_output_len=output_len, - lora_request=lora_req, - request_id=request_id_prefix + str(i), - ) - ) - return samples - - -# ----------------------------------------------------------------------------- -# HuggingFace Dataset Base Implementation -# ----------------------------------------------------------------------------- -class HuggingFaceDataset(BenchmarkDataset): - """Base class for datasets hosted on HuggingFace.""" - - SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set() - - def __init__( - self, - dataset_path: str, - dataset_split: str, - no_stream: bool = False, - dataset_subset: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__(dataset_path=dataset_path, **kwargs) - - self.dataset_split = dataset_split - self.dataset_subset = dataset_subset - self.load_stream = not no_stream - self.load_data() - - def load_data(self) -> None: - """Load data from HuggingFace datasets.""" - self.data = load_dataset( - self.dataset_path, - name=self.dataset_subset, - split=self.dataset_split, - streaming=self.load_stream, - ) - self.data = self.data.shuffle(seed=self.random_seed) - - -# ----------------------------------------------------------------------------- -# Conversation Dataset Implementation -# ----------------------------------------------------------------------------- - - -class ConversationDataset(HuggingFaceDataset): - """Dataset for conversation data with multimodal support.""" - - SUPPORTED_DATASET_PATHS = { - "lmms-lab/LLaVA-OneVision-Data", - "Aeala/ShareGPT_Vicuna_unfiltered", - } - IS_MULTIMODAL = True - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - # Filter examples with at least 2 conversations - filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) - sampled_requests = [] - dynamic_output = output_len is None - ind = 0 - - for item in filtered_data: - if len(sampled_requests) >= num_requests: - break - conv = item["conversations"] - prompt, completion = conv[0]["value"], conv[1]["value"] - - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - completion_len = len(completion_ids) - output_len = completion_len if dynamic_output else output_len - assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence(prompt_len, completion_len): - continue - mm_content = process_image(item["image"]) if "image" in item else None - if enable_multimodal_chat: - # Note: when chat is enabled the request prompt_len is no longer - # accurate and we will be using request output to count the - # actual prompt len and output len - prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - multi_modal_data=mm_content, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# Vision Arena Dataset Implementation -# ----------------------------------------------------------------------------- - - -class VisionArenaDataset(HuggingFaceDataset): - """ - Vision Arena Dataset. - """ - - DEFAULT_OUTPUT_LEN = 128 - SUPPORTED_DATASET_PATHS = { - "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"], - "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"], - } - IS_MULTIMODAL = True - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - sampled_requests = [] - for i, item in enumerate(self.data): - if len(sampled_requests) >= num_requests: - break - parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) - if parser_fn is None: - raise ValueError(f"Unsupported dataset path: {self.dataset_path}") - prompt = parser_fn(item) - mm_content = process_image(item["images"][0]) - prompt_len = len(tokenizer(prompt).input_ids) - if enable_multimodal_chat: - # Note: when chat is enabled the request prompt_len is no longer - # accurate and we will be using request output to count the - # actual prompt len - prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - multi_modal_data=mm_content, - request_id=request_id_prefix + str(i), - ) - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# Instruct Coder Dataset Implementation -# ----------------------------------------------------------------------------- - - -class InstructCoderDataset(HuggingFaceDataset): - """ - InstructCoder Dataset. - https://huggingface.co/datasets/likaixin/InstructCoder - - InstructCoder is the dataset designed for general code editing. It consists - of 114,239 instruction-input-output triplets, and covers multiple distinct - code editing scenario. - """ - - DEFAULT_OUTPUT_LEN = 200 # this is the average default output length - SUPPORTED_DATASET_PATHS = { - "likaixin/InstructCoder", - } - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - sampled_requests = [] - for i, item in enumerate(self.data): - if len(sampled_requests) >= num_requests: - break - prompt = ( - f"{item['input']}\n\n{item['instruction']} Just output " - "the code, do not include any explanation." - ) - - # apply template - prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=True, - tokenize=False, - ) - prompt_len = len(tokenizer(prompt).input_ids) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - request_id=request_id_prefix + str(i), - ) - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# MT-Bench Dataset Implementation -# ----------------------------------------------------------------------------- - - -class MTBenchDataset(HuggingFaceDataset): - """ - MT-Bench Dataset. - https://huggingface.co/datasets/philschmid/mt-bench - - We create a single turn dataset for MT-Bench. - This is similar to Spec decoding benchmark setup in vLLM - https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 - """ # noqa: E501 - - DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM - SUPPORTED_DATASET_PATHS = { - "philschmid/mt-bench", - } - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - **kwargs, - ) -> list: - output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - sampled_requests = [] - - for i, item in enumerate(self.data): - if len(sampled_requests) >= num_requests: - break - prompt = item["turns"][0] - - # apply template - prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=True, - tokenize=False, - ) - - prompt_len = len(tokenizer(prompt).input_ids) - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - request_id=request_id_prefix + str(i), - ) - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# AIMO Dataset Implementation -# ----------------------------------------------------------------------------- - - -class AIMODataset(HuggingFaceDataset): - """ - Dataset class for processing a AIMO dataset with reasoning questions. - """ - - SUPPORTED_DATASET_PATHS = { - "AI-MO/aimo-validation-aime", - "AI-MO/NuminaMath-1.5", - "AI-MO/NuminaMath-CoT", - } - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - request_id_prefix: str = "", - **kwargs, - ) -> list: - sampled_requests = [] - dynamic_output = output_len is None - ind = 0 - - for item in self.data: - if len(sampled_requests) >= num_requests: - break - prompt, completion = item["problem"], item["solution"] - - prompt_ids = tokenizer(prompt).input_ids - completion_ids = tokenizer(completion).input_ids - prompt_len = len(prompt_ids) - completion_len = len(completion_ids) - output_len = completion_len if dynamic_output else output_len - assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence( - prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000 - ): - continue - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - multi_modal_data=None, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests - - -# ----------------------------------------------------------------------------- -# Next Edit Prediction Dataset Implementation -# ----------------------------------------------------------------------------- - - -zeta_prompt = """### Instruction: -You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. - -### User Edits: - -{} - -### User Excerpt: - -{} - -### Response: - -""" # noqa: E501 - - -def _format_zeta_prompt( - sample: dict, original_start_marker: str = "<|editable_region_start|>" -) -> dict: - """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. - - This function formats examples from the NEP dataset - into prompts and expected outputs. It could be - further extended to support more NEP datasets. - - Args: - sample: The dataset sample containing events, - inputs, and outputs. - original_start_marker: The marker indicating the - start of the editable region. Defaults to - "<|editable_region_start|>". - - Returns: - A dictionary with the formatted prompts and expected outputs. - """ - events = sample["events"] - input = sample["input"] - output = sample["output"] - prompt = zeta_prompt.format(events, input) - - # following the original implementation, extract the focused region - # from the raw output - output_start_index = output.find(original_start_marker) - output_focused_region = output[output_start_index:] - expected_output = output_focused_region - - return {"prompt": prompt, "expected_output": expected_output} - - -class NextEditPredictionDataset(HuggingFaceDataset): - """ - Dataset class for processing a Next Edit Prediction dataset. - """ - - SUPPORTED_DATASET_PATHS = { - "zed-industries/zeta", - } - MAPPING_PROMPT_FUNCS = { - "zed-industries/zeta": _format_zeta_prompt, - } - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - request_id_prefix: str = "", - **kwargs, - ): - formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) - if formatting_prompt_func is None: - raise ValueError(f"Unsupported dataset path: {self.dataset_path}") - samples = [] - for i, sample in enumerate(self.data): - sample = formatting_prompt_func(sample) - samples.append( - SampleRequest( - prompt=sample["prompt"], - prompt_len=len(tokenizer(sample["prompt"]).input_ids), - expected_output_len=len( - tokenizer(sample["expected_output"]).input_ids - ), - request_id=request_id_prefix + str(i), - ) - ) - if len(samples) >= num_requests: - break - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) - return samples - - -# ----------------------------------------------------------------------------- -# ASR Dataset Implementation -# ----------------------------------------------------------------------------- - - -class ASRDataset(HuggingFaceDataset): - """ - Dataset class for processing a ASR dataset for transcription. - Tested on the following set: - - +----------------+----------------------------------------+--------------------------+-----------------------------+ - | Dataset | Domain | Speaking Style | hf-subset | - +----------------+----------------------------------------+--------------------------+-----------------------------+ - | TED-LIUM | TED talks | Oratory | release1, release2, release3| - | | | | release3-speaker-adaptation | - | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... | - | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" | - | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test | - | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test | - | AMI | Meetings | Spontaneous | ihm, sdm | - +----------------+----------------------------------------+--------------------------+-----------------------------+ - - """ # noqa: E501 - - SUPPORTED_DATASET_PATHS = { - "openslr/librispeech_asr", - "facebook/voxpopuli", - "LIUM/tedlium", - "edinburghcstr/ami", - "speechcolab/gigaspeech", - "kensho/spgispeech", - } - - DEFAULT_OUTPUT_LEN = 128 - IS_MULTIMODAL = True - - # TODO Whisper-specific. Abstract interface when more models are supported. - TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" - skip_long_audios: bool = True - - def sample( - self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - request_id_prefix: str = "", - **kwargs, - ) -> list: - import librosa - - output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - prompt = ASRDataset.TRANSCRIPTION_PREAMBLE - prompt_len = len(tokenizer(prompt).input_ids) - sampled_requests = [] - skipped = 0 - ind = 0 - for item in self.data: - if len(sampled_requests) >= num_requests: - break - audio = item["audio"] - y, sr = audio["array"], audio["sampling_rate"] - duration_s = librosa.get_duration(y=y, sr=sr) - # Whisper max supported duration - if self.skip_long_audios and duration_s > 30: - skipped += 1 - continue - - mm_content = {"audio": (y, sr)} - sampled_requests.append( - SampleRequest( - prompt=prompt, - prompt_len=prompt_len, - expected_output_len=output_len, - multi_modal_data=mm_content, - request_id=request_id_prefix + str(ind), - ) - ) - ind += 1 - if skipped: - logger.warning( - "%d samples discarded from dataset due to" - " their length being greater than" - " what Whisper supports.", - skipped, - ) - self.maybe_oversample_requests( - sampled_requests, num_requests, request_id_prefix - ) - return sampled_requests diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index d8b960edaa468..a7892f3f71243 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,191 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Benchmark the latency of processing a single batch of requests.""" - -import argparse -import dataclasses -import json -import os -import time -from typing import Any, Optional - -import numpy as np -from tqdm import tqdm -from typing_extensions import deprecated - -import vllm.envs as envs -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptType -from vllm.sampling_params import BeamSearchParams -from vllm.utils import FlexibleArgumentParser - - -def save_to_pytorch_benchmark_format( - args: argparse.Namespace, results: dict[str, Any] -) -> None: - pt_records = convert_to_pytorch_benchmark_format( - args=args, - metrics={"latency": results["latencies"]}, - extra_info={k: results[k] for k in ["avg_latency", "percentiles"]}, - ) - if pt_records: - pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" - write_to_json(pt_file, pt_records) - - -@deprecated( - "benchmark_latency.py is deprecated and will be removed in a " - "future version. Please use 'vllm bench latency' instead.", -) -def main(args: argparse.Namespace): - print(args) - - engine_args = EngineArgs.from_cli_args(args) - - # NOTE(woosuk): If the request cannot be processed in a single batch, - # the engine will automatically process the request in multiple batches. - llm = LLM(**dataclasses.asdict(engine_args)) - assert llm.llm_engine.model_config.max_model_len >= ( - args.input_len + args.output_len - ), ( - "Please ensure that max_model_len is greater than" - " the sum of input_len and output_len." - ) - - sampling_params = SamplingParams( - n=args.n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=args.output_len, - detokenize=not args.disable_detokenize, - ) - print(sampling_params) - dummy_prompt_token_ids = np.random.randint( - 10000, size=(args.batch_size, args.input_len) - ) - dummy_prompts: list[PromptType] = [ - {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist() - ] - - def llm_generate(): - if not args.use_beam_search: - llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) - else: - llm.beam_search( - dummy_prompts, - BeamSearchParams( - beam_width=args.n, - max_tokens=args.output_len, - ignore_eos=True, - ), - ) - - def run_to_completion(profile_dir: Optional[str] = None): - if profile_dir: - llm.start_profile() - llm_generate() - llm.stop_profile() - else: - start_time = time.perf_counter() - llm_generate() - end_time = time.perf_counter() - latency = end_time - start_time - return latency - - print("Warming up...") - for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): - run_to_completion(profile_dir=None) - - if args.profile: - profile_dir = envs.VLLM_TORCH_PROFILER_DIR - print(f"Profiling (results will be saved to '{profile_dir}')...") - run_to_completion(profile_dir=profile_dir) - return - - # Benchmark. - latencies = [] - for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): - latencies.append(run_to_completion(profile_dir=None)) - latencies = np.array(latencies) - percentages = [10, 25, 50, 75, 90, 99] - percentiles = np.percentile(latencies, percentages) - print(f"Avg latency: {np.mean(latencies)} seconds") - for percentage, percentile in zip(percentages, percentiles): - print(f"{percentage}% percentile latency: {percentile} seconds") - - # Output JSON results if specified - if args.output_json: - results = { - "avg_latency": np.mean(latencies), - "latencies": latencies.tolist(), - "percentiles": dict(zip(percentages, percentiles.tolist())), - } - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - save_to_pytorch_benchmark_format(args, results) - - -def create_argument_parser(): - parser = FlexibleArgumentParser( - description="Benchmark the latency of processing a single batch of " - "requests till completion." - ) - parser.add_argument("--input-len", type=int, default=32) - parser.add_argument("--output-len", type=int, default=128) - parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument( - "--n", - type=int, - default=1, - help="Number of generated sequences per prompt.", - ) - parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument( - "--num-iters-warmup", - type=int, - default=10, - help="Number of iterations to run for warmup.", - ) - parser.add_argument( - "--num-iters", type=int, default=30, help="Number of iterations to run." - ) - parser.add_argument( - "--profile", - action="store_true", - help="profile the generation process of a single batch", - ) - parser.add_argument( - "--output-json", - type=str, - default=None, - help="Path to save the latency results in JSON format.", - ) - parser.add_argument( - "--disable-detokenize", - action="store_true", - help=( - "Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)" - ), - ) - - parser = EngineArgs.add_cli_args(parser) - # V1 enables prefix caching by default which skews the latency - # numbers. We need to disable prefix caching by default. - parser.set_defaults(enable_prefix_caching=False) - - return parser - +import sys if __name__ == "__main__": - parser = create_argument_parser() - args = parser.parse_args() - if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: - raise OSError( - "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " - "Please set it to a valid path to use torch profiler." - ) - main(args) + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench latency + +For help with the new command, run: + vllm bench latency --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench latency --help +""") + sys.exit(1) diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index c60040d05ab7a..11833fa1b3c8b 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -77,7 +77,7 @@ def invoke_main() -> None: "--num-iteration", type=int, default=100, - help="Number of iterations to run to stablize final data readings", + help="Number of iterations to run to stabilize final data readings", ) parser.add_argument( "--num-req", type=int, default=128, help="Number of requests in the batch" diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 02f5f585c0c16..76cf51498020b 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,1324 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -r"""Benchmark online serving throughput. - -On the server side, run one of the following commands: - vLLM OpenAI API server - vllm serve \ - --swap-space 16 - -On the client side, run: - python benchmarks/benchmark_serving.py \ - --backend \ - --model \ - --dataset-name sharegpt \ - --dataset-path \ - --request-rate \ # By default is inf - --num-prompts # By default is 1000 - - when using tgi backend, add - --endpoint /generate_stream - to the end of the command above. -""" - -import argparse -import asyncio -import gc -import json -import os -import random -import time -import warnings -from collections.abc import Iterable -from dataclasses import dataclass -from datetime import datetime -from typing import Any, Literal, Optional - -import numpy as np -from tqdm.asyncio import tqdm -from transformers import PreTrainedTokenizerBase -from typing_extensions import deprecated - -from backend_request_func import ( - ASYNC_REQUEST_FUNCS, - OPENAI_COMPATIBLE_BACKENDS, - RequestFuncInput, - RequestFuncOutput, -) - -try: - from vllm.transformers_utils.tokenizer import get_tokenizer -except ImportError: - from backend_request_func import get_tokenizer - -try: - from vllm.utils import FlexibleArgumentParser -except ImportError: - from argparse import ArgumentParser as FlexibleArgumentParser - -from benchmark_dataset import ( - AIMODataset, - ASRDataset, - BurstGPTDataset, - ConversationDataset, - CustomDataset, - HuggingFaceDataset, - InstructCoderDataset, - MTBenchDataset, - NextEditPredictionDataset, - RandomDataset, - SampleRequest, - ShareGPTDataset, - SonnetDataset, - VisionArenaDataset, -) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm.benchmarks.serve import get_request - -MILLISECONDS_TO_SECONDS_CONVERSION = 1000 - - -@dataclass -class BenchmarkMetrics: - completed: int - total_input: int - total_output: int - request_throughput: float - request_goodput: float - output_throughput: float - total_token_throughput: float - mean_ttft_ms: float - median_ttft_ms: float - std_ttft_ms: float - percentiles_ttft_ms: list[tuple[float, float]] - mean_tpot_ms: float - median_tpot_ms: float - std_tpot_ms: float - percentiles_tpot_ms: list[tuple[float, float]] - mean_itl_ms: float - median_itl_ms: float - std_itl_ms: float - percentiles_itl_ms: list[tuple[float, float]] - # E2EL stands for end-to-end latency per request. - # It is the time taken on the client side from sending - # a request to receiving a complete response. - mean_e2el_ms: float - median_e2el_ms: float - std_e2el_ms: float - percentiles_e2el_ms: list[tuple[float, float]] - - -def calculate_metrics( - input_requests: list[SampleRequest], - outputs: list[RequestFuncOutput], - dur_s: float, - tokenizer: PreTrainedTokenizerBase, - selected_percentile_metrics: list[str], - selected_percentiles: list[float], - goodput_config_dict: dict[str, float], -) -> tuple[BenchmarkMetrics, list[int]]: - actual_output_lens: list[int] = [] - total_input = 0 - completed = 0 - good_completed = 0 - itls: list[float] = [] - tpots: list[float] = [] - all_tpots: list[float] = [] - ttfts: list[float] = [] - e2els: list[float] = [] - for i in range(len(outputs)): - if outputs[i].success: - output_len = outputs[i].output_tokens - - if not output_len: - # We use the tokenizer to count the number of output tokens - # for some serving backends instead of looking at - # len(outputs[i].itl) since multiple output tokens may be - # bundled together - # Note : this may inflate the output token count slightly - output_len = len( - tokenizer( - outputs[i].generated_text, add_special_tokens=False - ).input_ids - ) - actual_output_lens.append(output_len) - total_input += input_requests[i].prompt_len - tpot = 0 - if output_len > 1: - latency_minus_ttft = outputs[i].latency - outputs[i].ttft - tpot = latency_minus_ttft / (output_len - 1) - tpots.append(tpot) - # Note: if output_len <= 1, we regard tpot as 0 for goodput - all_tpots.append(tpot) - itls += outputs[i].itl - ttfts.append(outputs[i].ttft) - e2els.append(outputs[i].latency) - completed += 1 - else: - actual_output_lens.append(0) - - if goodput_config_dict: - valid_metrics = [] - slo_values = [] - - if "ttft" in goodput_config_dict: - valid_metrics.append(ttfts) - slo_values.append( - goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION - ) - if "tpot" in goodput_config_dict: - valid_metrics.append(all_tpots) - slo_values.append( - goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION - ) - if "e2el" in goodput_config_dict: - valid_metrics.append(e2els) - slo_values.append( - goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION - ) - - for req_metric in zip(*valid_metrics): - is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) - if is_good_req: - good_completed += 1 - - if completed == 0: - warnings.warn( - "All requests failed. This is likely due to a misconfiguration " - "on the benchmark arguments.", - stacklevel=2, - ) - metrics = BenchmarkMetrics( - completed=completed, - total_input=total_input, - total_output=sum(actual_output_lens), - request_throughput=completed / dur_s, - request_goodput=good_completed / dur_s, - output_throughput=sum(actual_output_lens) / dur_s, - total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) - * 1000, # ttfts is empty if streaming is not supported by backend - std_ttft_ms=np.std(ttfts or 0) * 1000, - median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[ - (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles - ], - mean_tpot_ms=np.mean(tpots or 0) * 1000, - std_tpot_ms=np.std(tpots or 0) * 1000, - median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[ - (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles - ], - mean_itl_ms=np.mean(itls or 0) * 1000, - std_itl_ms=np.std(itls or 0) * 1000, - median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[ - (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles - ], - mean_e2el_ms=np.mean(e2els or 0) * 1000, - std_e2el_ms=np.std(e2els or 0) * 1000, - median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[ - (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles - ], - ) - - return metrics, actual_output_lens - - -async def benchmark( - backend: str, - api_url: str, - base_url: str, - model_id: str, - model_name: str, - tokenizer: PreTrainedTokenizerBase, - input_requests: list[SampleRequest], - logprobs: Optional[int], - request_rate: float, - burstiness: float, - disable_tqdm: bool, - profile: bool, - selected_percentile_metrics: list[str], - selected_percentiles: list[float], - ignore_eos: bool, - goodput_config_dict: dict[str, float], - max_concurrency: Optional[int], - lora_modules: Optional[Iterable[str]], - extra_body: Optional[dict], - ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, - ramp_up_start_rps: Optional[int] = None, - ramp_up_end_rps: Optional[int] = None, -): - if backend in ASYNC_REQUEST_FUNCS: - request_func = ASYNC_REQUEST_FUNCS[backend] - else: - raise ValueError(f"Unknown backend: {backend}") - - print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len, test_mm_content = ( - input_requests[0].prompt, - input_requests[0].prompt_len, - input_requests[0].expected_output_len, - input_requests[0].multi_modal_data, - ) - - assert ( - test_mm_content is None - or isinstance(test_mm_content, dict) - or ( - isinstance(test_mm_content, list) - and all(isinstance(item, dict) for item in test_mm_content) - ) - ), "multi_modal_data must be a dict or list[dict]" - test_input = RequestFuncInput( - model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=api_url, - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - ) - - test_output = await request_func(request_func_input=test_input) - if not test_output.success: - raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}" - ) - else: - print("Initial test run completed. Starting main benchmark run...") - - if lora_modules: - # For each input request, choose a LoRA module at random. - lora_modules = iter( - [random.choice(lora_modules) for _ in range(len(input_requests))] - ) - - if profile: - print("Starting profiler...") - profile_input = RequestFuncInput( - model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - ) - profile_output = await request_func(request_func_input=profile_input) - if profile_output.success: - print("Profiler started") - - distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" - - if ramp_up_strategy is not None: - print( - f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase " - f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over " - "the duration of the benchmark." - ) - else: - print(f"Traffic request rate: {request_rate} RPS.") - - print(f"Burstiness factor: {burstiness} ({distribution})") - print(f"Maximum request concurrency: {max_concurrency}") - - pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - - # This can be used once the minimum Python version is 3.10 or higher, - # and it will simplify the code in limited_request_func. - # semaphore = (asyncio.Semaphore(max_concurrency) - # if max_concurrency else contextlib.nullcontext()) - semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None - - async def limited_request_func(request_func_input, pbar): - if semaphore is None: - return await request_func(request_func_input=request_func_input, pbar=pbar) - async with semaphore: - return await request_func(request_func_input=request_func_input, pbar=pbar) - - benchmark_start_time = time.perf_counter() - tasks: list[asyncio.Task] = [] - - rps_change_events = [] - last_int_rps = -1 - if ramp_up_strategy is not None and ramp_up_start_rps is not None: - last_int_rps = ramp_up_start_rps - rps_change_events.append( - { - "rps": last_int_rps, - "timestamp": datetime.now().isoformat(), - } - ) - - async for request, current_request_rate in get_request( - input_requests, - request_rate, - burstiness, - ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - ): - if ramp_up_strategy is not None: - current_int_rps = int(current_request_rate) - if current_int_rps > last_int_rps: - timestamp = datetime.now().isoformat() - for rps_val in range(last_int_rps + 1, current_int_rps + 1): - rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) - last_int_rps = current_int_rps - - prompt, prompt_len, output_len, mm_content, request_id = ( - request.prompt, - request.prompt_len, - request.expected_output_len, - request.multi_modal_data, - request.request_id, - ) - req_model_id, req_model_name = model_id, model_name - if lora_modules: - req_lora_module = next(lora_modules) - req_model_id, req_model_name = req_lora_module, req_lora_module - - request_func_input = RequestFuncInput( - model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - request_id=request_id, - ) - task = limited_request_func(request_func_input=request_func_input, pbar=pbar) - tasks.append(asyncio.create_task(task)) - outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) - - if pbar is not None: - pbar.close() - - benchmark_duration = time.perf_counter() - benchmark_start_time - - metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, - outputs=outputs, - dur_s=benchmark_duration, - tokenizer=tokenizer, - selected_percentile_metrics=selected_percentile_metrics, - selected_percentiles=selected_percentiles, - goodput_config_dict=goodput_config_dict, - ) - - print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) - print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - if max_concurrency is not None: - print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency)) - if request_rate != float("inf"): - print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", request_rate)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) - print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) - print( - "{:<40} {:<10.2f}".format( - "Request throughput (req/s):", metrics.request_throughput - ) - ) - if goodput_config_dict: - print( - "{:<40} {:<10.2f}".format( - "Request goodput (req/s):", metrics.request_goodput - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Output token throughput (tok/s):", metrics.output_throughput - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Total Token throughput (tok/s):", metrics.total_token_throughput - ) - ) - - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "request_throughput": metrics.request_throughput, - "request_goodput": metrics.request_goodput if goodput_config_dict else None, - "output_throughput": metrics.output_throughput, - "total_token_throughput": metrics.total_token_throughput, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": actual_output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - } - - if rps_change_events: - result["rps_change_events"] = rps_change_events - - def process_one_metric( - # E.g., "ttft" - metric_attribute_name: str, - # E.g., "TTFT" - metric_name: str, - # E.g., "Time to First Token" - metric_header: str, - ): - # This function prints and adds statistics of the specified - # metric. - if metric_attribute_name not in selected_percentile_metrics: - return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) - print( - "{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"), - ) - ) - print( - "{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"), - ) - ) - result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms" - ) - result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms" - ) - result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms" - ) - for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): - p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) - result[f"p{p_word}_{metric_attribute_name}_ms"] = value - - process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") - process_one_metric("itl", "ITL", "Inter-token Latency") - process_one_metric("e2el", "E2EL", "End-to-end Latency") - - print("=" * 50) - - if profile: - print("Stopping profiler...") - profile_input = RequestFuncInput( - model=model_id, - prompt=test_prompt, - api_url=base_url + "/stop_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - ) - profile_output = await request_func(request_func_input=profile_input) - if profile_output.success: - print("Profiler stopped") - - return result - - -def check_goodput_args(args): - # Check and parse goodput arguments - goodput_config_dict = {} - VALID_NAMES = ["ttft", "tpot", "e2el"] - if args.goodput: - goodput_config_dict = parse_goodput(args.goodput) - for slo_name, slo_val in goodput_config_dict.items(): - if slo_name not in VALID_NAMES: - raise ValueError( - f"Invalid metric name found, {slo_name}: {slo_val}. " - "The service level objective name should be one of " - f"{str(VALID_NAMES)}. " - ) - if slo_val < 0: - raise ValueError( - f"Invalid value found, {slo_name}: {slo_val}. " - "The service level objective value should be " - "non-negative." - ) - return goodput_config_dict - - -def parse_goodput(slo_pairs): - goodput_config_dict = {} - try: - for slo_pair in slo_pairs: - slo_name, slo_val = slo_pair.split(":") - goodput_config_dict[slo_name] = float(slo_val) - except ValueError as err: - raise argparse.ArgumentTypeError( - "Invalid format found for service level objectives. " - 'Specify service level objectives for goodput as "KEY:VALUE" ' - "pairs, where the key is a metric name, and the value is a " - "number in milliseconds." - ) from err - return goodput_config_dict - - -def save_to_pytorch_benchmark_format( - args: argparse.Namespace, results: dict[str, Any], file_name: str -) -> None: - metrics = [ - "median_ttft_ms", - "mean_ttft_ms", - "std_ttft_ms", - "p99_ttft_ms", - "mean_tpot_ms", - "median_tpot_ms", - "std_tpot_ms", - "p99_tpot_ms", - "median_itl_ms", - "mean_itl_ms", - "std_itl_ms", - "p99_itl_ms", - ] - # These raw data might be useful, but they are rather big. They can be added - # later if needed - ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] - pt_records = convert_to_pytorch_benchmark_format( - args=args, - metrics={k: [results[k]] for k in metrics}, - extra_info={ - k: results[k] - for k in results - if k not in metrics and k not in ignored_metrics - }, - ) - if pt_records: - # Don't use json suffix here as we don't want CI to pick it up - pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" - write_to_json(pt_file, pt_records) - - -@deprecated( - "benchmark_serving.py is deprecated and will be removed in a future " - "version. Please use 'vllm bench serve' instead.", -) -def main(args: argparse.Namespace): - print(args) - random.seed(args.seed) - np.random.seed(args.seed) - - backend = args.backend - model_id = args.model - model_name = args.served_model_name - tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model - tokenizer_mode = args.tokenizer_mode - - # Validate ramp-up arguments - if args.ramp_up_strategy is not None: - if args.request_rate != float("inf"): - raise ValueError( - "When using ramp-up, do not specify --request-rate. " - "The request rate will be controlled by ramp-up parameters. " - "Please remove the --request-rate argument." - ) - if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: - raise ValueError( - "When using --ramp-up-strategy, both --ramp-up-start-rps and " - "--ramp-up-end-rps must be specified" - ) - if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: - raise ValueError("Ramp-up start and end RPS must be non-negative") - if args.ramp_up_start_rps > args.ramp_up_end_rps: - raise ValueError("Ramp-up start RPS must be less than end RPS") - if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0: - raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") - - if args.base_url is not None: - api_url = f"{args.base_url}{args.endpoint}" - base_url = f"{args.base_url}" - else: - api_url = f"http://{args.host}:{args.port}{args.endpoint}" - base_url = f"http://{args.host}:{args.port}" - - tokenizer = get_tokenizer( - tokenizer_id, - tokenizer_mode=tokenizer_mode, - trust_remote_code=args.trust_remote_code, - ) - - if args.dataset_name is None: - raise ValueError( - "Please specify '--dataset-name' and the corresponding " - "'--dataset-path' if required." - ) - - if args.dataset_name == "custom": - dataset = CustomDataset(dataset_path=args.dataset_path) - input_requests = dataset.sample( - num_requests=args.num_prompts, - tokenizer=tokenizer, - output_len=args.custom_output_len, - skip_chat_template=args.custom_skip_chat_template, - request_id_prefix=args.request_id_prefix, - ) - - elif args.dataset_name == "sonnet": - dataset = SonnetDataset(dataset_path=args.dataset_path) - # For the "sonnet" dataset, formatting depends on the backend. - if args.backend == "openai-chat": - input_requests = dataset.sample( - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=False, - request_id_prefix=args.request_id_prefix, - ) - else: - assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset." - ) - input_requests = dataset.sample( - num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=True, - request_id_prefix=args.request_id_prefix, - ) - - elif args.dataset_name == "hf": - # all following datasets are implemented from the - # HuggingFaceDataset base class - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: - dataset_class = VisionArenaDataset - args.hf_split = "train" - args.hf_subset = None - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: - dataset_class = InstructCoderDataset - args.hf_split = "train" - elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: - dataset_class = MTBenchDataset - args.hf_split = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: - dataset_class = ConversationDataset - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: - dataset_class = AIMODataset - args.hf_split = "train" - elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 - dataset_class = NextEditPredictionDataset - args.hf_split = "train" - elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: - dataset_class = ASRDataset - args.hf_split = "train" - else: - supported_datasets = set( - [ - dataset_name - for cls in HuggingFaceDataset.__subclasses__() - for dataset_name in cls.SUPPORTED_DATASET_PATHS - ] - ) - raise ValueError( - f"Unsupported dataset path: {args.dataset_path}. " - "Huggingface dataset only supports dataset_path" - f" from one of following: {supported_datasets}. " - "Please consider contributing if you would " - "like to add support for additional dataset formats." - ) - - if dataset_class.IS_MULTIMODAL and backend not in [ - "openai-chat", - "openai-audio", - ]: - # multi-modal benchmark is only available on OpenAI Chat backend. - raise ValueError( - "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' backend." - ) - input_requests = dataset_class( - dataset_path=args.dataset_path, - dataset_subset=args.hf_subset, - dataset_split=args.hf_split, - random_seed=args.seed, - no_stream=args.no_stream, - ).sample( - num_requests=args.num_prompts, - tokenizer=tokenizer, - output_len=args.hf_output_len, - request_id_prefix=args.request_id_prefix, - ) - - else: - # For datasets that follow a similar structure, use a mapping. - dataset_mapping = { - "sharegpt": lambda: ShareGPTDataset( - random_seed=args.seed, dataset_path=args.dataset_path - ).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - request_id_prefix=args.request_id_prefix, - ), - "burstgpt": lambda: BurstGPTDataset( - random_seed=args.seed, dataset_path=args.dataset_path - ).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - request_id_prefix=args.request_id_prefix, - ), - "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - prefix_len=args.random_prefix_len, - input_len=args.random_input_len, - output_len=args.random_output_len, - range_ratio=args.random_range_ratio, - request_id_prefix=args.request_id_prefix, - ), - } - - try: - input_requests = dataset_mapping[args.dataset_name]() - except KeyError as err: - raise ValueError(f"Unknown dataset: {args.dataset_name}") from err - goodput_config_dict = check_goodput_args(args) - - # Collect the sampling parameters. - sampling_params = { - k: v - for k, v in { - "top_p": args.top_p, - "top_k": args.top_k, - "min_p": args.min_p, - "temperature": args.temperature, - }.items() - if v is not None - } - - # Sampling parameters are only supported by openai-compatible backend. - if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: - raise ValueError( - "Sampling parameters are only supported by openai-compatible backends." - ) - - if "temperature" not in sampling_params: - sampling_params["temperature"] = 0.0 # Default to greedy decoding. - - if args.backend == "llama.cpp": - # Disable prompt caching in llama.cpp backend - sampling_params["cache_prompt"] = False - - # Avoid GC processing "static" data - reduce pause times. - gc.collect() - gc.freeze() - - benchmark_result = asyncio.run( - benchmark( - backend=backend, - api_url=api_url, - base_url=base_url, - model_id=model_id, - model_name=model_name, - tokenizer=tokenizer, - input_requests=input_requests, - logprobs=args.logprobs, - request_rate=args.request_rate, - burstiness=args.burstiness, - disable_tqdm=args.disable_tqdm, - profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], - ignore_eos=args.ignore_eos, - goodput_config_dict=goodput_config_dict, - max_concurrency=args.max_concurrency, - lora_modules=args.lora_modules, - extra_body=sampling_params, - ramp_up_strategy=args.ramp_up_strategy, - ramp_up_start_rps=args.ramp_up_start_rps, - ramp_up_end_rps=args.ramp_up_end_rps, - ) - ) - - # Save config and results to json - if args.save_result or args.append_result: - result_json: dict[str, Any] = {} - - # Setup - current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") - result_json["date"] = current_dt - result_json["backend"] = backend - result_json["model_id"] = model_id - result_json["tokenizer_id"] = tokenizer_id - result_json["num_prompts"] = args.num_prompts - - # Metadata - if args.metadata: - for item in args.metadata: - if "=" in item: - kvstring = item.split("=") - result_json[kvstring[0].strip()] = kvstring[1].strip() - else: - raise ValueError( - "Invalid metadata format. Please use KEY=VALUE format." - ) - # Traffic - result_json["request_rate"] = ( - args.request_rate if args.request_rate < float("inf") else "inf" - ) - result_json["burstiness"] = args.burstiness - result_json["max_concurrency"] = args.max_concurrency - - if args.ramp_up_strategy is not None: - result_json["ramp_up_strategy"] = args.ramp_up_strategy - result_json["ramp_up_start_rps"] = args.ramp_up_start_rps - result_json["ramp_up_end_rps"] = args.ramp_up_end_rps - - # Merge with benchmark result - result_json = {**result_json, **benchmark_result} - - if not args.save_detailed: - # Remove fields with too many data points - for field in [ - "input_lens", - "output_lens", - "ttfts", - "itls", - "generated_texts", - "errors", - ]: - if field in result_json: - del result_json[field] - if field in benchmark_result: - del benchmark_result[field] - - # Save to file - base_model_id = model_id.split("/")[-1] - max_concurrency_str = ( - f"-concurrency{args.max_concurrency}" - if args.max_concurrency is not None - else "" - ) - if args.ramp_up_strategy is not None: - file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa - else: - file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa - if args.result_filename: - file_name = args.result_filename - if args.result_dir: - os.makedirs(args.result_dir, exist_ok=True) - file_name = os.path.join(args.result_dir, file_name) - with open( - file_name, mode="a+" if args.append_result else "w", encoding="utf-8" - ) as outfile: - # Append a newline. - if args.append_result and outfile.tell() != 0: - outfile.write("\n") - json.dump(result_json, outfile) - save_to_pytorch_benchmark_format(args, result_json, file_name) - - -def create_argument_parser(): - parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput." - ) - parser.add_argument( - "--backend", - type=str, - default="vllm", - choices=list(ASYNC_REQUEST_FUNCS.keys()), - ) - parser.add_argument( - "--base-url", - type=str, - default=None, - help="Server or API base url if not using http host and port.", - ) - # Use 127.0.0.1 here instead of localhost to force the use of ipv4 - parser.add_argument("--host", type=str, default="127.0.0.1") - parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--endpoint", - type=str, - default="/v1/completions", - help="API endpoint.", - ) - parser.add_argument( - "--dataset-name", - type=str, - default="sharegpt", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], - help="Name of the dataset to benchmark on.", - ) - parser.add_argument( - "--dataset-path", - type=str, - default=None, - help="Path to the sharegpt/sonnet dataset. " - "Or the huggingface dataset ID if using HF dataset.", - ) - parser.add_argument( - "--no-stream", - action="store_true", - help="Do not load the dataset in streaming mode.", - ) - parser.add_argument( - "--max-concurrency", - type=int, - default=None, - help="Maximum number of concurrent requests. This can be used " - "to help simulate an environment where a higher level component " - "is enforcing a maximum number of concurrent requests. While the " - "--request-rate argument controls the rate at which requests are " - "initiated, this argument will control how many are actually allowed " - "to execute at a time. This means that when used in combination, the " - "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.", - ) - - parser.add_argument( - "--model", - type=str, - required=True, - help="Name of the model.", - ) - parser.add_argument( - "--tokenizer", - type=str, - help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 - ) - parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument( - "--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.", - ) - parser.add_argument( - "--logprobs", - type=int, - default=None, - help=( - "Number of logprobs-per-token to compute & return as part of " - "the request. If unspecified, then either (1) if beam search " - "is disabled, no logprobs are computed & a single dummy " - "logprob is returned for each token; or (2) if beam search " - "is enabled 1 logprob per token is computed" - ), - ) - parser.add_argument( - "--request-rate", - type=float, - default=float("inf"), - help="Number of requests per second. If this is inf, " - "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process or gamma distribution " - "to synthesize the request arrival times.", - ) - parser.add_argument( - "--burstiness", - type=float, - default=1.0, - help="Burstiness factor of the request generation. " - "Only take effect when request_rate is not inf. " - "Default value is 1, which follows Poisson process. " - "Otherwise, the request intervals follow a gamma distribution. " - "A lower burstiness value (0 < burstiness < 1) results in more " - "bursty requests. A higher burstiness value (burstiness > 1) " - "results in a more uniform arrival of requests.", - ) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Trust remote code from huggingface", - ) - parser.add_argument( - "--disable-tqdm", - action="store_true", - help="Specify to disable tqdm progress bar.", - ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "VLLM_TORCH_PROFILER_DIR to enable profiler.", - ) - parser.add_argument( - "--save-result", - action="store_true", - help="Specify to save benchmark results to a json file", - ) - parser.add_argument( - "--save-detailed", - action="store_true", - help="When saving the results, whether to include per request " - "information such as response, error, ttfs, tpots, etc.", - ) - parser.add_argument( - "--append-result", - action="store_true", - help="Append the benchmark result to the existing json file.", - ) - parser.add_argument( - "--metadata", - metavar="KEY=VALUE", - nargs="*", - help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " - "for metadata of this run to be saved in the result JSON file " - "for record keeping purposes.", - ) - parser.add_argument( - "--result-dir", - type=str, - default=None, - help="Specify directory to save benchmark json results." - "If not specified, results are saved in the current directory.", - ) - parser.add_argument( - "--result-filename", - type=str, - default=None, - help="Specify the filename to save benchmark json results." - "If not specified, results will be saved in " - "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" - " format.", - ) - parser.add_argument( - "--ignore-eos", - action="store_true", - help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", - ) - parser.add_argument( - "--percentile-metrics", - type=str, - default="ttft,tpot,itl", - help="Comma-separated list of selected metrics to report percentils. " - "This argument specifies the metrics to report percentiles. " - 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' - 'Default value is "ttft,tpot,itl".', - ) - parser.add_argument( - "--metric-percentiles", - type=str, - default="99", - help="Comma-separated list of percentiles for selected metrics. " - 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' - 'Default value is "99". ' - 'Use "--percentile-metrics" to select metrics.', - ) - parser.add_argument( - "--goodput", - nargs="+", - required=False, - help='Specify service level objectives for goodput as "KEY:VALUE" ' - "pairs, where the key is a metric name, and the value is in " - 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' - "separated by spaces. Allowed request level metric names are " - '"ttft", "tpot", "e2el". For more context on the definition of ' - "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve", - ) - parser.add_argument( - "--request-id-prefix", - type=str, - required=False, - default="benchmark-serving", - help="Specify the prefix of request id.", - ) - - # group for dataset specific arguments - custom_group = parser.add_argument_group("custom dataset options") - custom_group.add_argument( - "--custom-output-len", - type=int, - default=256, - help="Number of output tokens per request, used only for custom dataset.", - ) - custom_group.add_argument( - "--custom-skip-chat-template", - action="store_true", - help="Skip applying chat template to prompt, used only for custom dataset.", - ) - - sonnet_group = parser.add_argument_group("sonnet dataset options") - sonnet_group.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help="Number of input tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help="Number of output tokens per request, used only for sonnet dataset.", - ) - sonnet_group.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help="Number of prefix tokens per request, used only for sonnet dataset.", - ) - - sharegpt_group = parser.add_argument_group("sharegpt dataset options") - sharegpt_group.add_argument( - "--sharegpt-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.", - ) - - random_group = parser.add_argument_group("random dataset options") - random_group.add_argument( - "--random-input-len", - type=int, - default=1024, - help="Number of input tokens per request, used only for random sampling.", - ) - random_group.add_argument( - "--random-output-len", - type=int, - default=128, - help="Number of output tokens per request, used only for random sampling.", - ) - random_group.add_argument( - "--random-range-ratio", - type=float, - default=0.0, - help="Range ratio for sampling input/output length, " - "used only for random sampling. Must be in the range [0, 1) to define " - "a symmetric sampling range" - "[length * (1 - range_ratio), length * (1 + range_ratio)].", - ) - random_group.add_argument( - "--random-prefix-len", - type=int, - default=0, - help=( - "Number of fixed prefix tokens before the random context " - "in a request. " - "The total input length is the sum of `random-prefix-len` and " - "a random " - "context length sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]." - ), - ) - - hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument( - "--hf-subset", type=str, default=None, help="Subset of the HF dataset." - ) - hf_group.add_argument( - "--hf-split", type=str, default=None, help="Split of the HF dataset." - ) - hf_group.add_argument( - "--hf-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output lengths " - "from the sampled HF dataset.", - ) - - sampling_group = parser.add_argument_group("sampling parameters") - sampling_group.add_argument( - "--top-p", - type=float, - default=None, - help="Top-p sampling parameter. Only has effect on openai-compatible backends.", - ) - sampling_group.add_argument( - "--top-k", - type=int, - default=None, - help="Top-k sampling parameter. Only has effect on openai-compatible backends.", - ) - sampling_group.add_argument( - "--min-p", - type=float, - default=None, - help="Min-p sampling parameter. Only has effect on openai-compatible backends.", - ) - sampling_group.add_argument( - "--temperature", - type=float, - default=None, - help="Temperature sampling parameter. Only has effect on " - "openai-compatible backends. If not specified, default to greedy " - "decoding (i.e. temperature==0.0).", - ) - - parser.add_argument( - "--tokenizer-mode", - type=str, - default="auto", - choices=["auto", "slow", "mistral", "custom"], - help='The tokenizer mode.\n\n* "auto" will use the ' - 'fast tokenizer if available.\n* "slow" will ' - "always use the slow tokenizer. \n* " - '"mistral" will always use the `mistral_common` tokenizer. \n*' - '"custom" will use --tokenizer to select the preregistered tokenizer.', - ) - - parser.add_argument( - "--served-model-name", - type=str, - default=None, - help="The model name used in the API. " - "If not specified, the model name will be the " - "same as the ``--model`` argument. ", - ) - - parser.add_argument( - "--lora-modules", - nargs="+", - default=None, - help="A subset of LoRA module names passed in when " - "launching the server. For each request, the " - "script chooses a LoRA module at random.", - ) - - parser.add_argument( - "--ramp-up-strategy", - type=str, - default=None, - choices=["linear", "exponential"], - help="The ramp-up strategy. This would be used to " - "ramp up the request rate from initial RPS to final " - "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). " - "over the duration of the benchmark.", - ) - parser.add_argument( - "--ramp-up-start-rps", - type=int, - default=None, - help="The starting request rate for ramp-up (RPS). " - "Needs to be specified when --ramp-up-strategy is used.", - ) - parser.add_argument( - "--ramp-up-end-rps", - type=int, - default=None, - help="The ending request rate for ramp-up (RPS). " - "Needs to be specified when --ramp-up-strategy is used.", - ) - - return parser - +import sys if __name__ == "__main__": - parser = create_argument_parser() - args = parser.parse_args() - main(args) + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench serve + +For help with the new command, run: + vllm bench serve --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench serve --help +""") + sys.exit(1) diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index ca6843a72aa36..73b4aa5a87e07 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -696,11 +696,11 @@ def evaluate(ret, args): return re.match(args.regex, actual) is not None def _eval_correctness(expected, actual): - if args.structure_type == "guided_json": + if args.structure_type == "json": return _eval_correctness_json(expected, actual) - elif args.structure_type == "guided_regex": + elif args.structure_type == "regex": return _eval_correctness_regex(expected, actual) - elif args.structure_type == "guided_choice": + elif args.structure_type == "choice": return _eval_correctness_choice(expected, actual) else: return None @@ -780,18 +780,18 @@ def main(args: argparse.Namespace): ) if args.dataset == "grammar": - args.structure_type = "guided_grammar" + args.structure_type = "grammar" elif args.dataset == "regex": - args.structure_type = "guided_regex" + args.structure_type = "regex" elif args.dataset == "choice": - args.structure_type = "guided_choice" + args.structure_type = "choice" else: - args.structure_type = "guided_json" + args.structure_type = "json" if args.no_structured_output: args.structured_output_ratio = 0 if args.save_results: - result_file_name = f"{args.structured_output_ratio}guided" + result_file_name = f"{args.structured_output_ratio}so" result_file_name += f"_{backend}" result_file_name += f"_{args.request_rate}qps" result_file_name += f"_{args.model.split('/')[-1]}" @@ -998,7 +998,7 @@ def create_argument_parser(): "--percentile-metrics", type=str, default="ttft,tpot,itl", - help="Comma-separated list of selected metrics to report percentils. " + help="Comma-separated list of selected metrics to report percentiles. " "This argument specifies the metrics to report percentiles. " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' 'Default value is "ttft,tpot,itl".', diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 6b24b8c8f3c67..b6dc0918fd4d1 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,741 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Benchmark offline inference throughput.""" - -import argparse -import dataclasses -import json -import os -import random -import time -import warnings -from typing import Any, Optional, Union - -import torch -import uvloop -from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase -from typing_extensions import deprecated - -from benchmark_dataset import ( - AIMODataset, - BurstGPTDataset, - ConversationDataset, - InstructCoderDataset, - RandomDataset, - SampleRequest, - ShareGPTDataset, - SonnetDataset, - VisionArenaDataset, -) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args, -) -from vllm.inputs import TextPrompt, TokensPrompt -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput -from vllm.sampling_params import BeamSearchParams -from vllm.utils import FlexibleArgumentParser, merge_async_iterators - - -def run_vllm( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False, -) -> tuple[float, Optional[list[RequestOutput]]]: - from vllm import LLM, SamplingParams - - llm = LLM(**dataclasses.asdict(engine_args)) - assert all( - llm.llm_engine.model_config.max_model_len - >= (request.prompt_len + request.expected_output_len) - for request in requests - ), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests." - ) - # Add the requests to the engine. - prompts: list[Union[TextPrompt, TokensPrompt]] = [] - sampling_params: list[SamplingParams] = [] - for request in requests: - prompts.append( - TokensPrompt( - prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data, - ) - if "prompt_token_ids" in request.prompt - else TextPrompt( - prompt=request.prompt, multi_modal_data=request.multi_modal_data - ) - ) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - detokenize=not disable_detokenize, - ) - ) - lora_requests: Optional[list[LoRARequest]] = None - if engine_args.enable_lora: - lora_requests = [request.lora_request for request in requests] - - use_beam_search = False - - outputs = None - if not use_beam_search: - start = time.perf_counter() - outputs = llm.generate( - prompts, sampling_params, lora_request=lora_requests, use_tqdm=True - ) - end = time.perf_counter() - else: - assert lora_requests is None, "BeamSearch API does not support LoRA" - # output_len should be the same for all requests. - output_len = requests[0].expected_output_len - for request in requests: - assert request.expected_output_len == output_len - start = time.perf_counter() - llm.beam_search( - prompts, - BeamSearchParams( - beam_width=n, - max_tokens=output_len, - ignore_eos=True, - ), - ) - end = time.perf_counter() - return end - start, outputs - - -def run_vllm_chat( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False, -) -> tuple[float, list[RequestOutput]]: - """ - Run vLLM chat benchmark. This function is recommended ONLY for benchmarking - multimodal models as it properly handles multimodal inputs and chat - formatting. For non-multimodal models, use run_vllm() instead. - """ - from vllm import LLM, SamplingParams - - llm = LLM(**dataclasses.asdict(engine_args)) - - assert all( - llm.llm_engine.model_config.max_model_len - >= (request.prompt_len + request.expected_output_len) - for request in requests - ), ( - "Please ensure that max_model_len is greater than the sum of " - "prompt_len and expected_output_len for all requests." - ) - - prompts = [] - sampling_params: list[SamplingParams] = [] - for request in requests: - prompts.append(request.prompt) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - detokenize=not disable_detokenize, - ) - ) - start = time.perf_counter() - outputs = llm.chat(prompts, sampling_params, use_tqdm=True) - end = time.perf_counter() - return end - start, outputs - - -async def run_vllm_async( - requests: list[SampleRequest], - n: int, - engine_args: AsyncEngineArgs, - disable_frontend_multiprocessing: bool = False, - disable_detokenize: bool = False, -) -> float: - from vllm import SamplingParams - - async with build_async_engine_client_from_engine_args( - engine_args, - disable_frontend_multiprocessing=disable_frontend_multiprocessing, - ) as llm: - model_config = await llm.get_model_config() - assert all( - model_config.max_model_len - >= (request.prompt_len + request.expected_output_len) - for request in requests - ), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests." - ) - - # Add the requests to the engine. - prompts: list[Union[TextPrompt, TokensPrompt]] = [] - sampling_params: list[SamplingParams] = [] - lora_requests: list[Optional[LoRARequest]] = [] - for request in requests: - prompts.append( - TokensPrompt( - prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data, - ) - if "prompt_token_ids" in request.prompt - else TextPrompt( - prompt=request.prompt, multi_modal_data=request.multi_modal_data - ) - ) - sampling_params.append( - SamplingParams( - n=n, - temperature=1.0, - top_p=1.0, - ignore_eos=True, - max_tokens=request.expected_output_len, - detokenize=not disable_detokenize, - ) - ) - lora_requests.append(request.lora_request) - - generators = [] - start = time.perf_counter() - for i, (prompt, sp, lr) in enumerate( - zip(prompts, sampling_params, lora_requests) - ): - generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") - generators.append(generator) - all_gens = merge_async_iterators(*generators) - async for i, res in all_gens: - pass - end = time.perf_counter() - return end - start - - -def run_hf( - requests: list[SampleRequest], - model: str, - tokenizer: PreTrainedTokenizerBase, - n: int, - max_batch_size: int, - trust_remote_code: bool, - disable_detokenize: bool = False, -) -> float: - llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code - ) - if llm.config.model_type == "llama": - # To enable padding in the HF backend. - tokenizer.pad_token = tokenizer.eos_token - llm = llm.cuda() - - pbar = tqdm(total=len(requests)) - start = time.perf_counter() - batch: list[str] = [] - max_prompt_len = 0 - max_output_len = 0 - for i in range(len(requests)): - prompt = requests[i].prompt - prompt_len = requests[i].prompt_len - output_len = requests[i].expected_output_len - # Add the prompt to the batch. - batch.append(prompt) - max_prompt_len = max(max_prompt_len, prompt_len) - max_output_len = max(max_output_len, output_len) - if len(batch) < max_batch_size and i != len(requests) - 1: - # Check if we can add more requests to the batch. - next_prompt_len = requests[i + 1].prompt_len - next_output_len = requests[i + 1].expected_output_len - if ( - max(max_prompt_len, next_prompt_len) - + max(max_output_len, next_output_len) - ) <= 2048: - # We can add more requests to the batch. - continue - - # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids - llm_outputs = llm.generate( - input_ids=input_ids.cuda(), - do_sample=True, - num_return_sequences=n, - temperature=1.0, - top_p=1.0, - use_cache=True, - max_new_tokens=max_output_len, - ) - if not disable_detokenize: - # Include the decoding time. - tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) - pbar.update(len(batch)) - - # Clear the batch. - batch = [] - max_prompt_len = 0 - max_output_len = 0 - end = time.perf_counter() - return end - start - - -def run_mii( - requests: list[SampleRequest], - model: str, - tensor_parallel_size: int, - output_len: int, -) -> float: - from mii import client, serve - - llm = serve(model, tensor_parallel=tensor_parallel_size) - prompts = [request.prompt for request in requests] - - start = time.perf_counter() - llm.generate(prompts, max_new_tokens=output_len) - end = time.perf_counter() - client = client(model) - client.terminate_server() - return end - start - - -def save_to_pytorch_benchmark_format( - args: argparse.Namespace, results: dict[str, Any] -) -> None: - pt_records = convert_to_pytorch_benchmark_format( - args=args, - metrics={ - "requests_per_second": [results["requests_per_second"]], - "tokens_per_second": [results["tokens_per_second"]], - }, - extra_info={ - k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"] - }, - ) - if pt_records: - # Don't use json suffix here as we don't want CI to pick it up - pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" - write_to_json(pt_file, pt_records) - - -def get_requests(args, tokenizer): - # Common parameters for all dataset types. - common_kwargs = { - "dataset_path": args.dataset_path, - "random_seed": args.seed, - } - sample_kwargs = { - "tokenizer": tokenizer, - "lora_path": args.lora_path, - "max_loras": args.max_loras, - "num_requests": args.num_prompts, - "input_len": args.input_len, - "output_len": args.output_len, - } - - if args.dataset_path is None or args.dataset_name == "random": - sample_kwargs["range_ratio"] = args.random_range_ratio - sample_kwargs["prefix_len"] = args.prefix_len - dataset_cls = RandomDataset - elif args.dataset_name == "sharegpt": - dataset_cls = ShareGPTDataset - if args.backend == "vllm-chat": - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_name == "sonnet": - assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset." - ) - dataset_cls = SonnetDataset - sample_kwargs["prefix_len"] = args.prefix_len - sample_kwargs["return_prompt_formatted"] = True - elif args.dataset_name == "burstgpt": - dataset_cls = BurstGPTDataset - elif args.dataset_name == "hf": - common_kwargs["no_stream"] = args.no_stream - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = VisionArenaDataset - common_kwargs["dataset_subset"] = None - common_kwargs["dataset_split"] = "train" - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = InstructCoderDataset - common_kwargs["dataset_split"] = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: - dataset_cls = ConversationDataset - common_kwargs["dataset_subset"] = args.hf_subset - common_kwargs["dataset_split"] = args.hf_split - sample_kwargs["enable_multimodal_chat"] = True - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: - dataset_cls = AIMODataset - common_kwargs["dataset_subset"] = None - common_kwargs["dataset_split"] = "train" - else: - raise ValueError(f"Unknown dataset name: {args.dataset_name}") - # Remove None values - sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} - return dataset_cls(**common_kwargs).sample(**sample_kwargs) - - -@deprecated( - "benchmark_throughput.py is deprecated and will be removed in a " - "future version. Please use 'vllm bench throughput' instead.", -) -def main(args: argparse.Namespace): - if args.seed is None: - args.seed = 0 - print(args) - random.seed(args.seed) - # Sample the requests. - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code - ) - requests = get_requests(args, tokenizer) - is_multi_modal = any(request.multi_modal_data is not None for request in requests) - request_outputs: Optional[list[RequestOutput]] = None - if args.backend == "vllm": - if args.async_engine: - elapsed_time = uvloop.run( - run_vllm_async( - requests, - args.n, - AsyncEngineArgs.from_cli_args(args), - args.disable_frontend_multiprocessing, - args.disable_detokenize, - ) - ) - else: - elapsed_time, request_outputs = run_vllm( - requests, - args.n, - EngineArgs.from_cli_args(args), - args.disable_detokenize, - ) - elif args.backend == "hf": - assert args.tensor_parallel_size == 1 - elapsed_time = run_hf( - requests, - args.model, - tokenizer, - args.n, - args.hf_max_batch_size, - args.trust_remote_code, - args.disable_detokenize, - ) - elif args.backend == "mii": - elapsed_time = run_mii( - requests, args.model, args.tensor_parallel_size, args.output_len - ) - elif args.backend == "vllm-chat": - elapsed_time, request_outputs = run_vllm_chat( - requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize - ) - else: - raise ValueError(f"Unknown backend: {args.backend}") - - if request_outputs: - # Note: with the vllm and vllm-chat backends, - # we have request_outputs, which we use to count tokens. - total_prompt_tokens = 0 - total_output_tokens = 0 - for ro in request_outputs: - if not isinstance(ro, RequestOutput): - continue - total_prompt_tokens += ( - len(ro.prompt_token_ids) if ro.prompt_token_ids else 0 - ) - total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o) - total_num_tokens = total_prompt_tokens + total_output_tokens - else: - total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests) - total_output_tokens = sum(r.expected_output_len for r in requests) - total_prompt_tokens = total_num_tokens - total_output_tokens - - if is_multi_modal and args.backend != "vllm-chat": - print( - "\033[91mWARNING\033[0m: Multi-modal request with " - f"{args.backend} backend detected. The " - "following metrics are not accurate because image tokens are not" - " counted. See vllm-project/vllm/issues/9778 for details." - ) - # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. - # vllm-chat backend counts the image tokens now - - print( - f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s" - ) - print(f"Total num prompt tokens: {total_prompt_tokens}") - print(f"Total num output tokens: {total_output_tokens}") - - # Output JSON results if specified - if args.output_json: - results = { - "elapsed_time": elapsed_time, - "num_requests": len(requests), - "total_num_tokens": total_num_tokens, - "requests_per_second": len(requests) / elapsed_time, - "tokens_per_second": total_num_tokens / elapsed_time, - } - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - save_to_pytorch_benchmark_format(args, results) - - -def validate_args(args): - """ - Validate command-line arguments. - """ - - # === Deprecation and Defaulting === - if args.dataset is not None: - warnings.warn( - "The '--dataset' argument will be deprecated in the next release. " - "Please use '--dataset-name' and '--dataset-path' instead.", - stacklevel=2, - ) - args.dataset_path = args.dataset - - if not getattr(args, "tokenizer", None): - args.tokenizer = args.model - - # === Backend Validation === - valid_backends = {"vllm", "hf", "mii", "vllm-chat"} - if args.backend not in valid_backends: - raise ValueError(f"Unsupported backend: {args.backend}") - - # === Dataset Configuration === - if not args.dataset and not args.dataset_path: - print("When dataset path is not set, it will default to random dataset") - args.dataset_name = "random" - if args.input_len is None: - raise ValueError("input_len must be provided for a random dataset") - - # === Dataset Name Specific Checks === - # --hf-subset and --hf-split: only used - # when dataset_name is 'hf' - if args.dataset_name != "hf" and ( - getattr(args, "hf_subset", None) is not None - or getattr(args, "hf_split", None) is not None - ): - warnings.warn( - "--hf-subset and --hf-split will be ignored \ - since --dataset-name is not 'hf'.", - stacklevel=2, - ) - elif args.dataset_name == "hf": - if args.dataset_path in ( - VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() - | ConversationDataset.SUPPORTED_DATASET_PATHS - ): - assert args.backend == "vllm-chat", ( - f"{args.dataset_path} needs to use vllm-chat as the backend." - ) # noqa: E501 - elif args.dataset_path in ( - InstructCoderDataset.SUPPORTED_DATASET_PATHS - | AIMODataset.SUPPORTED_DATASET_PATHS - ): - assert args.backend == "vllm", ( - f"{args.dataset_path} needs to use vllm as the backend." - ) # noqa: E501 - else: - raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") - - # --random-range-ratio: only used when dataset_name is 'random' - if args.dataset_name != "random" and args.random_range_ratio is not None: - warnings.warn( - "--random-range-ratio will be ignored since \ - --dataset-name is not 'random'.", - stacklevel=2, - ) - - # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not - # set. - if ( - args.dataset_name not in {"random", "sonnet", None} - and args.prefix_len is not None - ): - warnings.warn( - "--prefix-len will be ignored since --dataset-name\ - is not 'random', 'sonnet', or not set.", - stacklevel=2, - ) - - # === LoRA Settings === - if getattr(args, "enable_lora", False) and args.backend != "vllm": - raise ValueError("LoRA benchmarking is only supported for vLLM backend") - if getattr(args, "enable_lora", False) and args.lora_path is None: - raise ValueError("LoRA path must be provided when enable_lora is True") - - # === Backend-specific Validations === - if args.backend == "hf" and args.hf_max_batch_size is None: - raise ValueError("HF max batch size is required for HF backend") - if args.backend != "hf" and args.hf_max_batch_size is not None: - raise ValueError("HF max batch size is only for HF backend.") - - if ( - args.backend in {"hf", "mii"} - and getattr(args, "quantization", None) is not None - ): - raise ValueError("Quantization is only for vLLM backend.") - - if args.backend == "mii" and args.dtype != "auto": - raise ValueError("dtype must be auto for MII backend.") - if args.backend == "mii" and args.n != 1: - raise ValueError("n must be 1 for MII backend.") - if args.backend == "mii" and args.tokenizer != args.model: - raise ValueError("Tokenizer must be the same as the model for MII backend.") - - # --data-parallel is not supported currently. - # https://github.com/vllm-project/vllm/issues/16222 - if args.data_parallel_size > 1: - raise ValueError( - "Data parallel is not supported in offline benchmark, " - "please use benchmark serving instead" - ) - - -def create_argument_parser(): - parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument( - "--backend", - type=str, - choices=["vllm", "hf", "mii", "vllm-chat"], - default="vllm", - ) - parser.add_argument( - "--dataset-name", - type=str, - choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], - help="Name of the dataset to benchmark on.", - default="sharegpt", - ) - parser.add_argument( - "--no-stream", - action="store_true", - help="Do not load the dataset in streaming mode.", - ) - parser.add_argument( - "--dataset", - type=str, - default=None, - help="Path to the ShareGPT dataset, will be deprecated in\ - the next release. The dataset is expected to " - "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: ]]]]", - ) - parser.add_argument( - "--dataset-path", type=str, default=None, help="Path to the dataset" - ) - parser.add_argument( - "--input-len", - type=int, - default=None, - help="Input prompt length for each request", - ) - parser.add_argument( - "--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.", - ) - parser.add_argument( - "--n", type=int, default=1, help="Number of generated sequences per prompt." - ) - parser.add_argument( - "--num-prompts", type=int, default=1000, help="Number of prompts to process." - ) - parser.add_argument( - "--hf-max-batch-size", - type=int, - default=None, - help="Maximum batch size for HF backend.", - ) - parser.add_argument( - "--output-json", - type=str, - default=None, - help="Path to save the throughput results in JSON format.", - ) - parser.add_argument( - "--async-engine", - action="store_true", - default=False, - help="Use vLLM async engine rather than LLM class.", - ) - parser.add_argument( - "--disable-frontend-multiprocessing", - action="store_true", - default=False, - help="Disable decoupled async engine frontend.", - ) - parser.add_argument( - "--disable-detokenize", - action="store_true", - help=( - "Do not detokenize the response (i.e. do not include " - "detokenization time in the measurement)" - ), - ) - # LoRA - parser.add_argument( - "--lora-path", - type=str, - default=None, - help="Path to the LoRA adapters to use. This can be an absolute path, " - "a relative path, or a Hugging Face model identifier.", - ) - parser.add_argument( - "--prefix-len", - type=int, - default=None, - help=f"Number of prefix tokens to be used in RandomDataset " - "and SonnetDataset. For RandomDataset, the total input " - "length is the sum of prefix-len (default: " - f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length " - "sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]. For SonnetDataset, " - f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) " - "controls how much of the input is fixed lines versus " - "random lines, but the total input length remains approximately " - "input_len tokens.", - ) - # random dataset - parser.add_argument( - "--random-range-ratio", - type=float, - default=None, - help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) " - "for sampling input/output length, " - "used only for RandomDataset. Must be in the range [0, 1) to " - "define a symmetric sampling range " - "[length * (1 - range_ratio), length * (1 + range_ratio)].", - ) - - # hf dtaset - parser.add_argument( - "--hf-subset", type=str, default=None, help="Subset of the HF dataset." - ) - parser.add_argument( - "--hf-split", type=str, default=None, help="Split of the HF dataset." - ) - - parser = AsyncEngineArgs.add_cli_args(parser) - - return parser - +import sys if __name__ == "__main__": - parser = create_argument_parser() - args = parser.parse_args() - if args.tokenizer is None: - args.tokenizer = args.model - validate_args(args) - main(args) + print("""DEPRECATED: This script has been moved to the vLLM CLI. + +Please use the following command instead: + vllm bench throughput + +For help with the new command, run: + vllm bench throughput --help + +Alternatively, you can run the new command directly with: + python -m vllm.entrypoints.cli.main bench throughput --help +""") + sys.exit(1) diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh index 92f97ffabea2a..2c72941cf7e51 100644 --- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -62,7 +62,7 @@ benchmark() { --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & CUDA_VISIBLE_DEVICES=1 python3 \ @@ -72,7 +72,7 @@ benchmark() { --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & wait_for_server 8100 wait_for_server 8200 diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh index af2bcba3ea57a..0bbf7cd2b1c81 100644 --- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -69,7 +69,7 @@ launch_disagg_prefill() { --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & CUDA_VISIBLE_DEVICES=1 python3 \ -m vllm.entrypoints.openai.api_server \ @@ -78,7 +78,7 @@ launch_disagg_prefill() { --max-model-len 10000 \ --gpu-memory-utilization 0.6 \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & wait_for_server 8100 wait_for_server 8200 diff --git a/benchmarks/kernels/bench_block_fp8_gemm.py b/benchmarks/kernels/bench_block_fp8_gemm.py index 9663503e9baa0..f1e504499eaf6 100644 --- a/benchmarks/kernels/bench_block_fp8_gemm.py +++ b/benchmarks/kernels/bench_block_fp8_gemm.py @@ -4,7 +4,10 @@ import torch from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul, + apply_w8a8_block_fp8_linear, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_BLOCK_FP8_SUPPORTED, ) from vllm.platforms import current_platform from vllm.triton_utils import triton as vllm_triton @@ -29,7 +32,7 @@ DEEPSEEK_V3_SHAPES = [ ] -def build_w8a8_block_fp8_runner(M, N, K, block_size, device): +def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass): """Build runner function for w8a8 block fp8 matmul.""" factor_for_scale = 1e-2 @@ -37,37 +40,54 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device): fp8_max, fp8_min = fp8_info.max, fp8_info.min # Create random FP8 tensors - A_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max - B_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max - B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max + B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) # Create scales block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale Bs = ( torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device) * factor_for_scale ) + # SM90 CUTLASS requires row-major format for scales + if use_cutlass and current_platform.is_device_capability(90): + Bs = Bs.T.contiguous() + def run(): - return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, torch.bfloat16) + if use_cutlass: + return apply_w8a8_block_fp8_linear( + A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True + ) + else: + return apply_w8a8_block_fp8_linear( + A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False + ) return run +# Determine available providers +available_providers = ["torch-bf16", "w8a8-block-fp8-triton"] +plot_title = "BF16 vs W8A8 Block FP8 GEMMs" + +if CUTLASS_BLOCK_FP8_SUPPORTED: + available_providers.append("w8a8-block-fp8-cutlass") + + @vllm_triton.testing.perf_report( vllm_triton.testing.Benchmark( x_names=["batch_size"], x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], x_log=False, line_arg="provider", - line_vals=["torch-bf16", "w8a8-block-fp8"], - line_names=["torch-bf16", "w8a8-block-fp8"], + line_vals=available_providers, + line_names=available_providers, ylabel="TFLOP/s (larger is better)", plot_name="BF16 vs W8A8 Block FP8 GEMMs", args={}, @@ -85,11 +105,22 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)): ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( lambda: torch.nn.functional.linear(a, b), quantiles=quantiles ) - else: # w8a8-block-fp8 - run_w8a8 = build_w8a8_block_fp8_runner(M, N, K, block_size, device) - ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( - lambda: run_w8a8(), quantiles=quantiles + elif provider == "w8a8-block-fp8-triton": + run_w8a8_triton = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=False ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_triton(), quantiles=quantiles + ) + elif provider == "w8a8-block-fp8-cutlass": + run_w8a8_cutlass = build_w8a8_block_fp8_runner( + M, N, K, block_size, device, use_cutlass=True + ) + ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph( + lambda: run_w8a8_cutlass(), quantiles=quantiles + ) + else: + raise ValueError(f"Unknown provider: {provider}") to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 923d678f1f2db..9170361e974b6 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -2,14 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from typing import Callable +from unittest.mock import patch +import pandas as pd import torch -from vllm import _custom_ops as ops -from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +def with_triton_mode(fn): + """Temporarily force the Triton fallback path""" + + def wrapped(*args, **kwargs): + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + return fn(*args, **kwargs) + + return wrapped # TODO(luka): use standalone_compile utility @@ -21,78 +32,236 @@ def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int): return inner -torch._dynamo.config.recompile_limit = 8888 -compilation_config = CompilationConfig(custom_ops=["none"]) -with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): - torch_per_token_quant_fp8 = torch.compile( - QuantFP8(False, GroupShape.PER_TOKEN), - fullgraph=True, - dynamic=False, # recompile for different shapes - ) +def bench_compile(fn: Callable): + # recompile for different shapes + fwd = torch.compile(fn, fullgraph=True, dynamic=False) # First dim is explicitly dynamic to simulate vLLM usage - torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) + return with_dyn_arg(fwd, 0, 0) -def cuda_per_token_quant_fp8( - input: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - return ops.scaled_fp8_quant(input) +torch._dynamo.config.recompile_limit = 8888 -def calculate_diff(batch_size: int, seq_len: int): - """Calculate difference between Triton and CUDA implementations.""" +def calculate_diff( + batch_size: int, + hidden_size: int, + group_shape: GroupShape, + dtype: torch.dtype, +): + """Calculate the difference between Inductor and CUDA implementations.""" device = torch.device("cuda") - x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) + x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device) - torch_out, torch_scale = torch_per_token_quant_fp8(x) - cuda_out, cuda_scale = cuda_per_token_quant_fp8(x) + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False) - if torch.allclose( - cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 - ) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): + torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x) + torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x) + cuda_out, cuda_scale = quant_fp8.forward_cuda(x) + + out_allclose = lambda o1, o2: torch.allclose( + o1.to(torch.float32), + o2.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) + scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5) + + if ( + out_allclose(cuda_out, torch_out) + and scale_allclose(cuda_scale, torch_scale) + and out_allclose(cuda_out, torch_eager_out) + and scale_allclose(cuda_scale, torch_eager_scale) + ): print("✅ All implementations match") else: print("❌ Implementations differ") -batch_size_range = [1, 16, 32, 64, 128] -seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] - -configs = list(itertools.product(batch_size_range, seq_len_range)) +configs = [] -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size", "seq_len"], - x_vals=configs, - line_arg="provider", - line_vals=["torch", "cuda"], - line_names=["Torch", "CUDA"], - styles=[("blue", "-"), ("green", "-")], - ylabel="us", - plot_name="per-token-dynamic-quant-fp8-performance", - args={}, - ) -) -def benchmark_quantization(batch_size, seq_len, provider): - dtype = torch.float16 +def benchmark_quantization( + batch_size, + hidden_size, + provider, + group_shape: GroupShape, + col_major: bool, + dtype: torch.dtype, +): device = torch.device("cuda") - x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] + quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major) if provider == "torch": - fn = lambda: torch_per_token_quant_fp8(x.clone()) + fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone()) elif provider == "cuda": - fn = lambda: cuda_per_token_quant_fp8(x.clone()) + fn = lambda: quant_fp8.forward_cuda(x.clone()) + elif provider == "triton": + if not group_shape.is_per_group(): + # Triton only supported for per-group + return 0, 0, 0 + + fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone()) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) return 1000 * ms, 1000 * max_ms, 1000 * min_ms +# TODO(luka) extract to utils +def compute_geomean_speedups( + df: pd.DataFrame, + baseline_col: str, + speedup_cols: list[str], + groupby_cols: list[str] | None = None, +) -> pd.DataFrame: + """ + Compute geometric mean speedups over a baseline column. + + Args: + df: Input dataframe + baseline_col: Column to use as baseline + speedup_cols: Columns to compute speedups for + groupby_cols: Columns to group by. If None, compute over entire df. + + Returns: + pd.DataFrame with geometric mean speedups + """ + from scipy.stats import gmean + + def geo_speedup(group: pd.DataFrame) -> pd.Series: + ratios = { + col: (group[baseline_col] / group[col]).values for col in speedup_cols + } + return pd.Series({col: gmean(vals) for col, vals in ratios.items()}) + + if groupby_cols is None: + result = geo_speedup(df).to_frame().T + else: + result = ( + df.groupby(groupby_cols) + .apply(geo_speedup, include_groups=False) + .reset_index() + ) + + return result + + if __name__ == "__main__": - calculate_diff(batch_size=4, seq_len=4096) - benchmark_quantization.run(print_data=True) + parser = FlexibleArgumentParser( + description="Benchmark the various implementations of QuantFP8 (dynamic-only)" + ) + parser.add_argument("-c", "--check", action="store_true") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) + parser.add_argument( + "--hidden-sizes", + type=int, + nargs="+", + default=None, + help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)", + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=None, + help="Batch sizes to benchmark (default: 1,16,32,64,128)", + ) + parser.add_argument( + "--group-sizes", + type=int, + nargs="+", + default=None, + help="Group sizes for GroupShape(1,N) to benchmark. " + "Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)", + ) + parser.add_argument( + "--no-column-major", + action="store_true", + help="Disable column-major scales testing", + ) + + args = parser.parse_args() + assert args + + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] + batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128] + + if args.group_sizes is not None: + group_shapes = [] + for size in args.group_sizes: + if size == 0: + group_shapes.append(GroupShape.PER_TENSOR) + elif size == -1: + group_shapes.append(GroupShape.PER_TOKEN) + else: + group_shapes.append(GroupShape(1, size)) + else: + group_shapes = [ + GroupShape.PER_TENSOR, + GroupShape.PER_TOKEN, + GroupShape(1, 64), + GroupShape(1, 128), + ] + + column_major_scales = [False] if args.no_column_major else [True, False] + + config_gen = itertools.product( + group_shapes, + column_major_scales, + batch_sizes, + hidden_sizes, + ) + + # filter out column-major scales for non-group, reverse order + configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1])) + + print(f"Running {len(configs)} configurations:") + print(f" Hidden sizes: {hidden_sizes}") + print(f" Batch sizes: {batch_sizes}") + print(f" Group shapes: {[str(g) for g in group_shapes]}") + print(f" Column major scales: {column_major_scales}") + print() + + if args.check: + for group_shape in group_shapes: + group_size = group_shape[1] + print(f"{group_size=}") + calculate_diff( + batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype + ) + + benchmark = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size", "col_major", "group_shape"], + x_vals=configs, + line_arg="provider", + line_vals=["torch", "cuda", "triton"], + line_names=["Torch (Compiled)", "CUDA", "Triton"], + styles=[("blue", "-"), ("green", "-"), ("black", "-")], + ylabel="us", + plot_name="QuantFP8 performance", + args={}, + ) + )(benchmark_quantization) + + df = benchmark.run(print_data=True, dtype=dtype, return_df=True) + + # Print geomean speedups + geo_table_grouped = compute_geomean_speedups( + df, + baseline_col="Torch (Compiled)", + speedup_cols=["CUDA", "Triton"], + groupby_cols=["col_major", "group_shape"], + ) + + print("Speedup over Torch (Compiled)") + print(geo_table_grouped.to_string(index=False)) diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py new file mode 100644 index 0000000000000..93edbcc9391fc --- /dev/null +++ b/benchmarks/kernels/benchmark_activation.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# benchmark custom activation op performance +import itertools + +import torch + +import vllm.model_executor.layers.activation # noqa F401 +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + +batch_size_range = [1, 16, 32, 64, 128] +seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] +intermediate_size = [3072, 9728, 12288] +configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size)) + + +def benchmark_activation( + batch_size: int, + seq_len: int, + intermediate_size: int, + provider: str, + func_name: str, + dtype: torch.dtype, +): + device = "cuda" + num_tokens = batch_size * seq_len + dim = intermediate_size + current_platform.seed_everything(42) + torch.set_default_device(device) + + if func_name == "gelu_and_mul": + layer = CustomOp.op_registry[func_name](approximate="none") + elif func_name == "gelu_and_mul_tanh": + layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh") + elif func_name == "fatrelu_and_mul": + threshold = 0.5 + layer = CustomOp.op_registry[func_name](threshold) + else: + layer = CustomOp.op_registry[func_name]() + + x = torch.randn(num_tokens, dim, dtype=dtype, device=device) + compiled_layer = torch.compile(layer.forward_native) + + if provider == "custom": + fn = lambda: layer(x) + elif provider == "compiled": + fn = lambda: compiled_layer(x) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + fn, quantiles=[0.5, 0.2, 0.8] + ) + return ms, max_ms, min_ms + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the custom activation op.") + parser.add_argument( + "--func-name", + type=str, + choices=[ + "mul_and_silu", + "silu_and_mul", + "gelu_and_mul", + "gelu_and_mul_tanh", + "fatrelu_and_mul", + "swigluoai_and_mul", + "gelu_new", + "gelu_fast", + "quick_gelu", + ], + default="silu_and_mul", + ) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16" + ) + args = parser.parse_args() + assert args + + func_name = args.func_name + dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + perf_report = triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "intermediate_size"], + x_vals=configs, + line_arg="provider", + line_vals=["custom", "compiled"], + line_names=["Custom OP", "Compiled"], + styles=[("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"{func_name}-op-performance", + args={}, + ) + ) + + perf_report( + lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation( + batch_size, seq_len, intermediate_size, provider, func_name, dtype + ) + ).run(print_data=True) diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py index 35c20ee41b9a9..726a2a371d109 100644 --- a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config, + nvfp4_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.scalar_type import scalar_types @@ -140,6 +144,12 @@ def bench_run( a_fp8_scale: torch.Tensor, num_repeats: int, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + for _ in range(num_repeats): fused_experts( a, @@ -147,10 +157,7 @@ def bench_run( w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_fp8_scale, + quant_config=quant_config, ) def run_cutlass_moe_fp4( @@ -172,25 +179,27 @@ def bench_run( device: torch.device, num_repeats: int, ): + quant_config = nvfp4_moe_quant_config( + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + g1_alphas=w1_gs, + g2_alphas=w2_gs, + ) for _ in range(num_repeats): with nvtx.annotate("cutlass_moe_fp4", color="green"): cutlass_moe_fp4( a=a, - a1_gscale=a1_gs, - a2_gscale=a2_gs, w1_fp4=w1_fp4, - w1_blockscale=w1_blockscale, - w1_alphas=w1_gs, w2_fp4=w2_fp4, - w2_blockscale=w2_blockscale, - w2_alphas=w2_gs, topk_weights=topk_weights, topk_ids=topk_ids, m=m, n=n, k=k, e=num_experts, - device=device, + quant_config=quant_config, ) def run_cutlass_from_graph( @@ -211,26 +220,29 @@ def bench_run( e: int, device: torch.device, ): + quant_config = nvfp4_moe_quant_config( + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + g1_alphas=w1_gs, + g2_alphas=w2_gs, + ) + with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): return cutlass_moe_fp4( a=a, - a1_gscale=a1_gs, w1_fp4=w1_fp4, - w1_blockscale=w1_blockscale, - w1_alphas=w1_alphas, - a2_gscale=a2_gs, w2_fp4=w2_fp4, - w2_blockscale=w2_blockscale, - w2_alphas=w2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, m=m, n=n, k=k, e=num_experts, - device=device, + quant_config=quant_config, ) def run_triton_from_graph( @@ -246,16 +258,18 @@ def bench_run( with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) return fused_experts( a, w1, w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_fp8_scale, + quant_config=quant_config, ) def replay_graph(graph, num_repeats): diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py new file mode 100644 index 0000000000000..a61c17edc1e28 --- /dev/null +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark script for device communicators: +CustomAllreduce (oneshot, twoshot), PyNcclCommunicator, +and SymmMemCommunicator (multimem, two-shot). + +Usage: + torchrun --nproc_per_node= benchmark_device_communicators.py [options] + +Example: + torchrun --nproc_per_node=2 benchmark_device_communicators.py + --sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100 +""" + +import json +import os +import time +from contextlib import nullcontext +from typing import Callable, Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + +# Default sequence lengths to benchmark +DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192] + +# Fixed hidden size and dtype for all benchmarks +HIDDEN_SIZE = 8192 +BENCHMARK_DTYPE = torch.bfloat16 + +# CUDA graph settings +CUDA_GRAPH_CAPTURE_CYCLES = 10 + + +class CommunicatorBenchmark: + """Benchmark class for testing device communicators.""" + + def __init__( + self, + rank: int, + world_size: int, + device: torch.device, + cpu_group: ProcessGroup, + sequence_lengths: list[int], + ): + self.rank = rank + self.world_size = world_size + self.device = device + self.cpu_group = cpu_group + + # Calculate max_size_override based on largest sequence length + max_seq_len = max(sequence_lengths) + max_tensor_elements = max_seq_len * HIDDEN_SIZE + self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1 + + # Initialize communicators + self.custom_allreduce = None + self.pynccl_comm = None + self.symm_mem_comm = None + self.symm_mem_comm_multimem = None + self.symm_mem_comm_two_shot = None + + self._init_communicators() + + def _init_communicators(self): + """Initialize all available communicators.""" + try: + self.custom_allreduce = CustomAllreduce( + group=self.cpu_group, + device=self.device, + max_size=self.max_size_override, + ) + if not self.custom_allreduce.disabled: + logger.info("Rank %s: CustomAllreduce initialized", self.rank) + else: + logger.info("Rank %s: CustomAllreduce disabled", self.rank) + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e + ) + self.custom_allreduce = None + + try: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, device=self.device + ) + if not self.pynccl_comm.disabled: + logger.info("Rank %s: PyNcclCommunicator initialized", self.rank) + else: + logger.info("Rank %s: PyNcclCommunicator disabled", self.rank) + self.pynccl_comm = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e + ) + self.pynccl_comm = None + + # Initialize variants for SymmMemCommunicator + try: + self.symm_mem_comm_multimem = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=True, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_multimem.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (multimem) initialized", self.rank + ) + else: + self.symm_mem_comm_multimem = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s", + self.rank, + e, + ) + self.symm_mem_comm_multimem = None + + try: + self.symm_mem_comm_two_shot = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + force_multimem=False, + max_size_override=self.max_size_override, + ) + if not self.symm_mem_comm_two_shot.disabled: + logger.info( + "Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank + ) + else: + self.symm_mem_comm_two_shot = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s", + self.rank, + e, + ) + self.symm_mem_comm_two_shot = None + + def benchmark_allreduce( + self, sequence_length: int, num_warmup: int, num_trials: int + ) -> dict[str, float]: + """Benchmark allreduce operations for all available communicators.""" + + results = {} + + # Define communicators with their benchmark functions + communicators = [] + + if self.custom_allreduce is not None: + comm = self.custom_allreduce + # CustomAllreduce one-shot + communicators.append( + ( + "ca_1stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + "1stage", # env variable value + ) + ) + # CustomAllreduce two-shot + communicators.append( + ( + "ca_2stage", + lambda t, c=comm: c.custom_all_reduce(t), + lambda t, c=comm: c.should_custom_ar(t), + comm.capture(), + "2stage", # env variable value + ) + ) + + if self.pynccl_comm is not None: + comm = self.pynccl_comm + communicators.append( + ( + "pynccl", + lambda t, c=comm: c.all_reduce(t), + lambda t: True, # Always available if initialized + nullcontext(), + None, # no env variable needed + ) + ) + + if self.symm_mem_comm_multimem is not None: + comm = self.symm_mem_comm_multimem + communicators.append( + ( + "symm_mem_multimem", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + None, # no env variable needed + ) + ) + + if self.symm_mem_comm_two_shot is not None: + comm = self.symm_mem_comm_two_shot + communicators.append( + ( + "symm_mem_two_shot", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_symm_mem(t), + nullcontext(), + None, # no env variable needed + ) + ) + + # Benchmark each communicator + for name, allreduce_fn, should_use_fn, context, env_var in communicators: + # Set environment variable if needed + if env_var is not None: + os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var + else: + # Clear the environment variable to avoid interference + os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None) + + latency = self.benchmark_allreduce_single( + sequence_length, + allreduce_fn, + should_use_fn, + context, + num_warmup, + num_trials, + ) + if latency is not None: + results[name] = latency + + return results + + def benchmark_allreduce_single( + self, + sequence_length: int, + allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]], + should_use_fn: Callable[[torch.Tensor], bool], + context, + num_warmup: int, + num_trials: int, + ) -> Optional[float]: + """Benchmark method with CUDA graph optimization.""" + try: + # Create test tensor (2D: sequence_length x hidden_size) + tensor = torch.randn( + sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device + ) + if not should_use_fn(tensor): + return None + + torch.cuda.synchronize() + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + graph_input = tensor.clone() + + # Warmup before capture + for _ in range(3): + allreduce_fn(graph_input) + + # Capture the graph using context manager + with context: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): + allreduce_fn(graph_input) + + torch.cuda.synchronize() + for _ in range(num_warmup): + graph.replay() + torch.cuda.synchronize() + + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(num_trials): + graph.replay() + torch.cuda.synchronize() + + end_time = time.perf_counter() + + # Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES + return ( + (end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000 + ) + + except Exception as e: + logger.error("CUDA graph benchmark failed: %s", e) + raise RuntimeError( + f"CUDA graph benchmark failed for communicator: {e}" + ) from e + + +def _calculate_speedup_info(comm_results: dict[str, float]) -> str: + """Calculate speedup information for a single tensor size.""" + if not comm_results: + return "N/A" + + # Find the fastest communicator + fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k]) + fastest_time = comm_results[fastest_comm] + + # Calculate speedup vs PyNccl if available + if "pynccl" in comm_results: + pynccl_time = comm_results["pynccl"] + speedup = pynccl_time / fastest_time + return f"{fastest_comm} ({speedup:.2f}x)" + else: + return f"{fastest_comm} (N/A)" + + +def print_results( + results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int +): + """Print benchmark results in a formatted table.""" + + print(f"\n{'=' * 130}") + print("Device Communicator Benchmark Results") + print( + f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, " + f"Hidden Size: {HIDDEN_SIZE}" + ) + print(f"{'=' * 130}") + + # Get all communicator names + all_comms = set() + for size_results in results.values(): + all_comms.update(size_results.keys()) + + all_comms = sorted(list(all_comms)) + + # Print header + header = f"{'Tensor Shape':<20}{'Tensor Size':<15}" + for comm in all_comms: + header += f"{comm:<20}" + header += f"{'Best (Speedup vs PyNccl)':<30}" + print(header) + print("-" * len(header)) + + # Print results for each sequence length + for seq_len in sequence_lengths: + if seq_len in results: + # Calculate tensor size in elements and bytes + tensor_elements = seq_len * HIDDEN_SIZE + tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize + + # Format tensor size (MB) + tensor_size_mb = tensor_bytes / (1024 * 1024) + tensor_size_str = f"{tensor_size_mb:.2f} MB" + + # Format tensor shape + tensor_shape = f"({seq_len}, {HIDDEN_SIZE})" + + row = f"{tensor_shape:<20}{tensor_size_str:<15}" + for comm in all_comms: + if comm in results[seq_len]: + row += f"{results[seq_len][comm]:<20.3f}" + else: + row += f"{'N/A':<20}" + + # Calculate speedup information + speedup_info = _calculate_speedup_info(results[seq_len]) + row += f"{speedup_info:<30}" + + print(row) + + print(f"{'=' * 130}") + print("All times are in milliseconds (ms) per allreduce operation") + print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)") + + +def main(): + parser = FlexibleArgumentParser(description="Benchmark device communicators") + + parser.add_argument( + "--sequence-lengths", + type=int, + nargs="+", + default=DEFAULT_SEQUENCE_LENGTHS, + help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)", + ) + + parser.add_argument( + "--num-warmup", type=int, default=5, help="Number of warmup iterations" + ) + + parser.add_argument( + "--num-trials", type=int, default=50, help="Number of benchmark trials" + ) + + parser.add_argument("--output-json", type=str, help="Output results to JSON file") + + args = parser.parse_args() + + # Initialize distributed + if not dist.is_initialized(): + dist.init_process_group(backend="gloo") + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Set device + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + # Get CPU process group + cpu_group = dist.new_group(backend="gloo") + + # Disable USE_SYMM_MEM to avoid affecting the max_sizes + # in symm_mem and custom_all_reduce for benchmark + os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" + + # Initialize benchmark + benchmark = CommunicatorBenchmark( + rank, world_size, device, cpu_group, args.sequence_lengths + ) + + # Run benchmarks + all_results = {} + + for seq_len in args.sequence_lengths: + if rank == 0: + logger.info( + "Benchmarking sequence length: %s (tensor shape: %s x %s)", + seq_len, + seq_len, + HIDDEN_SIZE, + ) + + results = benchmark.benchmark_allreduce( + sequence_length=seq_len, + num_warmup=args.num_warmup, + num_trials=args.num_trials, + ) + + all_results[seq_len] = results + + # Synchronize between ranks + dist.barrier() + + # Print results (only rank 0) + if rank == 0: + print_results(all_results, args.sequence_lengths, world_size) + + # Save to JSON if requested + if args.output_json: + # Add speedup information to results + enhanced_results = {} + for seq_len, comm_results in all_results.items(): + enhanced_results[seq_len] = { + "timings": comm_results, + "speedup_info": _calculate_speedup_info(comm_results), + } + + output_data = { + "world_size": world_size, + "dtype": str(BENCHMARK_DTYPE), + "hidden_size": HIDDEN_SIZE, + "sequence_lengths": args.sequence_lengths, + "num_warmup": args.num_warmup, + "num_trials": args.num_trials, + "cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES, + "results": enhanced_results, + } + + with open(args.output_json, "w") as f: + json.dump(output_data, f, indent=2) + + logger.info("Results saved to %s", args.output_json) + + # Cleanup + if cpu_group != dist.group.WORLD: + dist.destroy_process_group(cpu_group) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index a6b42406b5cb0..14330ae6f03c5 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -7,6 +7,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, @@ -96,6 +97,11 @@ def bench_run( a_scale: torch.Tensor, num_repeats: int, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) for _ in range(num_repeats): fused_experts( a, @@ -103,10 +109,7 @@ def bench_run( w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, + quant_config=quant_config, ) def run_cutlass_moe( @@ -125,6 +128,12 @@ def bench_run( per_act_token: bool, num_repeats: int, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token, + ) + for _ in range(num_repeats): cutlass_moe_fp8( a, @@ -132,14 +141,11 @@ def bench_run( w2, topk_weights, topk_ids, - w1_scale, - w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, - per_act_token, - a1_scale=None, + quant_config=quant_config, ) def run_cutlass_from_graph( @@ -156,6 +162,12 @@ def bench_run( topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token, + ) + with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): @@ -165,14 +177,11 @@ def bench_run( w2_q, topk_weights, topk_ids, - w1_scale, - w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, - per_act_token, - a1_scale=None, + quant_config=quant_config, ) def run_triton_from_graph( @@ -185,6 +194,11 @@ def bench_run( w2_scale: torch.Tensor, a_scale: torch.Tensor, ): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): @@ -194,10 +208,7 @@ def bench_run( w2, topk_weights, topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, + quant_config=quant_config, ) def replay_graph(graph, num_repeats): diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 3d38d4b3534e8..debb29744bfaa 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -464,7 +464,11 @@ class BenchmarkTensors: for field_name in LoRAKernelMeta.__dataclass_fields__: field = getattr(self.lora_kernel_meta, field_name) assert isinstance(field, torch.Tensor) - setattr(self.lora_kernel_meta, field_name, to_device(field)) + setattr( + self.lora_kernel_meta, + field_name, + to_device(field) if field_name != "no_lora_flag_cpu" else field, + ) def metadata(self) -> tuple[int, int, int]: """ @@ -512,6 +516,7 @@ class BenchmarkTensors: "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, "lora_ids": self.lora_kernel_meta.active_lora_ids, "scaling": 1.0, + "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: @@ -552,6 +557,7 @@ class BenchmarkTensors: "lora_ids": self.lora_kernel_meta.active_lora_ids, "offset_start": 0, "add_inputs": add_inputs, + "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } def bench_fn_kwargs( @@ -637,7 +643,7 @@ def bench_optype( # Clear LoRA optimization hash-maps. _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() - # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup + # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) torch.cuda.synchronize() diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 710d30adfd846..d2beb28f70233 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -14,6 +14,10 @@ import ray import torch from ray.experimental.tqdm_ray import tqdm +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + _get_config_dtype_str, +) from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform from vllm.transformers_utils.config import get_config @@ -134,43 +138,36 @@ def benchmark_config( def run(): from vllm.model_executor.layers.fused_moe import override_config + if use_fp8_w8a8: + quant_dtype = torch.float8_e4m3fn + elif use_int8_w8a16: + quant_dtype = torch.int8 + else: + quant_dtype = None + + quant_config = FusedMoEQuantConfig.make( + quant_dtype=quant_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_quant_shape, + ) + with override_config(config): - if use_deep_gemm: - topk_weights, topk_ids, token_expert_indices = fused_topk( - x, input_gating, topk, False - ) - return fused_experts( - x, - w1, - w2, - topk_weights, - topk_ids, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_quant_shape, - allow_deep_gemm=True, - ) - else: - fused_moe( - x, - w1, - w2, - input_gating, - topk, - renormalize=True, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_quant_shape, - ) + topk_weights, topk_ids, token_expert_indices = fused_topk( + x, input_gating, topk, renormalize=not use_deep_gemm + ) + return fused_experts( + x, + w1, + w2, + topk_weights, + topk_ids, + inplace=True, + quant_config=quant_config, + allow_deep_gemm=use_deep_gemm, + ) # JIT compilation & warmup run() @@ -414,7 +411,7 @@ class BenchmarkWorker: use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) - dtype_str = get_config_dtype_str( + dtype_str = _get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which @@ -547,7 +544,7 @@ def save_configs( block_quant_shape: list[int], save_dir: str, ) -> None: - dtype_str = get_config_dtype_str( + dtype_str = _get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 ) @@ -560,7 +557,7 @@ def save_configs( filename = os.path.join(save_dir, filename) print(f"Writing best config to {filename}...") with open(filename, "w") as f: - json.dump(configs, f, indent=4) + json.dump({"triton_version": triton.__version__, **configs}, f, indent=4) f.write("\n") @@ -594,7 +591,11 @@ def main(args: argparse.Namespace): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size - elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): + elif config.architectures[0] in ( + "Qwen2MoeForCausalLM", + "Qwen3MoeForCausalLM", + "Qwen3NextForCausalLM", + ): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size @@ -678,7 +679,11 @@ def main(args: argparse.Namespace): is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) search_space = get_configs_compute_bound(is_fp16, block_quant_shape) print(f"Start tuning over {len(search_space)} configurations...") - + if use_deep_gemm: + raise ValueError( + "Tuning with --use-deep-gemm is not supported as it only tunes Triton " + "kernels. Please remove the flag." + ) start = time.time() configs = _distribute( "tune", diff --git a/benchmarks/kernels/benchmark_polynorm.py b/benchmarks/kernels/benchmark_polynorm.py new file mode 100644 index 0000000000000..9ac8f5e6594e4 --- /dev/null +++ b/benchmarks/kernels/benchmark_polynorm.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools + +import torch + +from vllm import _custom_ops as vllm_ops +from vllm.triton_utils import triton + + +def polynorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + def norm(x, eps: float): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + x = x.float() + return ( + ( + weight[0] * norm(x**3, eps) + + weight[1] * norm(x**2, eps) + + weight[2] * norm(x, eps) + + bias + ) + .to(weight.dtype) + .view(orig_shape) + ) + + +def polynorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + out = torch.empty_like(x) + vllm_ops.poly_norm(out, x, weight, bias, eps) + output = out + + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_dim): + dtype = torch.bfloat16 + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") + weight = torch.ones(3, dtype=dtype, device="cuda") + bias = torch.ones(1, dtype=dtype, device="cuda") + + output_naive = polynorm_naive(x, weight, bias) + output_vllm = polynorm_vllm(x, weight, bias) + + if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +dim_range = [2048, 4096] +configs = list(itertools.product(dim_range, batch_size_range, seq_length_range)) + + +def get_benchmark(): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["dim", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "vllm"], + line_names=["Naive", "vLLM"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name="polynorm-perf", + args={}, + ) + ) + def benchmark(dim, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_dim = dim * 4 + + x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") + weight = torch.ones(3, dtype=dtype, device="cuda") + bias = torch.ones(1, dtype=dtype, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: polynorm_naive(x, weight, bias), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: polynorm_vllm(x, weight, bias), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-dim", + type=int, + default=8192, + help="Intermediate size of MLP", + ) + parser.add_argument( + "--save-path", + type=str, + default="./configs/polnorm/", + help="Path to save polnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff( + batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_dim=args.hidden_dim, + ) + + benchmark = get_benchmark() + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py index 0650cbf3cc18e..c7a4066b39d70 100644 --- a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -1,77 +1,675 @@ -#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time +from collections.abc import Callable +import matplotlib.pyplot as plt +import numpy as np import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm, + silu_mul_fp8_quant_deep_gemm_cuda, ) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used -def benchmark(E, T, H, G=128, runs=50): - current_platform.seed_everything(42) - y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") - tokens_per_expert = torch.randint( - T // 2, T, size=(E,), dtype=torch.int32, device="cuda" +@triton.jit +def _silu_mul_fp8_quant_deep_gemm( + # Pointers ------------------------------------------------------------ + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp8 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) + # Sizes --------------------------------------------------------------- + H: tl.constexpr, # hidden dimension (per output) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + # Strides for input (elements) --------------------------------------- + stride_i_e, + stride_i_t, + stride_i_h, + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + # Stride for counts (elements) + stride_counts_e, + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + G = H // GROUP_SIZE + + # map program id -> (e, g) + pid = tl.program_id(0) + e = pid // G + g = pid % G + + e = e.to(tl.int64) + g = g.to(tl.int64) + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) + + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h + base_ys_offset = e * stride_ys_e + g * stride_ys_g + + for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): + gate = tl.load( + input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0 + ).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0) + + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up + + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) + + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) + + +def silu_mul_fp8_quant_deep_gemm_triton( + y: torch.Tensor, # (E, T, 2*H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + num_parallel_tokens, + group_size: int = 128, + eps: float = 1e-10, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales + + y has shape (E, T, 2*H). The first half of the last dimension is + silu-activated, multiplied by the second half, then quantized into FP8. + + Returns `(y_q, y_s)` where + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + """ + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = (H + group_size - 1) // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, ( + "tokens_per_expert must be shape (E,)" + ) + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) + + # allocate outputs + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, ) + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G,) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) + + return y_q, y_s + + +# Parse generation strategies +strategies = ["uniform", "max_t", "first_t"] + + +def benchmark( + kernel: Callable, + E: int, + T: int, + H: int, + total_tokens: int, + num_parallel_tokens: int = 64, + G: int = 128, + runs: int = 200, + num_warmups: int = 20, + gen_strategy: str = "default", + iterations_per_run: int = 20, +): + def generate_data(seed_offset=0): + """Generate input data with given seed offset""" + current_platform.seed_everything(42 + seed_offset) + y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() + + if gen_strategy == "uniform": + r = torch.rand(size=(E,), device="cuda") + r /= r.sum() + r *= total_tokens + tokens_per_expert = r.int() + tokens_per_expert = torch.minimum( + tokens_per_expert, + torch.ones((E,), device=r.device, dtype=torch.int) * T, + ) + elif gen_strategy == "max_t": + tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert.fill_(total_tokens / E) + elif gen_strategy == "first_t": + tokens_per_expert = torch.zeros(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert[0] = min(T, total_tokens) + else: + raise ValueError(f"Unknown generation strategy: {gen_strategy}") + return y, tokens_per_expert + + dataset_count = 4 + # Pre-generate different input matrices for each iteration to avoid cache effects + data_sets = [generate_data(i) for i in range(dataset_count)] + # Warmup - for _ in range(10): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) - torch.cuda.synchronize() + y, tokens_per_expert = data_sets[0] + for _ in range(num_warmups): + kernel( + y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G + ) + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) # Benchmark - torch.cuda.synchronize() - start = time.perf_counter() + latencies: list[float] = [] for _ in range(runs): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) - torch.cuda.synchronize() + torch.cuda.synchronize() - avg_time = (time.perf_counter() - start) / runs * 1000 + start_event.record() + for i in range(iterations_per_run): + y, tokens_per_expert = data_sets[i % dataset_count] + kernel( + y, + tokens_per_expert, + num_parallel_tokens=num_parallel_tokens, + group_size=G, + ) + end_event.record() + end_event.synchronize() - # Calculate actual work done (only count valid tokens) + total_time_ms = start_event.elapsed_time(end_event) + per_iter_time_ms = total_time_ms / iterations_per_run + latencies.append(per_iter_time_ms) + + # Use median instead of average for better outlier handling + median_time_ms = np.median(latencies) + median_time_s = median_time_ms / 1000 + + # Calculate actual work done (using first dataset for consistency) + _, tokens_per_expert = data_sets[0] actual_tokens = tokens_per_expert.sum().item() actual_elements = actual_tokens * H # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops ops_per_element = 8 total_ops = actual_elements * ops_per_element - gflops = total_ops / (avg_time / 1000) / 1e9 + gflops = total_ops / median_time_s / 1e9 # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs output_bytes = actual_tokens * H * 1 # H fp8 outputs scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 total_bytes = input_bytes + output_bytes + scale_bytes - memory_bw = total_bytes / (avg_time / 1000) / 1e9 + memory_bw = total_bytes / median_time_s / 1e9 - return avg_time, gflops, memory_bw + HOPPER_BANDWIDTH_TBPS = 3.35 + return ( + median_time_ms, + gflops, + memory_bw, + (memory_bw / (HOPPER_BANDWIDTH_TBPS * 1024)) * 100, + ) +def create_comparison_plot( + ratio, cuda_times, baseline_times, config_labels, strategy_name, id +): + """Create a comparison plot for a specific generation strategy""" + fig, ax = plt.subplots(1, 1, figsize=(16, 6)) + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.35 + + # Execution Time plot (lower is better) + ax.bar( + x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue" + ) + ax.bar( + x + width / 2, + baseline_times, + width, + label="Baseline", + alpha=0.8, + color="orange", + ) + + # Add speedup labels over each bar pair + for i in range(len(x)): + speedup = ratio[i] + max_height = max(cuda_times[i], baseline_times[i]) + ax.text( + x[i], + max_height + max_height * 0.02, + f"{speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=9, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + return fig, ax + + +def create_combined_plot(all_results): + """Create a combined plot with all strategies in one PNG""" + num_strategies = len(all_results) + fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies)) + + if num_strategies == 1: + axes = [axes] + + for idx, ( + strategy_name, + ratio, + cuda_times, + baseline_times, + config_labels, + ) in enumerate(all_results): + ax = axes[idx] + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.35 + + # Execution Time plot (lower is better) + ax.bar( + x - width / 2, + cuda_times, + width, + label="CUDA Kernel", + alpha=0.8, + color="blue", + ) + ax.bar( + x + width / 2, + baseline_times, + width, + label="Baseline", + alpha=0.8, + color="orange", + ) + + # Add speedup labels over each bar pair + for i in range(len(x)): + speedup = ratio[i] + max_height = max(cuda_times[i], baseline_times[i]) + ax.text( + x[i], + max_height + max_height * 0.02, + f"{speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=9, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + filename = "../../silu_bench/silu_benchmark_combined.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +outer_dim = 7168 configs = [ - (8, 32, 1024), - (16, 64, 2048), - (32, 128, 4096), # DeepSeekV3 Configs - (256, 16, 7168), - (256, 32, 7168), - (256, 64, 7168), - (256, 128, 7168), - (256, 256, 7168), - (256, 512, 7168), + (8, 1024, 7168), + # DeepSeekV3 Configs + (32, 1024, 7168), + # DeepSeekV3 Configs (256, 1024, 7168), ] -print(f"GPU: {torch.cuda.get_device_name()}") -print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") -print("-" * 50) +runs = 100 +num_warmups = 20 -for E, T, H in configs: - try: - time_ms, gflops, gbps = benchmark(E, T, H) - print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") - except Exception: - print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") +strategy_descriptions = { + "uniform": "Uniform Random", + "max_t": "Even Assignment", + "first_t": "experts[0] = T, experts[1:] = 0", +} + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"Testing strategies: {', '.join(strategies)}") +print(f"Configurations: {len(configs)} configs") + +all_results = [] + +# Run benchmarks for each strategy +for id, strategy in enumerate(strategies): + print(f"\n{'=' * 60}") + print(f"Testing strategy: {strategy_descriptions[strategy]}") + print(f"{'=' * 60}") + + # Collect benchmark data for both algorithms + config_labels = [] + config_x_axis = [] + all_cuda_results = [] + all_baseline_results = [] + all_ratios = [] + + for E, T, H in configs: + total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E] + config_x_axis.append(total_tokens_config) + + cuda_results = [] + baseline_results = [] + ratios = [] + + for total_tokens in total_tokens_config: + config_label = f"E={E},T={T},H={H},TT={total_tokens}" + config_labels.append(config_label) + + # CUDA kernel results + time_ms_cuda, gflops, gbps, perc = benchmark( + silu_mul_fp8_quant_deep_gemm_cuda, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + cuda_results.append((time_ms_cuda, gflops, gbps, perc)) + + # Baseline results + time_ms_triton, gflops, gbps, perc = benchmark( + silu_mul_fp8_quant_deep_gemm_triton, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + baseline_results.append((time_ms_triton, gflops, gbps, perc)) + ratios.append(time_ms_triton / time_ms_cuda) + + print(f"Completed: {config_label}") + all_cuda_results.append(cuda_results) + all_baseline_results.append(baseline_results) + all_ratios.append(ratios) + + # Store results for combined plotting + all_results.append( + ( + strategy_descriptions[strategy], + all_ratios, + all_cuda_results, + all_baseline_results, + config_labels, + config_x_axis, + ) + ) + + # Print summary table for this strategy + print(f"\nSummary Table - {strategy_descriptions[strategy]}:") + print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}") + print("-" * 60) + + for i, (E, T, H) in enumerate(configs): + speedup = baseline_results[i][0] / cuda_results[i][0] + config_label = f"E={E:3d},T={T:4d},H={H:4d}" + print( + f"{config_label:<20} {cuda_results[i][0]:8.5f} " + f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x" + ) + + +def create_total_tokens_plot(all_results): + num_strategies = len(all_results) + num_configs = len(configs) + + # Create side-by-side subplots: 2 columns for speedup and bandwidth percentage + fig, axs = plt.subplots( + num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies) + ) + + # Add main title to the entire figure + fig.suptitle( + "Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)", + fontsize=16, + fontweight="bold", + y=0.98, + ) + + # Handle single strategy case + if num_strategies == 1: + axs = axs.reshape(1, -1) + + # Handle single config case + if num_configs == 1: + axs = axs.reshape(-1, 2) + + for strategy_idx, result in enumerate(all_results): + ( + strategy_name, + all_ratios, + all_cuda_results, + all_baseline_results, + config_labels, + config_x_axis, + ) = result + + for config_idx in range(num_configs): + # Speedup plot (left column) + ax_speedup = axs[strategy_idx, config_idx * 2] + # Bandwidth plot (right column) + ax_bandwidth = axs[strategy_idx, config_idx * 2 + 1] + + E, T, H = configs[config_idx] + ratios = all_ratios[config_idx] + total_tokens_values = config_x_axis[config_idx] + + # Extract CUDA and Triton bandwidth percentages + cuda_bandwidth_percentages = [ + result[3] for result in all_cuda_results[config_idx] + ] + triton_bandwidth_percentages = [ + result[3] for result in all_baseline_results[config_idx] + ] + + # Plot speedup ratios vs total tokens (left plot) + ax_speedup.plot( + total_tokens_values, ratios, "bo-", linewidth=3, markersize=8 + ) + ax_speedup.set_title( + f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) + ax_speedup.grid(True, alpha=0.3) + + ax_bandwidth.plot( + total_tokens_values, + cuda_bandwidth_percentages, + "ro-", + linewidth=3, + markersize=8, + label="CUDA", + ) + ax_bandwidth.plot( + total_tokens_values, + triton_bandwidth_percentages, + "go-", + linewidth=3, + markersize=8, + label="Triton", + ) + ax_bandwidth.set_title( + f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_bandwidth.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_bandwidth.set_ylabel( + "% of Peak Bandwidth", fontweight="bold", fontsize=11 + ) + ax_bandwidth.legend(prop={"weight": "bold"}) + ax_bandwidth.grid(True, alpha=0.3) + + # Format x-axis labels for both plots + for ax in [ax_speedup, ax_bandwidth]: + ax.set_xticks(total_tokens_values) + ax.set_xticklabels( + [ + f"{tt // 1000}K" if tt >= 1000 else str(tt) + for tt in total_tokens_values + ], + fontweight="bold", + ) + # Make tick labels bold + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight("bold") + + # Add value labels on speedup points + for x, y in zip(total_tokens_values, ratios): + ax_speedup.annotate( + f"{y:.2f}x", + (x, y), + textcoords="offset points", + xytext=(0, 12), + ha="center", + fontsize=10, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7), + ) + + # Add value labels on CUDA bandwidth points + for x, y in zip(total_tokens_values, cuda_bandwidth_percentages): + ax_bandwidth.annotate( + f"{y:.1f}%", + (x, y), + textcoords="offset points", + xytext=(0, 12), + ha="center", + fontsize=9, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3), + ) + + # Add value labels on Triton bandwidth points + for x, y in zip(total_tokens_values, triton_bandwidth_percentages): + ax_bandwidth.annotate( + f"{y:.1f}%", + (x, y), + textcoords="offset points", + xytext=(0, -15), + ha="center", + fontsize=9, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="green", alpha=0.3), + ) + + plt.tight_layout() + plt.subplots_adjust(top=0.93) # Make room for main title + filename = "silu_benchmark_total_tokens.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +# Create combined plot with all strategies +combined_plot_filename = create_total_tokens_plot(all_results) + +print(f"\n{'=' * 60}") +print("Benchmark Complete!") +print(f"Generated combined plot: {combined_plot_filename}") +print(f"{'=' * 60}") diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 603ce5ecf0d2c..6ddab46214577 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -259,6 +259,7 @@ if __name__ == "__main__": # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 40903c6c3444f..131df74c7de1b 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -274,6 +274,7 @@ if __name__ == "__main__": quant_dtypes = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 98bde9d83c82d..df2b713e46dc4 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -56,7 +56,7 @@ def w8a8_block_matmul( Bs: The per-block quantization scale for `B`. block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. - output_dytpe: The dtype of the returned tensor. + output_dtype: The dtype of the returned tensor. Returns: torch.Tensor: The result of matmul. diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md index 7adf97bcf5622..f5b5c6c97d484 100644 --- a/benchmarks/multi_turn/README.md +++ b/benchmarks/multi_turn/README.md @@ -55,6 +55,107 @@ output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 ---------------------------------------------------------------------------------------------------- ``` +### JSON configuration file for synthetic conversations generation + +The input flag `--input-file` is used to determine the input conversations for the benchmark.
+When the input is a JSON file with the field `"filetype": "generate_conversations"` the tool will generate synthetic multi-turn (questions and answers) conversations. + +The file `generate_multi_turn.json` is an example file. + +The file must contain the sections `prompt_input` and `prompt_output`. + +The `prompt_input` section must contain `num_turns`, `prefix_num_tokens` and `num_tokens`: + +* `num_turns` - Number of total turns in the conversation (both user & assistant).
+The final value will always be rounded to an even number so each user turn has a reply. +* `prefix_num_tokens` - Tokens added at the start of only the **first user turn** in a conversation (unique per conversation). +* `num_tokens` - Total token length of each **user** message (one turn). + +The `prompt_output` section must contain `num_tokens`: + +* `num_tokens` - Total token length of each **assistant** message (one turn). + +### Random distributions for synthetic conversations generation + +When creating an input JSON file (such as `generate_multi_turn.json`),
+every numeric field (such as `num_turns` or `num_tokens`) requires a distribution.
+The distribution determines how to randomly sample values for the field. + +The available distributions are listed below. + +**Note:** The optional `max` field (for lognormal, zipf, and poisson) can be used to cap sampled values at an upper bound.
+Can be used to make sure that the total number of tokens in every request does not exceed `--max-model-len`. + +#### constant + +```json +{ + "distribution": "constant", + "value": 500 +} +``` + +* `value` - the fixed integer value (always returns the same number). + +#### uniform + +```json +{ + "distribution": "uniform", + "min": 12, + "max": 18 +} +``` + +* `min` - minimum value (inclusive). +* `max` - maximum value (inclusive), should be equal or larger than min. + +#### lognormal + +```json +{ + "distribution": "lognormal", + "average": 1000, + "max": 5000 +} +``` + +You can parameterize the lognormal distribution in one of two ways: + +Using the average and optional median ratio: + +* `average` - target average value of the distribution. +* `median_ratio` - the ratio of the median to the average; controls the skewness. Must be in the range (0, 1). + +Using the parameters of the underlying normal distribution: + +* `mean` - mean of the underlying normal distribution. +* `sigma` - standard deviation of the underlying normal distribution. + +#### zipf + +```json +{ + "distribution": "zipf", + "alpha": 1.2, + "max": 100 +} +``` + +* `alpha` - skew parameter (> 1). Larger values produce stronger skew toward smaller integers. + +#### poisson + +```json +{ + "distribution": "poisson", + "alpha": 10, + "max": 50 +} +``` + +* `alpha` - expected value (λ). Also the variance of the distribution. + ## ShareGPT Conversations To run with the ShareGPT data, download the following ShareGPT dataset: diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py index 411b89dd23dc6..67b937930d58c 100644 --- a/benchmarks/multi_turn/bench_dataset.py +++ b/benchmarks/multi_turn/bench_dataset.py @@ -99,21 +99,105 @@ class PoissonDistribution(Distribution): class LognormalDistribution(Distribution): def __init__( - self, mean: float, sigma: float, max_val: Optional[int] = None + self, + mean: Optional[float] = None, + sigma: Optional[float] = None, + average: Optional[int] = None, + median_ratio: Optional[float] = None, + max_val: Optional[int] = None, ) -> None: + self.average = average + self.median_ratio = median_ratio + self.max_val = max_val + + if average is not None: + if average < 1: + raise ValueError("Lognormal average must be positive") + + if mean or sigma: + raise ValueError( + "When using lognormal average, you can't provide mean/sigma" + ) + + if self.median_ratio is None: + # Default value that provides relatively wide range of values + self.median_ratio = 0.85 + + # Calculate mean/sigma of np.random.lognormal based on the average + mean, sigma = self._generate_lognormal_by_median( + target_average=self.average, median_ratio=self.median_ratio + ) + else: + if mean is None or sigma is None: + raise ValueError( + "Must provide both mean and sigma if average is not used" + ) + + if mean <= 0 or sigma < 0: + raise ValueError( + "Lognormal mean must be positive and sigma must be non-negative" + ) + + # Mean and standard deviation of the underlying normal distribution + # Based on numpy.random.lognormal self.mean = mean self.sigma = sigma - self.max_val = max_val + + @staticmethod + def _generate_lognormal_by_median( + target_average: int, median_ratio: float + ) -> tuple[float, float]: + """ + Compute (mu, sigma) for a lognormal distribution given: + - a target average (mean of the distribution) + - a ratio of median / mean (controls skewness), assume mean > median + + Background: + If Z ~ Normal(mu, sigma^2), then X = exp(Z) ~ LogNormal(mu, sigma). + * mean(X) = exp(mu + sigma^2 / 2) + * median(X) = exp(mu) + + So: + median / mean = exp(mu) / exp(mu + sigma^2 / 2) + = exp(-sigma^2 / 2) + + Rearranging: + sigma^2 = 2 * ln(mean / median) + mu = ln(median) + + This gives a unique (mu, sigma) for any valid mean and median. + """ + # Check input validity: median must be smaller than mean + if median_ratio <= 0 or median_ratio >= 1: + raise ValueError("median_ratio must be in range (0, 1)") + + target_median = target_average * median_ratio + + # Solve sigma^2 = 2 * ln(mean / median) + sigma = np.sqrt(2 * np.log(target_average / target_median)) + mu = np.log(target_median) + + return mu, sigma def sample(self, size: int = 1) -> np.ndarray: samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size) + + if self.average is not None: + # Scale to average + samples *= self.average / samples.mean() + if self.max_val: samples = np.minimum(samples, self.max_val) return np.round(samples).astype(int) def __repr__(self) -> str: - return f"LognormalDistribution[{self.mean}, {self.sigma}]" + if self.average: + return ( + f"LognormalDistribution[{self.average}, " + f"{self.median_ratio}, {self.max_val}]" + ) + return f"LognormalDistribution[{self.mean}, {self.sigma}, {self.max_val}]" class GenConvArgs(NamedTuple): @@ -173,10 +257,21 @@ def get_random_distribution( return PoissonDistribution(conf["alpha"], max_val=max_val) elif distribution == "lognormal": + max_val = conf.get("max", None) + + if "average" in conf: + # Infer lognormal mean/sigma (numpy) from input average + median_ratio = conf.get("median_ratio", None) + return LognormalDistribution( + average=conf["average"], median_ratio=median_ratio, max_val=max_val + ) + + # Use mean/sigma directly (for full control over the distribution) verify_field_exists(conf, "mean", section, subsection) verify_field_exists(conf, "sigma", section, subsection) - max_val = conf.get("max", None) - return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val) + return LognormalDistribution( + mean=conf["mean"], sigma=conf["sigma"], max_val=max_val + ) elif distribution == "uniform": verify_field_exists(conf, "min", section, subsection) diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index d23b7b6e4571d..66d85eaf51312 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -962,7 +962,7 @@ async def main_mp( # At this point all the clients finished, # collect results (TTFT, TPOT, etc.) from all the clients. - # This needs to happens before calling join on the clients + # This needs to happen before calling join on the clients # (result_queue should be emptied). while not result_queue.empty(): client_metrics.append(result_queue.get()) diff --git a/benchmarks/multi_turn/generate_multi_turn.json b/benchmarks/multi_turn/generate_multi_turn.json index 274d03c2bdb2b..03cfc7d63e8aa 100644 --- a/benchmarks/multi_turn/generate_multi_turn.json +++ b/benchmarks/multi_turn/generate_multi_turn.json @@ -15,9 +15,8 @@ }, "prefix_num_tokens": { "distribution": "lognormal", - "mean": 6, - "sigma": 4, - "max": 1500 + "average": 1000, + "max": 5000 }, "num_tokens": { "distribution": "uniform", diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 52bfd82c7fcfe..06494463223bd 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -88,6 +88,7 @@ is_avx512_disabled(AVX512_DISABLED) if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") message(STATUS "Apple Silicon Detected") + set(APPLE_SILICON_FOUND TRUE) set(ENABLE_NUMA OFF) check_sysctl(hw.optional.neon ASIMD_FOUND) check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND) @@ -189,7 +190,7 @@ else() set(USE_ACL OFF) endif() -if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) +if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 49defccbb1fa4..3d32121f13ac2 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f + GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 9c0ed1d09572e..8558976e2c392 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -480,7 +480,6 @@ function (define_gpu_extension_target GPU_MOD_NAME) ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}") endif() - set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17) target_compile_options(${GPU_MOD_NAME} PRIVATE $<$:${GPU_COMPILE_FLAGS}>) diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu deleted file mode 100644 index 0319d1daf302f..0000000000000 --- a/csrc/attention/mla/cutlass_mla_entry.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA -void cutlass_mla_decode_sm100a(torch::Tensor const& out, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, double scale); -#endif - -void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, double scale) { -#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA - return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale); -#endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA"); -} diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu deleted file mode 100644 index 9d05d910dd81f..0000000000000 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -#include "cute/tensor.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/kernel_hardware_info.h" - -#include "cutlass_extensions/common.hpp" - -#include "device/sm100_mla.hpp" -#include "kernel/sm100_mla_tile_scheduler.hpp" - -using namespace cute; -using namespace cutlass::fmha::kernel; - -template -struct MlaSm100 { - using Element = T; - using ElementAcc = float; - using ElementOut = T; - - using TileShape = Shape<_128, _128, Shape<_512, _64>>; - using TileShapeH = cute::tuple_element_t<0, TileShape>; - using TileShapeD = cute::tuple_element_t<2, TileShape>; - - // H K (D_latent D_rope) B - using ProblemShape = cute::tuple; - - using StrideQ = cute::tuple; // H D B - using StrideK = cute::tuple; // K D B - using StrideO = StrideK; // H D B - using StrideLSE = cute::tuple<_1, int>; // H B - - using TileScheduler = - std::conditional_t; - - using FmhaKernel = - cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< - TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler, - /*kIsCpAsync=*/true>; - using Fmha = cutlass::fmha::device::MLA; -}; - -template -typename T::Fmha::Arguments args_from_options( - at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe, - at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, - at::Tensor const& page_table, double scale) { - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = q_nope.device().index(); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); - - int batches = q_nope.sizes()[0]; - int page_count_per_seq = page_table.sizes()[1]; - int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; - int page_size = kv_c_and_k_pe_cache.sizes()[1]; - int max_seq_len = page_size * page_count_per_seq; - using TileShapeH = typename T::TileShapeH; - using TileShapeD = typename T::TileShapeD; - auto problem_shape = - cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); - - auto [H, K, D, B] = problem_shape; - auto [D_latent, D_rope] = D; - - using StrideQ = typename T::StrideQ; - using StrideK = typename T::StrideK; - using StrideO = typename T::StrideO; - using StrideLSE = typename T::StrideLSE; - - StrideQ stride_Q_latent = cute::make_tuple( - static_cast(D_latent), _1{}, static_cast(H * D_latent)); - StrideQ stride_Q_rope = cute::make_tuple(static_cast(D_rope), _1{}, - static_cast(H * D_rope)); - StrideK stride_C = - cute::make_tuple(static_cast(D_latent + D_rope), _1{}, - static_cast(page_size * (D_latent + D_rope))); - StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); - StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast(H)); - StrideO stride_O = cute::make_tuple(static_cast(D_latent), _1{}, - static_cast(H * D_latent)); - - using Element = typename T::Element; - using ElementOut = typename T::ElementOut; - using ElementAcc = typename T::ElementAcc; - auto Q_latent_ptr = static_cast(q_nope.data_ptr()); - auto Q_rope_ptr = static_cast(q_pe.data_ptr()); - auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); - auto scale_f = static_cast(scale); - typename T::Fmha::Arguments arguments{ - problem_shape, - {scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr, - stride_C, C_ptr + D_latent, stride_C, - static_cast(seq_lens.data_ptr()), - static_cast(page_table.data_ptr()), stride_PT, page_count_total, - page_size}, - {static_cast(out.data_ptr()), stride_O, - static_cast(nullptr), stride_LSE}, - hw_info, - 1, // split_kv - nullptr, // is_var_split_kv - }; - // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute - // split_kv automatically based on batch size and sequence length to balance - // workload across available SMs. Consider using var_split_kv for manual - // control if needed. - T::Fmha::set_split_kv(arguments); - return arguments; -} - -template -void runMla(at::Tensor const& out, at::Tensor const& q_nope, - at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, - at::Tensor const& seq_lens, at::Tensor const& page_table, - float scale, cudaStream_t stream) { - using MlaSm100Type = MlaSm100; - typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); - size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - CUTLASS_CHECK(fmha.can_implement(arguments)); - - CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); - - CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); -} - -void cutlass_mla_decode_sm100a(torch::Tensor const& out, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, double scale) { - TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA"); - TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor"); - TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor"); - TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3, - "kv_c_and_k_pe_cache must be a 3D tensor"); - TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor"); - TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor"); - TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor"); - - auto B_q_nope = q_nope.size(0); - auto H_q_nope = q_nope.size(1); - auto D_q_nope = q_nope.size(2); - auto B_q_pe = q_pe.size(0); - auto H_q_pe = q_pe.size(1); - auto D_q_pe = q_pe.size(2); - auto B_pt = page_table.size(0); - auto PAGE_NUM = page_table.size(1); - auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1); - auto D_ckv = kv_c_and_k_pe_cache.size(2); - auto B_o = out.size(0); - auto H_o = out.size(1); - auto D_o = out.size(2); - - TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512"); - TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64"); - TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576"); - TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128, - "H_q_nope, H_q_pe, and H_o must be equal to 128"); - TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, - "PAGE_SIZE must be a power of 2"); - TORCH_CHECK( - B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o, - "Batch dims must be same for page_table, q_nope and q_pe, and out"); - TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0, - "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); - TORCH_CHECK(D_o == 512, "D_o must be equal to 512"); - - TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half || - q_nope.dtype() == at::ScalarType::BFloat16 || - q_nope.dtype() == at::ScalarType::Float8_e4m3fn, - "q_nope must be a half, bfloat16, or float8_e4m3fn tensor"); - TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() && - q_nope.dtype() == q_pe.dtype(), - "kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type"); - TORCH_CHECK(seq_lens.dtype() == torch::kInt32, - "seq_lens must be a 32-bit integer tensor"); - TORCH_CHECK(page_table.dtype() == torch::kInt32, - "page_table must be a 32-bit integer tensor"); - - auto in_dtype = q_nope.dtype(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope)); - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(q_nope.get_device()); - if (in_dtype == at::ScalarType::Half) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, - page_table, scale, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale, stream); - } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale, stream); - } else { - TORCH_CHECK(false, "Unsupported input data type of MLA"); - } -} diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp index 95e32559cd540..fbbc2e588c326 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -133,6 +133,14 @@ public: // printf(" sm_count = %d\n", sm_count); int max_splits = ceil_div(K, 128); max_splits = min(16, max_splits); + + // TODO: This avoids a hang when the batch size larger than 1 and + // there is more than 4 kv_splits. + // Discuss with NVIDIA how this can be fixed. + if (B > 1) { + max_splits = min(2, max_splits); + } + // printf(" max_splits = %d\n", max_splits); int sms_per_batch = max(1, sm_count / B); // printf(" sms_per_batch = %d\n", sms_per_batch); diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index 6dd6f269f3dc9..d1874515cc8fd 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -36,12 +36,14 @@ limitations under the License. #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, torch::Tensor const& seq_lens, torch::Tensor const& page_table, torch::Tensor const& workspace, + double sm_scale, int64_t num_kv_splits) { TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); } @@ -64,11 +66,11 @@ struct IsPersistent { static const bool value = v; }; -template > +template > struct MlaSm100 { using Element = T; using ElementAcc = float; - using ElementOut = T; + using ElementOut = TOut; using TileShape = Shape<_128, _128, Shape<_512, _64>>; using TileShapeH = cute::tuple_element_t<0, TileShape>; @@ -99,6 +101,7 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -162,7 +165,10 @@ typename T::Fmha::Arguments args_from_options( stride_PT, page_count_total, page_size}, - {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, + {static_cast(out.data_ptr()), + stride_O, + static_cast(lse.defined() ? lse.data_ptr() : nullptr), + stride_LSE}, hw_info, // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will @@ -178,9 +184,10 @@ typename T::Fmha::Arguments args_from_options( return arguments; } -template +template void runMla( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -190,9 +197,9 @@ void runMla( double sm_scale, int64_t num_kv_splits, cudaStream_t stream) { - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); + auto arguments = args_from_options(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -214,6 +221,7 @@ void runMla( void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, @@ -233,14 +241,14 @@ void sm100_cutlass_mla_decode( DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { if (in_dtype == at::ScalarType::Half) { - runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::BFloat16) { - runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + runMla>( + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } @@ -253,7 +261,7 @@ void sm100_cutlass_mla_decode( int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc) // which are float, so Element type here doesn't matter. - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; // Get split kv. Requires problem shape and sm_count only. typename MlaSm100Type::Fmha::Arguments arguments; diff --git a/csrc/cache.h b/csrc/cache.h index e8e069aefd9c5..fd230bec27fca 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -36,13 +36,6 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, const std::string& kv_cache_dtype, torch::Tensor& scale); -void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, - torch::Tensor& cp_local_token_select_indices, - torch::Tensor& kv_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - torch::Tensor& scale); - // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index fbb022464ef27..80b4c47c55476 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -396,51 +396,6 @@ __global__ void concat_and_cache_mla_kernel( copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); } -template -__global__ void cp_fused_concat_and_cache_mla_kernel( - const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank] - const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim] - const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens] - cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank - // + pe_dim)] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, // - const int entry_stride, // - const int kv_c_stride, // - const int k_pe_stride, // - const int kv_lora_rank, // - const int pe_dim, // - const int block_size, // - const float* scale // -) { - const int64_t token_idx = cp_local_token_select_indices[blockIdx.x]; - const int64_t slot_idx = slot_mapping[blockIdx.x]; - // NOTE: slot_idx can be -1 if the token is padded - if (slot_idx < 0) { - return; - } - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - - auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, - int src_stride, int dst_stride, int size, int offset) { - for (int i = threadIdx.x; i < size; i += blockDim.x) { - const int64_t src_idx = token_idx * src_stride + i; - const int64_t dst_idx = - block_idx * block_stride + block_offset * entry_stride + i + offset; - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { - dst[dst_idx] = src[src_idx]; - } else { - dst[dst_idx] = - fp8::scaled_convert(src[src_idx], *scale); - } - } - }; - - copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); - copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); -} - } // namespace vllm // KV_T is the data type of key and value tensors. @@ -554,20 +509,6 @@ void reshape_and_cache_flash( kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ reinterpret_cast(scale.data_ptr())); -// KV_T is the data type of key and value tensors. -// CACHE_T is the stored data type of kv-cache. -// KV_DTYPE is the real data type of kv-cache. -#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ - vllm::cp_fused_concat_and_cache_mla_kernel \ - <<>>( \ - reinterpret_cast(kv_c.data_ptr()), \ - reinterpret_cast(k_pe.data_ptr()), \ - cp_local_token_select_indices.data_ptr(), \ - reinterpret_cast(kv_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, entry_stride, \ - kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ - reinterpret_cast(scale.data_ptr())); - void concat_and_cache_mla( torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& k_pe, // [num_tokens, pe_dim] @@ -606,50 +547,6 @@ void concat_and_cache_mla( CALL_CONCAT_AND_CACHE_MLA); } -// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel -// calls into one: -// k_c_normed.index_select(0, cp_local_token_select_indices) + \ -// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \ -// concat_and_cache_mla. -void cp_fused_concat_and_cache_mla( - torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank] - torch::Tensor& k_pe, // [num_total_tokens, pe_dim] - torch::Tensor& cp_local_token_select_indices, // [num_tokens] - torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + - // pe_dim)] - torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& scale) { - // NOTE(woosuk): In vLLM V1, key.size(0) can be different from - // slot_mapping.size(0) because of padding for CUDA graphs. - // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because - // both include padding. - // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) - // since key includes padding for CUDA graphs, while slot_mapping does not. - // In this case, slot_mapping.size(0) represents the actual number of tokens - // before padding. - // For compatibility with both cases, we use slot_mapping.size(0) as the - // number of tokens. - int num_tokens = slot_mapping.size(0); - int kv_lora_rank = kv_c.size(1); - int pe_dim = k_pe.size(1); - int block_size = kv_cache.size(1); - - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); - - int kv_c_stride = kv_c.stride(0); - int k_pe_stride = k_pe.stride(0); - int block_stride = kv_cache.stride(0); - int entry_stride = kv_cache.stride(1); - - dim3 grid(num_tokens); - dim3 block(std::min(kv_lora_rank, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, - CALL_CP_FUSED_CONCAT_AND_CACHE_MLA); -} - namespace vllm { template diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index 17bbe04eef94a..c3a21796881c9 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -17,4 +17,8 @@ #warning "unsupported vLLM cpu implementation" #endif +#ifdef _OPENMP + #include +#endif + #endif \ No newline at end of file diff --git a/csrc/cpu/cpu_types_vxe.hpp b/csrc/cpu/cpu_types_vxe.hpp index ab8cbbbf4ec4f..51bca37e699b9 100644 --- a/csrc/cpu/cpu_types_vxe.hpp +++ b/csrc/cpu/cpu_types_vxe.hpp @@ -12,7 +12,7 @@ namespace vec_op { #define vec_sub(a, b) ((a) - (b)) #define vec_mul(a, b) ((a) * (b)) #define vec_div(a, b) ((a) / (b)) -#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic +#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic #define vec_sl(a, b) ((a) << (b)) // Vector Shift Left // FIXME: FP16 is not fully supported in Torch-CPU diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index f3f00edb36068..6def0e061fa96 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -22,6 +22,23 @@ void release_dnnl_matmul_handler(int64_t handler) { delete ptr; } +DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) { + this->realloc(allocation_unit * 128); +} + +void DNNLScratchPadManager::realloc(size_t new_size) { + new_size = round(new_size); + if (new_size > size_) { + ptr_ = std::aligned_alloc(64, new_size); + size_ = new_size; + } +} + +DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() { + static DNNLScratchPadManager manager; + return &manager; +} + template class DNNLPrimitiveCache { public: @@ -166,6 +183,23 @@ struct hash { hash()(static_cast(val.bias_type)); } }; + +template <> +struct hash { + size_t operator()( + const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { + return hash()(val.b_n_size) ^ hash()(val.b_k_size); + } +}; + +template <> +struct hash { + size_t operator()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const { + return hash()(val.a_m_size) ^ + hash()(val.a_m_stride) ^ hash()(val.use_bias) ^ + hash()(static_cast(val.bias_type)); + } +}; } // namespace std bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l, @@ -181,6 +215,17 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, l.bias_type == r.bias_type; } +bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l, + const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size; +} + +bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l, + const MatMulPrimitiveHandler::MSizeCacheKey& r) { + return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride && + l.use_bias == r.use_bias && l.bias_type == r.bias_type; +} + static std::shared_ptr get_w8a8_class_primitive_cache( const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key, @@ -239,6 +284,11 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { } dnnl::matmul matmul = get_matmul_cache(args); + + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5); + scratchpad_storage->set_data_handle( + DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + matmul.execute(default_stream(), memory_cache_); default_stream().wait(); } @@ -257,6 +307,8 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( return m_size_cache_->get_or_create(key, [&]() { dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + manager->realloc(desc.scratchpad_desc().get_size()); return dnnl::matmul(desc); }); } @@ -300,6 +352,11 @@ void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get()); + + memory_cache_[DNNL_ARG_SCRATCHPAD] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(5, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); } dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( @@ -319,6 +376,9 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( dnnl::memory::format_tag::ab); dnnl::primitive_attr attr; + + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + // For PER_TOKEN, scales will be applied in outside epilogue if (a_qs_ == QuantizationStrategy::PER_TENSOR) { attr.set_scales_mask(DNNL_ARG_SRC, 0); @@ -344,3 +404,120 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( attr); } } + +MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args) + : DNNLMatMulPrimitiveHandler( + static_cast(args), args.ab_type), + m_size_cache_(nullptr) { + assert(ab_type_ == dnnl::memory::data_type::f32 || + ab_type_ == dnnl::memory::data_type::bf16 || + ab_type_ == dnnl::memory::data_type::f16); + prepack_weight(args.b_ptr, + create_primitive_desc( + MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, + .a_m_stride = DNNL_RUNTIME_DIM_VAL, + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, + true) + .weights_desc()); + init_runtime_memory_cache(args); +} + +static std::shared_ptr +get_matul_class_primitive_cache( + const MatMulPrimitiveHandler::ClassMatmulCacheKey& key, + int64_t cache_size) { + static MatMulPrimitiveHandler::ClassMatmulCache cache(128); + assert(cache_size > 0); + return cache.get_or_create(key, [&]() { + return std::make_shared(cache_size); + }); +} + +void MatMulPrimitiveHandler::execute(ExecArgs& args) { + auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); + auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); + a_storage->set_data_handle((void*)args.a_ptr); + a_mem_desc->dims[0] = args.a_m_size; + a_mem_desc->format_desc.blocking.strides[0] = args.a_m_stride; + c_storage->set_data_handle((void*)args.c_ptr); + c_mem_desc->dims[0] = args.a_m_size; + + if (args.use_bias) { + auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2); + bias_storage->set_data_handle((void*)args.bias_ptr); + } + + dnnl::matmul matmul = get_matmul_cache(args); + + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3); + scratchpad_storage->set_data_handle( + DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + + matmul.execute(default_stream(), memory_cache_); + default_stream().wait(); +} + +dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache( + const MSizeCacheKey& key) { + if (m_size_cache_.get() == nullptr) { + ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_}; + m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_); + } + return m_size_cache_->get_or_create(key, [&]() { + dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + manager->realloc(desc.scratchpad_desc().get_size()); + return dnnl::matmul(desc); + }); +} + +dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc( + const MSizeCacheKey& key, bool first_time) { + dnnl::memory::desc a_md; + dnnl::memory::desc b_md; + if (first_time) { + a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, + dnnl::memory::format_tag::ab); + b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_, + dnnl::memory::format_tag::any); + } else { + a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, + {key.a_m_stride, 1}); + b_md = b_target_mem_desc_; + } + dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + + if (key.use_bias) { + dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, + c_md, attr); + } else { + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); + } +} + +void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { + memory_cache_[DNNL_ARG_SRC] = dnnl::memory( + {{1, b_k_size_}, b_type_, {b_k_size_, 1}}, default_engine(), nullptr); + set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); + memory_cache_[DNNL_ARG_DST] = + dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); + + memory_cache_[DNNL_ARG_BIAS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get()); + + memory_cache_[DNNL_ARG_SCRATCHPAD] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h index 54ceefced9e98..ad6773d2b9fd6 100644 --- a/csrc/cpu/dnnl_helper.h +++ b/csrc/cpu/dnnl_helper.h @@ -59,6 +59,30 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() { return DNNLType>::type; } +class DNNLScratchPadManager { + public: + static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB + + static DNNLScratchPadManager* get_dnnl_scratchpad_manager(); + + DNNLScratchPadManager(); + + template + T* get_data() { + return reinterpret_cast(ptr_); + } + + static size_t round(size_t size) { + return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit; + } + + void realloc(size_t new_size); + + private: + size_t size_; + void* ptr_; +}; + class DNNLMatMulPrimitiveHandler { public: virtual ~DNNLMatMulPrimitiveHandler() = default; @@ -166,4 +190,54 @@ class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { std::shared_ptr m_size_cache_; }; +class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { + public: + struct Args : public DNNLMatMulPrimitiveHandler::Args { + dnnl::memory::data_type ab_type; + }; + + struct ClassMatmulCacheKey { + dnnl_dim_t b_n_size; + dnnl_dim_t b_k_size; + + friend bool operator==(const ClassMatmulCacheKey& l, + const ClassMatmulCacheKey& r); + }; + + struct MSizeCacheKey { + dnnl_dim_t a_m_size; + dnnl_dim_t a_m_stride; + bool use_bias; + dnnl::memory::data_type bias_type; + + friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); + }; + + using MSizeCache = DNNLPrimitiveCache; + using ClassMatmulCache = + DNNLPrimitiveCache>; + + struct ExecArgs : public MSizeCacheKey { + const void* a_ptr; + const void* bias_ptr; + void* c_ptr; + }; + + public: + MatMulPrimitiveHandler(const Args& args); + + void execute(ExecArgs& args); + + private: + dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, + bool first_time); + + void init_runtime_memory_cache(const Args& args); + + dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); + + private: + std::shared_ptr m_size_cache_; +}; + #endif diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp index acc3b9ecde143..1c42a75bc2d61 100644 --- a/csrc/cpu/dnnl_kernels.cpp +++ b/csrc/cpu/dnnl_kernels.cpp @@ -145,7 +145,8 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, } } - float scale_val, azp_val; + float scale_val; + float azp_val = 0.0f; if constexpr (AZP) { float max_scalar = max_value.reduce_max(); float min_scalar = min_value.reduce_min(); @@ -379,6 +380,7 @@ void onednn_scaled_mm( exec_args.a_ptr = a.data_ptr(); exec_args.a_m_size = a.size(0); exec_args.bias_ptr = nullptr; + exec_args.bias_type = get_dnnl_type(); exec_args.use_bias = false; exec_args.a_scales_ptr = nullptr; exec_args.a_zero_points_ptr = nullptr; @@ -492,3 +494,56 @@ void dynamic_scaled_int8_quant( } }); } + +int64_t create_onednn_mm_handler(const torch::Tensor& b, + int64_t primitive_cache_size) { + TORCH_CHECK(b.dim() == 2); + + MatMulPrimitiveHandler::Args args; + args.primitive_cache_size = primitive_cache_size; + + args.b_k_size = b.size(0); + args.b_k_stride = b.stride(0); + args.b_n_size = b.size(1); + args.b_n_stride = b.stride(1); + args.b_ptr = b.data_ptr(); + + VLLM_DISPATCH_FLOATING_TYPES(b.scalar_type(), "create_onednn_mm_handler", + [&] { + args.c_type = get_dnnl_type(); + args.ab_type = get_dnnl_type(); + }); + + return reinterpret_cast(new MatMulPrimitiveHandler(args)); +} + +void onednn_mm(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const std::optional& bias, int64_t handler) { + CPU_KERNEL_GUARD_IN(onednn_mm) + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(a.stride(-1) == 1); + TORCH_CHECK(c.stride(-1) == 1); + MatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + + MatMulPrimitiveHandler::ExecArgs exec_args; + exec_args.a_m_size = a.size(0); + exec_args.a_m_stride = a.stride(0); + + VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] { + if (bias.has_value()) { + exec_args.use_bias = true; + exec_args.bias_type = get_dnnl_type(); + exec_args.bias_ptr = bias->data_ptr(); + } else { + exec_args.use_bias = false; + exec_args.bias_type = get_dnnl_type(); + exec_args.bias_ptr = nullptr; + } + exec_args.a_ptr = a.data_ptr(); + exec_args.c_ptr = c.data_ptr(); + + ptr->execute(exec_args); + }); +} diff --git a/csrc/cpu/sgl-kernels/moe.cpp b/csrc/cpu/sgl-kernels/moe.cpp index beeccff783ea0..94b24c2f13a06 100644 --- a/csrc/cpu/sgl-kernels/moe.cpp +++ b/csrc/cpu/sgl-kernels/moe.cpp @@ -215,7 +215,7 @@ int moe_align_block_size( offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M); } }); - // TODO: do we need to vecterize this ? + // TODO: do we need to vectorize this ? for (int mb = 0; mb < num_token_blocks; ++mb) { offsets[mb + 1] += offsets[mb]; } diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index c9f426bdf618a..98c3ebc5a75f8 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -21,6 +21,12 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, const std::optional& bias, int64_t handler); +int64_t create_onednn_mm_handler(const torch::Tensor& b, + int64_t primitive_cache_size); + +void onednn_mm(torch::Tensor& c, const torch::Tensor& a, + const std::optional& bias, int64_t handler); + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -153,6 +159,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("release_dnnl_matmul_handler(int handler) -> ()", &release_dnnl_matmul_handler); + // Create oneDNN GEMM handler + ops.def( + "create_onednn_mm_handler(Tensor b, int " + "primitive_cache_size) -> int", + &create_onednn_mm_handler); + + // oneDNN GEMM + ops.def( + "onednn_mm(Tensor! c, Tensor a, Tensor? bias, " + "int handler) -> ()"); + ops.impl("onednn_mm", torch::kCPU, &onednn_mm); + // Create oneDNN W8A8 handler ops.def( "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " diff --git a/csrc/cub_helpers.h b/csrc/cub_helpers.h new file mode 100644 index 0000000000000..470a63a22cab0 --- /dev/null +++ b/csrc/cub_helpers.h @@ -0,0 +1,17 @@ +#pragma once + +#ifndef USE_ROCM + #include + #if CUB_VERSION >= 200800 + #include +using CubAddOp = cuda::std::plus<>; +using CubMaxOp = cuda::maximum<>; + #else // if CUB_VERSION < 200800 +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; + #endif // CUB_VERSION +#else + #include +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; +#endif // USE_ROCM diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 44709b4597765..58926f6429dd3 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -15,6 +15,8 @@ typedef __hip_bfloat16 nv_bfloat16; #include #include #include +#include +#include namespace vllm { #define CUDACHECK(cmd) \ @@ -555,22 +557,47 @@ class CustomAllreduce { size /= d; auto bytes = size * sizeof(typename packed_t::P); int blocks = std::min(block_limit, (size + threads - 1) / threads); + + // Check environment variable once + const char* env_algo = std::getenv("VLLM_CUSTOM_ALLREDUCE_ALGO"); + bool force_1stage = false; + bool force_2stage = false; + if (env_algo != nullptr) { + if (std::strcmp(env_algo, "1stage") == 0 || + std::strcmp(env_algo, "oneshot") == 0) { + force_1stage = true; + } else if (std::strcmp(env_algo, "2stage") == 0 || + std::strcmp(env_algo, "twoshot") == 0) { + force_2stage = true; + } else { + throw std::runtime_error( + "Invalid VLLM_CUSTOM_ALLREDUCE_ALGO: " + std::string(env_algo) + + ". Valid values: 1stage, oneshot, 2stage, twoshot"); + } + } + #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); -#define REDUCE_CASE(ngpus) \ - case ngpus: { \ - if (world_size_ == 2) { \ - KL(ngpus, cross_device_reduce_1stage); \ - } else if (fully_connected_) { \ - if ((world_size_ <= 4 && bytes < 512 * 1024) || \ - (world_size_ <= 8 && bytes < 256 * 1024)) { \ - KL(ngpus, cross_device_reduce_1stage); \ - } else { \ - KL(ngpus, cross_device_reduce_2stage); \ - } \ - } \ - break; \ +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (force_1stage) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (force_2stage) { \ + KL(ngpus, cross_device_reduce_2stage); \ + } else { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (fully_connected_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + } \ + break; \ } switch (world_size_) { diff --git a/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp b/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp deleted file mode 100644 index ec75c29e54f4d..0000000000000 --- a/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp +++ /dev/null @@ -1,123 +0,0 @@ -// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl -// clang-format off -#pragma once - -#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" - -#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA_TMA_WS_SS (BlockScaled Builders) -template < - class ElementA, - class GmemLayoutATag, - int AlignmentA, - class ElementB, - class GmemLayoutBTag, - int AlignmentB, - class ElementAccumulator, - class TileShape_MNK, - class ClusterShape_MNK, - class StageCountType, - int ScaleGranularityM -> -struct CollectiveBuilder< - arch::Sm90, - arch::OpClassTensorOp, - ElementA, - GmemLayoutATag, - AlignmentA, - ElementB, - GmemLayoutBTag, - AlignmentB, - ElementAccumulator, - TileShape_MNK, - ClusterShape_MNK, - StageCountType, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, - cute::enable_if_t< - not detail::is_use_rmem_A()> -> { - using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; - - static_assert(is_static::value); - static_assert(is_static::value); -#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); -#endif - static_assert(detail::is_aligned(), - "Should meet TMA alignment requirement\n"); - - static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); - static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert((!IsFP8Input || !IsArrayOfPointersGemm), - "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); - - // For fp32 types, map to tf32 MMA value type - using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; - using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; - - static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - - static constexpr bool IsCooperative = cute::is_any_of_v>; - using AtomLayoutMNK = cute::conditional_t>, Layout>>; - - using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< - ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); - - using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); - - using SmemLayoutAtomA = decltype(detail::ss_smem_selector< - GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomB = decltype(detail::ss_smem_selector< - GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; - static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); - - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; - - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; - - using CollectiveOp = CollectiveMma< - DispatchPolicy, - TileShape_MNK, - ElementA, - TagToStrideA_t, - ElementB, - TagToStrideB_t, - TiledMma, - GmemTiledCopyA, - SmemLayoutAtomA, - SmemCopyAtomA, - cute::identity, - GmemTiledCopyB, - SmemLayoutAtomB, - SmemCopyAtomB, - cute::identity - >; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp b/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp deleted file mode 100644 index 13b90e998625e..0000000000000 --- a/csrc/cutlass_extensions/gemm/collective/fp8_accumulation.hpp +++ /dev/null @@ -1,183 +0,0 @@ -// clang-format off -// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp - -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cute/algorithm/clear.hpp" -#include "cute/tensor.hpp" - -////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////FP8 Accumulation/////////////////////////// -////////////////////////////////////////////////////////////////////////////// -/// This class provides API to promote (add) or scale (multiply_add) the results -/// from the tensor core accumulators to the main accumulators when the number -/// of MMAs reaches the max number of MMA interval specified by user, after that -/// the tensor core accumulators are zeroed. -////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { - -template < - class EngineAccum, - class LayoutAccum> -struct GmmaFP8AccumulationWithScale { - using TensorAccum = cute::Tensor; - using ElementAccumulator = typename EngineAccum::value_type; - - static_assert(is_static::value, "Accumulator Layout should be static"); - static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); - -private: - TensorAccum& accum_; - TensorAccum accum_temp_; - - uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. - uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop - uint32_t mma_count_; // current executed MMAs - uint32_t reset_accum_flag_; // accum needs to be zeroed or not. - - // promote or `add` the partial accumulators to main accumulator (FADD). - CUTLASS_DEVICE - void promote_core() { - warpgroup_wait<0>(); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accum_); ++i) { - accum_(i) += accum_temp_(i); - } - } - - // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_core(const cute::Tensor &scale) { - using TensorScale = cute::Tensor; - - static_assert(is_static::value, "Scale Layout should be static"); - static_assert(is_rmem::value , "Scale tensor must be rmem resident."); - - static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); - - warpgroup_wait<0>(); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accum_); ++i) { - accum_(i) += accum_temp_(i) * scale(i); - } - } - -public: - CUTLASS_DEVICE - GmmaFP8AccumulationWithScale( - TensorAccum &accum, - uint32_t accum_promotion_interval, - uint32_t mma_count_per_mainloop_iteration) - : accum_(accum), - accum_promotion_interval_(accum_promotion_interval), - mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), - mma_count_(0), - reset_accum_flag_(0) - { - accum_temp_ = cute::make_fragment_like(accum); - } - - // - // Methods (Common) - // - - CUTLASS_DEVICE - TensorAccum& operator()() { - return accum_temp_; - } - - /// prepare the MMA accumulators when initialization or zeroing is required. - CUTLASS_DEVICE - bool prepare_if_needed() { - return reset_accum_flag_; - } - - // - // Methods (for FADD version) - // - - /// promote (add) the results from the MMA accumulators to main accumulator if needed. - CUTLASS_DEVICE - void promote_if_needed() { - mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); - if (reset_accum_flag_) { - promote_core(); - mma_count_ = 0; - } - } - - /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. - CUTLASS_DEVICE - void promote_residue_if_needed() { - if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - promote_core(); - } - } - - // - // Methods (for FFMA version) - // - - /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_if_needed(const cute::Tensor &scale) { - mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); - if (reset_accum_flag_) { - scale_core(scale); - mma_count_ = 0; - } - } - - /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_residue_if_needed(const cute::Tensor &scale) { - if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { - scale_core(scale); - } - } -}; - -} // namespace cutlass::gemm::collective diff --git a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp deleted file mode 100644 index ce7f47cf72337..0000000000000 --- a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ /dev/null @@ -1,729 +0,0 @@ -// clang-format off -// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp - -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/trace.h" -#include "cutlass/numeric_types.h" - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/arch/copy_sm80.hpp" -#include "cute/arch/copy_sm90.hpp" -#include "cute/algorithm/functional.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" - -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// WarpSpecialized Mainloop -template < - int Stages, - class ClusterShape, - class KernelSchedule, - int ScaleGranularityM_, - class TileShape_, - class ElementA_, - class StrideA_, - class ElementB_, - class StrideB_, - class TiledMma_, - class GmemTiledCopyA_, - class SmemLayoutAtomA_, - class SmemCopyAtomA_, - class TransformA_, - class GmemTiledCopyB_, - class SmemLayoutAtomB_, - class SmemCopyAtomB_, - class TransformB_> -struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8, - TileShape_, - ElementA_, - StrideA_, - ElementB_, - StrideB_, - TiledMma_, - GmemTiledCopyA_, - SmemLayoutAtomA_, - SmemCopyAtomA_, - TransformA_, - GmemTiledCopyB_, - SmemLayoutAtomB_, - SmemCopyAtomB_, - TransformB_> -{ - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; - using TileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using ElementBlockScale = ElementAccumulator; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - - using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); - using MainloopPipeline = cutlass::PipelineTmaAsync; - using PipelineState = cutlass::PipelineState; - using PipelineParams = typename MainloopPipeline::Params; - - // Two threads per CTA are producers (1 for operand tile and 32 for scales) - static constexpr int NumProducerThreadEvents = 33; - - static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; - static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); - - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutA = decltype(tile_to_shape( - SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); - using SmemLayoutB = decltype(tile_to_shape( - SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); - - // Block scaling gmem-to-smem copy atom - using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; - using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; - - // Block scaling smem layout - using SmemLayoutScaleA = Layout, Int>>; - using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value && - cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v, - "ElementAccumulator and ElementBlockScale should be same datatype"); - - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> { - cute::array_aligned> smem_A; // mxk - cute::array_aligned> smem_B; // nxk - cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k - cute::array_aligned> smem_scale_B; // 1xk - } tensors; - - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Host side kernel arguments - struct Arguments { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - ElementBlockScale const* ptr_scale_A; - ElementBlockScale const* ptr_scale_B; - }; - - // Device side kernel params - struct Params { - // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy_A_sm90( - GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_,_,0), - TileShape{}, - ClusterShape{})); - // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy_B_sm90( - GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_,_,0), - TileShape{}, - ClusterShape{})); - TMA_A tma_load_a; - TMA_B tma_load_b; - uint32_t tma_transaction_bytes = TmaTransactionBytes; - uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; - uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; - // Block scaling factors for A and B - ElementBlockScale const* ptr_scale_A; - ElementBlockScale const* ptr_scale_B; - }; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - (void) workspace; - - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - - auto ptr_A = reinterpret_cast(args.ptr_A); - auto ptr_B = reinterpret_cast(args.ptr_B); - - Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( - GmemTiledCopyA{}, - tensor_a, - SmemLayoutA{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); - typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); - uint32_t transaction_bytes_mk = TmaTransactionBytesMK; - uint32_t transaction_bytes_nk = TmaTransactionBytesNK; - uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; - - return { - tma_load_a, - tma_load_b, - transaction_bytes, - transaction_bytes_mk, - transaction_bytes_nk, - args.ptr_scale_A, - args.ptr_scale_B - }; - } - - template - static bool - can_implement( - ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - constexpr int tma_alignment_bits = 128; - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; - - bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); - - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - return implementable; - } - - static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytesMK = - cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytesNK = - cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - } - - /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) - /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) - template - CUTLASS_DEVICE auto - load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { - using X = Underscore; - // Separate out problem shape for convenience - auto [M,N,K,L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - - constexpr auto scales_m = Int{}; - auto tM = get<2>(gA_mkl.shape()); - auto tN = get<2>(gB_nkl.shape()); - auto tK = get<3>(gA_mkl.shape()); - - // Make the tiled views of scale tensors - auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) - auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); - auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) - auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); - - // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and - // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. - Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) - Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) - - return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Producer Perspective - template < - class TensorA, class TensorB, - class TensorScaleA, class TensorScaleB, - class KTileIterator, class BlockCoord - > - CUTLASS_DEVICE void - load( - Params const& mainloop_params, - MainloopPipeline pipeline, - PipelineState smem_pipe_write, - cute::tuple const& load_inputs, - BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, - uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) { - int lane_predicate = cute::elect_one_sync(); - - // Blockscaling: Tma loads for load_input and CpAsync for load_scale - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) - - // - // Prepare the TMA loads for A and B - // - - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - - // Block scaling: load_scale has scaling tensors in global memory which are not tiled - Tensor mScaleA_mkl = get<2>(load_inputs); - Tensor mScaleB_nkl = get<3>(load_inputs); - auto scales_m = get<0>(mScaleA_mkl.shape()); - - Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); - - Tensor gScaleA = local_tile( - mScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) - Tensor cScaleA = local_tile( - cScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); - Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) - - // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 - TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, - Layout>{}, Layout>{}); // (1,1,1) - TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, - Layout>{}, Layout>{}); // (1,1,1) - ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); - ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); - - Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); - Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA); - Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); - - Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); - Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); - - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - - // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors - // Maps the tile -> block, value - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); - } - } - - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); - } - } - - // Allocate predicate tensors for a_scales (since we can't guarantee that - // all scales are valid, since we could have a partial tiles along M) - Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); - #pragma unroll - for (int i = 0; i < size(tApA_ScaleA); ++i) { - tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; - } - - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); - - // - // Copy gmem to smem for *k_tile_iter - // - int write_stage = smem_pipe_write.index(); - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - - // Copy operands A and B from global memory to shared memory - if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); - - // Copy scale tensors from global memory to shared memory - copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); - copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); - pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); - - ++k_tile_iter; - - // Advance smem_pipe_write - ++smem_pipe_write; - } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void - load_tail( - MainloopPipeline pipeline, - PipelineState smem_pipe_write) { - int lane_predicate = cute::elect_one_sync(); - - // Issue the epilogue waits - if (lane_predicate) { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all - * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was - * still inverted from make_producer_start_state - */ - pipeline.producer_tail(smem_pipe_write); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective - template < - class FrgTensorC - > - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, - PipelineState smem_pipe_read, - FrgTensorC& accum, - int k_tile_count, - int thread_idx, - TensorStorage& shared_tensors, - Params const& mainloop_params) { - - - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - - // Block scaling - Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), - Layout< - Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, - Stride, _0, Int> - >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) - - // - // Define C accumulators and A/B partitioning - // - - // Layout of warp group to thread mapping - - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and - stride<0>(typename TiledMma::BLayout{}) == 0 and - size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, - "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); - - constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, - Int{}); - - int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); - - TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); - - Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. - - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - - // - // PIPELINED MAIN LOOP - // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), - "ERROR : Incorrect number of MMAs in flight"); - - // We release buffers to producer warps(dma load) with some mmas in flight - PipelineState smem_pipe_release = smem_pipe_read; - - // Per block scale values for operand A and B - - using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. - using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above - - Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) - ElementBlockScale scale_b; - - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); - warpgroup_fence_operand(accumulation()); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - if (accumulation.prepare_if_needed()) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - int read_stage = smem_pipe_read.index(); - - // Load per block scale values from shared memory to registers. - scale_b = sScaleB[read_stage]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); - } - if constexpr (ScaleMsPerTile == 1) { - static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. - } else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; - } - } - - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` - accumulation.scale_if_needed(tCrScaleAViewAsC); - - ++smem_pipe_read; - } - - warpgroup_fence_operand(accumulation()); - // Mainloop GMMAs - k_tile_count -= prologue_mma_count; - - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // - // Compute on k_tile - // - - int read_stage = smem_pipe_read.index(); - - // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) - scale_b = sScaleB[read_stage]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); - } - if constexpr (ScaleMsPerTile == 1) { - static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. - } else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; - } - } - - if (accumulation.prepare_if_needed()) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - warpgroup_fence_operand(accumulation()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_wait(); - warpgroup_fence_operand(accumulation()); - - // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` - accumulation.scale_if_needed(tCrScaleAViewAsC); - - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_read; - ++smem_pipe_release; - } - - accumulation.scale_residue_if_needed(tCrScaleAViewAsC); - - warpgroup_fence_operand(accumulation()); - } - - /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void - mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - k_tile_count -= prologue_mma_count; - - smem_pipe_release.advance(k_tile_count); - - // Wait on all GMMAs to complete - warpgroup_wait<0>(); - - for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - ++smem_pipe_release; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/csrc/cutlass_extensions/gemm/dispatch_policy.hpp deleted file mode 100644 index df809e27a3efe..0000000000000 --- a/csrc/cutlass_extensions/gemm/dispatch_policy.hpp +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include "cutlass/gemm/dispatch_policy.hpp" - -namespace cutlass::gemm { - -////////////////////////////////////////////////////////////////////////////// - -// FP8 related policies (including Blocked Scaled Accumulation) -// `ScaleGranularityM` specifies scaling granularity along M, while zero-value -// `ScaleGranularityM` indicates that scaling granularity is -// `size<0>(TileShape_MNK{})` along M. -template -struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum - : KernelTmaWarpSpecializedCooperative {}; - -// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp -// specialized dynamic schedule For FP8 kernels with Block Scaling -template , - class KernelSchedule = KernelTmaWarpSpecialized, - int ScaleGranularityM = - 0 // `ScaleGranularityM` specifies scaling granularity along M, - // while zero-value `ScaleGranularityM` indicates that scaling - // granularity is `size<0>(TileShape_MNK{})` along M. - > -struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 - : MainloopSm90TmaGmmaWarpSpecialized { - static_assert( - cute::is_same_v< - KernelSchedule, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< - ScaleGranularityM>>, - "KernelSchedule must be one of the warp specialized policies"); -}; - -////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh index e7fbba4cd4b0d..085ee1290031f 100644 --- a/csrc/cutlass_extensions/vllm_collective_builder.cuh +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -1,6 +1,6 @@ #pragma once -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" namespace cutlass::gemm::collective { using namespace cute; diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 2728aa81f0c9f..995374a50b037 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -52,15 +52,6 @@ #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) -#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \ - AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__) - -#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \ - AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__) - -#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__)) - #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f051eb0702228..93c73d58390e1 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,15 +1,10 @@ #include "type_convert.cuh" #include "dispatch_utils.h" +#include "cub_helpers.h" #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif - namespace vllm { // TODO(woosuk): Further optimize this kernel. @@ -30,7 +25,7 @@ __global__ void rms_norm_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -85,7 +80,7 @@ fused_add_rms_norm_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -126,7 +121,7 @@ fused_add_rms_norm_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -140,6 +135,211 @@ fused_add_rms_norm_kernel( } } +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. + + _f16VecPN struct extends _f16Vec to add operations specifically required for + polynomial normalization (poly norm). + The original _f16Vec does not include the sum-of-powers computation or + in-place polynomial normalization logic. */ +template +struct alignas(16) _f16VecPN : _f16Vec { + using Base = _f16Vec; + using Converter = typename Base::Converter; + using T1 = typename Base::T1; + using T2 = typename Base::T2; + using Base::data; + + __device__ auto sum_pows() const { + float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f; + +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + float x2 = z.x * z.x; + float x4 = x2 * x2; + float x6 = x4 * x2; + + float y2 = z.y * z.y; + float y4 = y2 * y2; + float y6 = y4 * y2; + + s2 += x2 + y2; + s4 += x4 + y4; + s6 += x6 + y6; + } + return std::make_tuple(s2, s4, s6); + } + + __device__ void poly_norm_inplace(const float w2_inv_std, + const float w1_inv_std2, + const float w0_inv_std3, const float bias) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + + float x2 = z.x * z.x; + float x3 = x2 * z.x; + z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias; + + float y2 = z.y * z.y; + float y3 = y2 * z.y; + z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias; + + auto out = Converter::convert(z); + data[i] = out.x; + data[i + 1] = out.y; + } + } +}; + +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [3] + const scalar_t* __restrict__ bias, // [1] + const float epsilon, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16VecPN>); + static_assert(sizeof(_f16VecPN) == sizeof(scalar_t) * width); + + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = + reinterpret_cast*>(input); + const int vec_hidden_size = hidden_size / width; + float variance = 0.0f; + float variance2 = 0.0f; + float variance3 = 0.0f; + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16VecPN temp = input_v[id]; + auto [x2, x4, x6] = temp.sum_pows(); + + variance += x2; + variance2 += x4; + variance3 += x6; + } + + float3 thread_variances = make_float3(variance, variance2, variance3); + + struct SumOp { + __device__ float3 operator()(const float3& a, const float3& b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } + }; + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float3 block_variances = + BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); + + variance = block_variances.x; + variance2 = block_variances.y; + variance3 = block_variances.z; + + __shared__ float s_w2_inv_std; + __shared__ float s_w1_inv_std2; + __shared__ float s_w0_inv_std3; + __shared__ float s_bias; + + if (threadIdx.x == 0) { + float w0 = (float)weight[0]; + float w1 = (float)weight[1]; + float w2 = (float)weight[2]; + s_bias = (float)bias[0]; + + s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); + s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); + s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); + } + __syncthreads(); + + auto* __restrict__ out_v = reinterpret_cast<_f16VecPN*>(out); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16VecPN temp = input_v[id]; + temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias); + out_v[id] = temp; + } +} + +/* Generic poly_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [3] + const scalar_t* __restrict__ bias, // [1] + const float epsilon, const int hidden_size) { + float variance = 0.0f; + float variance2 = 0.0f; + float variance3 = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float x2 = x * x; + float x4 = x2 * x2; + float x6 = x4 * x2; + + variance += x2; + variance2 += x4; + variance3 += x6; + } + + float3 thread_variances = make_float3(variance, variance2, variance3); + + struct SumOp { + __device__ float3 operator()(const float3& a, const float3& b) const { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + } + }; + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float3 block_variances = + BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); + + variance = block_variances.x; + variance2 = block_variances.y; + variance3 = block_variances.z; + + __shared__ float s_w2_inv_std; + __shared__ float s_w1_inv_std2; + __shared__ float s_w0_inv_std3; + __shared__ float s_bias; + + if (threadIdx.x == 0) { + float w0 = (float)weight[0]; + float w1 = (float)weight[1]; + float w2 = (float)weight[2]; + s_bias = (float)bias[0]; + + s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); + s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); + s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float x2 = x * x; + float x3 = x2 * x; + + out[blockIdx.x * hidden_size + idx] = + (scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 + + s_bias); + } +} + } // namespace vllm void rms_norm(torch::Tensor& out, // [..., hidden_size] @@ -219,3 +419,49 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] LAUNCH_FUSED_ADD_RMS_NORM(0); } } + +#define LAUNCH_FUSED_POLY_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \ + vllm::poly_norm_kernel<<>>( \ + out.data_ptr(), input.data_ptr(), \ + weight.data_ptr(), bias.data_ptr(), epsilon, \ + hidden_size); \ + }); + +void poly_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [3] + torch::Tensor& bias, // [1] + double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.data_ptr() != input.data_ptr()); + + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto out_ptr = reinterpret_cast(out.data_ptr()); + bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_POLY_NORM(8); + } else { + LAUNCH_FUSED_POLY_NORM(0); + } +} diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0fd5849d9626c..be134089bd6d4 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -8,16 +8,11 @@ #include "type_convert.cuh" #include "quantization/fp8/common.cuh" #include "dispatch_utils.h" +#include "cub_helpers.h" #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif - namespace vllm { // TODO(woosuk): Further optimize this kernel. @@ -39,7 +34,7 @@ __global__ void rms_norm_static_fp8_quant_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -100,7 +95,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -149,7 +144,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 78f7b3cc1aa25..b5321f748e6be 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -21,6 +21,7 @@ #include #include #include +#include #include #include namespace cg = cooperative_groups; @@ -410,14 +411,21 @@ __device__ inline float cuda_cast(__nv_bfloat16 val) { return __bfloat162float(val); } +template +__device__ inline T neg_inf() { + // cuda::std::numeric_limits::infinity() returns `0` for [T=bf16 or fp16] + // so we need to cast from fp32 + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + template __device__ void topk_with_k2(T* output, T const* input, cg::thread_block_tile<32> const& tile, int32_t const lane_id, int const num_experts_per_group) { // Get the top2 per thread - T largest = -INFINITY; - T second_largest = -INFINITY; + T largest = neg_inf(); + T second_largest = neg_inf(); if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { @@ -512,8 +520,8 @@ __global__ void group_idx_and_topk_idx_kernel( warp_id * topk; s_topk_idx += warp_id * topk; - T value = cuda::std::numeric_limits::min(); - T topk_group_value = cuda::std::numeric_limits::min(); + T value = neg_inf(); + T topk_group_value = neg_inf(); int32_t num_equalto_topkth_group; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -524,11 +532,8 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { // calculate group_idx int32_t target_num_min = WARP_SIZE - n_group + topk_group; - if (lane_id < n_group && - (isfinite(cuda_cast( - group_scores[lane_id])))) // The check is necessary to avoid - // abnormal input - { + // The check is necessary to avoid abnormal input + if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) { value = group_scores[lane_id]; } @@ -539,11 +544,11 @@ __global__ void group_idx_and_topk_idx_kernel( __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { - value = cuda::std::numeric_limits::min(); + value = neg_inf(); } pre_count_equal_to_top_value = count_equal_to_top_value; - count_equal_to_top_value = __popc(__ballot_sync( - FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + count_equal_to_top_value = + __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf()))); } num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; } @@ -551,11 +556,10 @@ __global__ void group_idx_and_topk_idx_kernel( warp_topk::WarpSelect - queue((int32_t)topk, -INFINITY); + queue((int32_t)topk, neg_inf()); int count_equalto_topkth_group = 0; - bool if_proceed_next_topk = - (topk_group_value != cuda::std::numeric_limits::min()); + bool if_proceed_next_topk = topk_group_value != neg_inf(); if (case_id < num_tokens && if_proceed_next_topk) { for (int i_group = 0; i_group < n_group; i_group++) { if ((group_scores[i_group] > topk_group_value) || @@ -565,10 +569,10 @@ __global__ void group_idx_and_topk_idx_kernel( for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { T candidates = - (i < num_experts_per_group) && isfinite(cuda_cast( - scores_with_bias[offset + i])) + (i < num_experts_per_group) && + cuda::std::isfinite(scores_with_bias[offset + i]) ? scores_with_bias[offset + i] - : cuda::std::numeric_limits::min(); + : neg_inf(); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) { @@ -597,7 +601,8 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { s_topk_value[i] = value; } - topk_sum += reduce(tile, cuda_cast(value), cg::plus()); + topk_sum += + cg::reduce(tile, cuda_cast(value), cg::plus()); } } diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index cd80bfda7dfde..53573ada86ba9 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -20,17 +20,7 @@ #include #include #include "../cuda_compat.h" - -#ifndef USE_ROCM - #include - #include - #include - using AddOp = cuda::std::plus; -#else - #include - #include - using AddOp = cub::Sum; -#endif +#include "../cub_helpers.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -79,7 +69,7 @@ __launch_bounds__(TPB) __global__ threadData = max(static_cast(input[idx]), threadData); } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp()); if (threadIdx.x == 0) { float_max = maxElem; @@ -94,7 +84,7 @@ __launch_bounds__(TPB) __global__ threadData += exp((static_cast(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp()); + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp()); if (threadIdx.x == 0) { diff --git a/csrc/ops.h b/csrc/ops.h index 7a176a5c00322..fd9c55b948959 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,9 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + torch::Tensor& bias, double epsilon); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, @@ -119,24 +122,23 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); -void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - std::optional key, - int64_t head_size, torch::Tensor& cos_sin_cache, - bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets); - void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +#ifndef USE_ROCM void silu_and_mul_nvfp4_quant(torch::Tensor& out, torch::Tensor& output_block_scale, torch::Tensor& input, torch::Tensor& input_global_scale); #endif +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens); void mul_and_silu(torch::Tensor& out, torch::Tensor& input); @@ -345,6 +347,8 @@ std::tuple allocate_shared_buffer_and_handle( int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); +torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace); + #ifdef USE_ROCM fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size = std::nullopt); @@ -354,4 +358,4 @@ void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); -#endif \ No newline at end of file +#endif diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 266f2a0667a24..b5645b33b9073 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel( token_idx, query_stride, key_stride, head_stride); } -template -__global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or - // [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, - // head_size] or [num_tokens, num_heads, - // head_size] - scalar_t* __restrict__ key, // nullptr or - // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int64_t head_stride, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = - cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - - apply_rotary_embedding( - query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride, head_stride); -} - } // namespace vllm void rotary_embedding( @@ -211,96 +182,3 @@ void rotary_embedding( } }); } - -/* -Batched version of rotary embedding, pack multiple LoRAs together -and process in batched manner. -*/ -void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] - std::optional - key, // null or - // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] - int64_t head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, int64_t rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size] -) { - // num_tokens = batch_size * seq_len - int64_t num_tokens = cos_sin_cache_offsets.size(0); - TORCH_CHECK( - positions.size(0) == num_tokens || positions.numel() == num_tokens, - "positions must have the same num_tokens or batch_size as " - "cos_sin_cache_offsets"); - - int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key - TORCH_CHECK( - positions_ndim == 1 || positions_ndim == 2, - "positions must have shape [num_tokens] or [batch_size, seq_len]"); - if (positions_ndim == 1) { - TORCH_CHECK(query.size(0) == positions.size(0) && - (!key.has_value() || key->size(0) == positions.size(0)), - "query, key and positions must have the same number of tokens"); - } - if (positions_ndim == 2) { - TORCH_CHECK( - query.size(0) == positions.size(0) && - (!key.has_value() || key->size(0) == positions.size(0)) && - query.size(1) == positions.size(1) && - (!key.has_value() || key->size(1) == positions.size(1)), - "query, key and positions must have the same batch_size and seq_len"); - } - - // Make sure head_size is valid for query and key - int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; - TORCH_CHECK(query_hidden_size % head_size == 0); - TORCH_CHECK(key_hidden_size % head_size == 0); - - // Make sure query and key have concistent number of heads - int num_heads = query_hidden_size / head_size; - int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; - TORCH_CHECK(num_heads % num_kv_heads == 0); - - int seq_dim_idx = positions_ndim - 1; - int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; - // Determine head stride: for [*, heads, head_size] use stride of last dim; - // for flat [*, heads*head_size], heads blocks are contiguous of size - // head_size - int query_ndim = query.dim(); - int64_t head_stride = - (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size); - } else { - vllm::batched_rotary_embedding_kernel - <<>>( - positions.data_ptr(), query.data_ptr(), - key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size); - } - }); -} diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 8bc2b9bff3d5a..9aa1411b4a25c 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -9,6 +9,26 @@ #include "quantization/fp8/common.cuh" +#include + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + #include + #include + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16_raw __nv_bfloat16_raw; + +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3; +#endif + +#include "core/registration.h" namespace vllm { template @@ -87,6 +107,336 @@ __global__ void act_and_mul_quant_kernel( } } } + +__device__ __forceinline__ float silu(float x) { + return (__fdividef(x, (1.f + expf(-x)))); +} + +__device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); +} + +#ifndef USE_ROCM +__device__ __forceinline__ float warp_max(float v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} + +__device__ __forceinline__ __nv_bfloat16 warp_max(__nv_bfloat16 v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = __hmax(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} +#endif + +template +__device__ __forceinline__ void cp_async4(T* _smem_ptr, const U* _glob_ptr) { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + auto smem_ptr = reinterpret_cast(_smem_ptr); + auto glob_ptr = reinterpret_cast(_glob_ptr); + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +#else + _smem_ptr[0] = _glob_ptr[0]; +#endif +} + +__device__ __forceinline__ void cp_async_fence() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;\n" ::); +#else +#endif +} + +template +__device__ __forceinline__ void cp_async_wait() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#else +#endif +} + +template <> +__device__ __forceinline__ void cp_async_wait<0>() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_all;\n" ::); +#else +#endif +} + +__device__ __forceinline__ float clip(float v, float mmin, float mmax) { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + return fminf(mmax, fmaxf(v, mmin)); +#else +#endif +} + +__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v, + __nv_bfloat16 mmin, + __nv_bfloat16 mmax) { + return __hmin(mmax, __hmax(v, mmin)); +} + +__device__ __forceinline__ __nv_bfloat162 clip(__nv_bfloat162 v, + __nv_bfloat162 mmin, + __nv_bfloat162 mmax) { + return __hmin2(mmax, __hmax2(v, mmin)); +} + +// We use the following values for fp8 min/max: +// __nv_fp8_e4m3 = (-448, +448) +// __nv_fp8_e4m3uz = (-240.0, +240.0) +// It is currently assumed that only +template +constexpr __nv_bfloat16 get_fp8_max() { + static_assert(std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 17376}); + } else { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 17264}); + } +} + +template +constexpr __nv_bfloat16 get_fp8_min() { + static_assert(std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 50144}); + } else { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032}); + } +} +#ifndef USE_ROCM +template +__global__ void silu_mul_fp8_quant_deep_gemm_kernel( + const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, + float* __restrict__ _y_s, const int32_t* __restrict__ counts, + + // sizes + int H, int G, + + // strides (in elements) + Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, + Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, + Idx_t stride_ys_g, Idx_t stride_counts_e) { + static constexpr __nv_bfloat16 fp8_min = get_fp8_min(); + static constexpr __nv_bfloat16 fp8_max = get_fp8_max(); + // We assign EPS with its 16-bit unsigned counterpart to allow constexpr. + static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996}); + + // We pack 8 16-bit bfloat16 values into a 128-bit __int128_t. + static constexpr int32_t BFLOAT16_PER_GROUP = 8; + + // We split the shared memory in half, corresponding to gate and up matrices: + // [...gate_i, ...up_i] where 0 <= i < stages. + static constexpr int32_t S_NUM_128 = + 2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES; + static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE; + static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2; + static constexpr int32_t S_NUM_64 = S_NUM_128 * 2; + __shared__ __int128_t __align__(16) s_buff_128[S_NUM_128]; + + const int32_t tid = threadIdx.x; + const int32_t warp_id = tid / WARP_SIZE; + const int32_t lane_id = tid % WARP_SIZE; + + auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128); + + // block handles one (expert e, group g) + int32_t pid = blockIdx.x; + int32_t e = pid / G; + int32_t g = pid % G; + + const int32_t n_tokens = counts[e * stride_counts_e]; + + if (!n_tokens) { + return; // Exit ASAP. + } + + const Idx_t stride_i_t_128 = stride_i_t / 8u; + + int32_t n_tokens_lower, n_tokens_upper; + + // Each block i iterates over tokens of a slice of n_tokens = + // expert_counts[i], with the size of chunk being + // (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of + // updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling. + if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) { + // Specialize this, but can be likely fused. + if (blockIdx.y >= NUM_PARALLEL_TOKENS) { + return; + } + n_tokens_lower = blockIdx.y; + n_tokens_upper = blockIdx.y + 1; + } else { + auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS; + auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS; + auto calc_id = [&](int32_t id) { + if (id < residual) { + return min(n_tokens, id * (chunk_size + 1)); + } else { + return min(n_tokens, id * chunk_size + residual); + } + }; + n_tokens_lower = calc_id(blockIdx.y); + n_tokens_upper = calc_id(blockIdx.y + 1); + } + + if (n_tokens_lower >= n_tokens_upper) { + return; + } + + // We do calculations here, using constexpr wherever possible. + const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h; + const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g; + const Idx_t base_yq = + e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h; + Idx_t gate_off_128 = (base_i / static_cast(8u)); + auto input_128_ptr = reinterpret_cast(_input); + auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) + + stride_i_t_128 * n_tokens_lower; + auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u; + auto y_s_ptr = + _y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t; + auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE + + stride_yq_t * n_tokens_lower + 4 * lane_id; + int32_t t_load = n_tokens_lower, load_stage_id = 0; + auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT); + auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u; + int32_t stage_offset{}; + + static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2); + static constexpr int32_t LOAD_STAGE_MOD = + NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2); + + // Two halves of all threads in a block conduct global loads for gate and up, + // repsectively. + auto load_and_advance_y_pred = [&] { + if (t_load < n_tokens_upper) { + auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset; + auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset; + + // It is very important that LOAD_STAGE_SIZE is constexpr to avoid + // unnecessary ALU ops. + stage_offset += LOAD_STAGE_SIZE; + stage_offset %= LOAD_STAGE_MOD; + + if (tid < HALF_THREAD_COUNT) { + cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr); + gate_128_ptr += stride_i_t_128; + } else { + cp_async4(s_up_stage_128_staged_ptr, up_128_ptr); + up_128_ptr += stride_i_t_128; + } + ++t_load; + ++load_stage_id; + } + // We fence even if there is nothing to load to simplify pipelining. + cp_async_fence(); + }; + + #pragma unroll + for (int i = 0; i < NUM_STAGES - 1; i++) { + load_and_advance_y_pred(); + } + + __int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>( + s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) + + lane_id; + __int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2; + + static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u; + static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES; + + int32_t compute_pipeline_offset_64 = 0; + + for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) { + __nv_bfloat162 results_bf162[2]; + + cp_async_wait(); + __syncthreads(); + + // We double-buffer pipelined loads so that the next load will + // concurrently run with compute without overwrites. + load_and_advance_y_pred(); + + auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64; + auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64; + + // STAGE_SIZE must also be constexpr! + compute_pipeline_offset_64 += STAGE_SIZE; + compute_pipeline_offset_64 %= STAGE_MOD; + + // Each thread loads (gate/up) 2X 4X bfloat16 values into registers. + __int64_t gate64 = *s_gate_compute_64; + __nv_bfloat162* s_gate_compute_32 = + reinterpret_cast<__nv_bfloat162*>(&gate64); + + __int64_t up64 = *s_up_compute_64; + __nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64); + + #pragma unroll + for (int i = 0; i < 2; i++) { + // For silu, we make sure that div is emitted. + float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i])); + results_bf162[i] = __float22bfloat162_rn(gate); + } + + #pragma unroll + for (int i = 0; i < 2; i++) { + results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]); + } + + auto _y_max2 = + __hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1])); + + __nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y)); + + // An entire group is assigned to a single warp, so a simple warp reduce + // is used. + __nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max; + + if constexpr (USE_UE8M0) { + y_s = hexp2(hceil(hlog2(y_s))); + } + + auto inv_y = __float2bfloat16_rn(1.f) / y_s; + + auto y_s2 = make_bfloat162(inv_y, inv_y); + + #pragma unroll + for (int32_t i = 0; i < 2; ++i) { + results_bf162[i] = + clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min), + __bfloat162bfloat162(fp8_max)); + } + + auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]); + *reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4; + y_q_ptr += stride_yq_t; + + if (lane_id == 0) { + *y_s_ptr = y_s; + y_s_ptr += stride_ys_t; + } + } +} +#endif + } // namespace vllm // Launch activation, gating, and quantize kernel. @@ -119,3 +469,117 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } + +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) { +#ifndef USE_ROCM + // This kernel relies heavily on cp.async and fp8 support. + // This kernel currently only supports H % 128 == 0 and assumes a + // fixed GROUP_SIZE of 128. + TORCH_CHECK(input.dtype() == torch::kBFloat16); + TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || + y_q.dtype() == torch::kFloat8_e4m3fnuz); + TORCH_CHECK(y_s.dtype() == torch::kFloat32); + TORCH_CHECK(input.size(-1) % 256 == 0); + + // Check that num_parallel_tokens is of power of 2 and between 1 and 64. + TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64); + TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1))); + + using Idx_t = int64_t; + + Idx_t E = input.size(0); + Idx_t T = input.size(1); + Idx_t H = input.size(2) / 2; + Idx_t stride_i_e = input.stride(0); + Idx_t stride_i_t = input.stride(1); + Idx_t stride_i_h = input.stride(2); + Idx_t stride_yq_e = y_q.stride(0); + Idx_t stride_yq_t = y_q.stride(1); + Idx_t stride_yq_h = y_q.stride(2); + Idx_t stride_ys_e = y_s.stride(0); + Idx_t stride_ys_t = y_s.stride(1); + Idx_t stride_ys_g = y_s.stride(2); + + Idx_t stride_counts_e = counts.stride(0); + + static constexpr int GROUP_SIZE = 128; + + #define KERNEL_FN \ + if (use_ue8m0) { \ + vllm::silu_mul_fp8_quant_deep_gemm_kernel \ + <<>>( \ + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ + (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + reinterpret_cast(counts.data_ptr()), H, G, \ + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ + stride_counts_e); \ + } else { \ + vllm::silu_mul_fp8_quant_deep_gemm_kernel \ + <<>>( \ + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ + (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + reinterpret_cast(counts.data_ptr()), H, G, \ + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ + stride_counts_e); \ + } + + #define KERNEL_CALL_H \ + if (H % (4 * GROUP_SIZE) == 0) { \ + static constexpr int NUM_WARPS = 4; \ + populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ + KERNEL_FN \ + } else { \ + static constexpr int NUM_WARPS = 1; \ + populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ + KERNEL_FN \ + } + + #define KERNEL_CALL_TOP_LEVEL \ + if (num_parallel_tokens == 1) { \ + static constexpr int NUM_PARALLEL_TOKENS = 1; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 2) { \ + static constexpr int NUM_PARALLEL_TOKENS = 2; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 4) { \ + static constexpr int NUM_PARALLEL_TOKENS = 4; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 8) { \ + static constexpr int NUM_PARALLEL_TOKENS = 8; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 16) { \ + static constexpr int NUM_PARALLEL_TOKENS = 16; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 32) { \ + static constexpr int NUM_PARALLEL_TOKENS = 32; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 64) { \ + static constexpr int NUM_PARALLEL_TOKENS = 64; \ + KERNEL_CALL_H \ + } + + Idx_t G; + dim3 block, grid; + auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) { + G = H / Idx_t(group_size * num_warps); + grid = dim3(E * G, _num_parallel_tokens); + block = dim3(num_warps * WARP_SIZE); + }; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(), + "silu_mul_fp8_quant_deep_gemm_kernel", + [&] { KERNEL_CALL_TOP_LEVEL }); + +#endif +} diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index d8369108d0bd3..bcfde9fbcbbef 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -7,17 +7,10 @@ #include +#include "../../cub_helpers.h" #include "../../dispatch_utils.h" #include "../vectorization_utils.cuh" -#ifndef USE_ROCM - #include - #include -#else - #include - #include -#endif - static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM static constexpr auto i8_min = @@ -173,7 +166,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( }); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; - float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x); + float block_max = BlockReduce(tmp).Reduce(thread_max, CubMaxOp{}, blockDim.x); __shared__ float absmax; if (tid == 0) { absmax = block_max; diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu index fdac47c425d61..2d1568b08651c 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -11,6 +11,7 @@ #include "core/registration.h" #include "cutlass/cutlass.h" +#include #include "cute/tensor.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" @@ -24,6 +25,8 @@ #include "cutlass_extensions/common.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include + namespace vllm::cutlass_w4a8 { using namespace cute; @@ -169,6 +172,11 @@ struct W4A8GemmKernel { int k = A.size(1); int n = B.size(1); + // safely cast group_size to int + TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits::max(), + "group_size out of supported range for int: ", group_size); + int const group_size_int = static_cast(group_size); + // Allocate output const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); auto device = A.device(); @@ -181,7 +189,7 @@ struct W4A8GemmKernel { auto A_ptr = static_cast(A.const_data_ptr()); auto B_ptr = static_cast(B.const_data_ptr()); auto D_ptr = static_cast(D.data_ptr()); - // can we avoid harcode the 8 here + // can we avoid hardcode the 8 here auto S_ptr = static_cast const*>( group_scales.const_data_ptr()); @@ -192,7 +200,7 @@ struct W4A8GemmKernel { cute::tile_to_shape(LayoutAtomQuant{}, shape_B); // strides - int const scale_k = cutlass::ceil_div(k, group_size); + int const scale_k = cutlass::ceil_div(k, group_size_int); StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); // Reverse stride here due to swap and transpose @@ -211,8 +219,8 @@ struct W4A8GemmKernel { using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; MainloopArguments mainloop_arguments{ - B_ptr, layout_B_reordered, A_ptr, stride_A, - S_ptr, stride_S, group_size}; + B_ptr, layout_B_reordered, A_ptr, stride_A, + S_ptr, stride_S, group_size_int}; EpilogueArguments epilogue_arguments{ ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), @@ -387,6 +395,71 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { return packed_scales; } +/* + GPU-accelerated implementation of cutlass::unified_encode_int4b. + Constructs a lookup table in constant memory to map 8 bits + (two 4-bit values) at a time. Assumes memory is contiguous + and pointers are 16-byte aligned. +*/ +__constant__ uint8_t kNibbleLUT[256]; + +__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out, + size_t nbytes) { + constexpr size_t V = sizeof(uint4); // 16 bytes + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t nthreads = size_t(gridDim.x) * blockDim.x; + const size_t nvec = nbytes / V; + + // 1-D grid-stride loop over 16-byte chunks + for (size_t vec = tid; vec < nvec; vec += nthreads) { + uint4 v = reinterpret_cast(in)[vec]; + uint8_t* b = reinterpret_cast(&v); +#pragma unroll + for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]]; + reinterpret_cast(out)[vec] = v; + } +} + +static bool upload_lut() { + std::array lut{}; + auto map_nib = [](uint8_t v) -> uint8_t { + // 1..7 -> (8 - v); keep 0 and 8..15 + return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v); + }; + for (int b = 0; b < 256; ++b) { + uint8_t lo = b & 0xF; + uint8_t hi = (b >> 4) & 0xF; + lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo)); + } + cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(), + /*offset=*/0, cudaMemcpyHostToDevice); + + return (e == cudaSuccess); +} + +static bool unified_encode_int4b(cutlass::int4b_t const* in, + cutlass::int4b_t* out, size_t num_int4_elems) { + // Build/upload LUT + if (!upload_lut()) return false; + + static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1, + "int4 storage must be 1 byte"); + const size_t nbytes = num_int4_elems >> 1; + + auto* in_bytes = reinterpret_cast(in); + auto* out_bytes = reinterpret_cast(out); + + // kernel launch params + constexpr int block = 256; + const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors + int grid = int((nvec + block - 1) / block); + if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel + + unified_encode_int4b_device<<>>(in_bytes, out_bytes, nbytes); + cudaError_t err = cudaGetLastError(); + return (err == cudaSuccess); +} + torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { TORCH_CHECK(B.dtype() == torch::kInt32); TORCH_CHECK(B.dim() == 2); @@ -395,6 +468,7 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { int k = B.size(0) * PackFactor; // logical k int n = B.size(1); + TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks"); auto B_ptr = static_cast(B.const_data_ptr()); auto B_packed_ptr = static_cast(B_packed.data_ptr()); @@ -403,7 +477,9 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { LayoutB_Reordered layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B); - cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + bool ok = + vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + TORCH_CHECK(ok, "unified_encode_int4b failed"); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); return B_packed; diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index c841125dbb734..dbf79a0651159 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -14,9 +14,6 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" - #include "cutlass_gemm_caller.cuh" namespace vllm { @@ -149,6 +146,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; + using ElementBlockScale = typename Gemm::ElementBlockScale; int32_t m = a.size(0), n = b.size(1), k = a.size(1); @@ -169,26 +167,29 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); - auto mainloop_args = [&](){ - // layout_SFA and layout_SFB cannot be swapped since they are deduced. - if (swap_ab) { - return typename GemmKernel::MainloopArguments{ - b_ptr, b_stride, a_ptr, a_stride, - b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB - }; - } - else { - return typename GemmKernel::MainloopArguments{ - a_ptr, a_stride, b_ptr, b_stride, - a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB - }; - } - }(); + typename GemmKernel::MainloopArguments mainloop_args{}; + mainloop_args.layout_SFA = layout_SFA; + mainloop_args.layout_SFB = layout_SFB; + if (swap_ab) { + mainloop_args.ptr_A = b_ptr; + mainloop_args.dA = b_stride; + mainloop_args.ptr_B = a_ptr; + mainloop_args.dB = a_stride; + mainloop_args.ptr_SFA = b_scales_ptr; + mainloop_args.ptr_SFB = a_scales_ptr; + } else { + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.ptr_SFB = b_scales_ptr; + } auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh index d50a83ae1cd48..811741aee58b3 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -14,9 +14,6 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" - #include "cutlass_gemm_caller.cuh" namespace vllm { @@ -128,6 +125,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; + using ElementBlockScale = typename Gemm::ElementBlockScale; int32_t m = a.size(0), n = b.size(1), k = a.size(1); @@ -146,17 +144,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); - auto mainloop_args = [&](){ - return typename GemmKernel::MainloopArguments{ - a_ptr, a_stride, b_ptr, b_stride, - a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB - }; - }(); + typename GemmKernel::MainloopArguments mainloop_args{}; + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.layout_SFA = layout_SFA; + mainloop_args.ptr_SFB = b_scales_ptr; + mainloop_args.layout_SFB = layout_SFB; auto prob_shape = cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index e089c3d4be2cc..147eb8efc0778 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -13,27 +13,18 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/dispatch_policy.hpp" -#include "cutlass_extensions/gemm/collective/collective_builder.hpp" - #include "cutlass_gemm_caller.cuh" namespace vllm { using namespace cute; -template > +// clang-format off +template struct cutlass_3x_gemm_fp8_blockwise { - using GroupSizeM = Int; - using GroupSizeN = Int; - using GroupSizeK = Int; - using TileSizeM = Int; - - static_assert(TileSizeM_ % GroupSizeM_ == 0, - "TileSizeM must be a multiple of GroupSizeM"); - using ElementAB = cutlass::float_e4m3_t; using ElementA = ElementAB; @@ -45,52 +36,67 @@ struct cutlass_3x_gemm_fp8_blockwise { static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; using ElementD = OutType; - using StrideD = Stride, Int<0>>; + using LayoutD = cutlass::layout::RowMajor; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - using ElementC = void; - using StrideC = StrideD; + using ElementC = void; // TODO: support bias + using LayoutC = LayoutD; static constexpr int AlignmentC = AlignmentD; using ElementAccumulator = float; - using ElementBlockScale = float; using ElementCompute = float; + using ElementBlockScale = float; + + using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>; + + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; - using TileShape = Shape; - using KernelSchedule = cutlass::gemm:: - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< - GroupSizeM_>; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; - using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; - using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< - cutlass::epilogue::fusion::Sm90AccFetch>; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC, - ElementD, StrideD, AlignmentD, EpilogueSchedule, - StoreEpilogueCompute>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, - LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp; using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - SchedulerType>>; + Shape, CollectiveMainloop, CollectiveEpilogue>>; struct GemmKernel : public KernelType {}; - - using StrideA = typename GemmKernel::StrideA; - using StrideB = typename GemmKernel::StrideB; }; template @@ -99,76 +105,58 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; + using ElementBlockScale = typename Gemm::ElementBlockScale; - auto prob_shape = c3x::get_problem_shape(a, b); - int32_t m = get<0>(prob_shape), n = get<1>(prob_shape), - k = get<2>(prob_shape); + int32_t m = a.size(0), n = b.size(1), k = a.size(1); - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); + TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto a_scales_ptr = static_cast(a_scales.data_ptr()); - auto b_scales_ptr = static_cast(b_scales.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); - // Check is the t is contiguous and is 1D or 2D with one of the dimensions - // being 1 (i.e. a row or column vector) - auto is_contiguous_vector = [](const torch::Tensor& t) { - auto t_sizes = t.sizes(); - return t.is_contiguous() && - (t.dim() == 1 || - (t.dim() == 2 && - *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); - }; - - // TODO(lucas): lets clean-up the kernel so that we pass in Strides so - // we don't have to deal with enforcing implicit layouts - TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value); - TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value); - TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales), - "a_scales must be M major"); - TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value); - TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value); - TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales), - "b_scales must be K major"); - typename GemmKernel::MainloopArguments mainloop_args{ - a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr}; + typename GemmKernel::MainloopArguments mainloop_args{}; + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.layout_SFA = layout_SFA; + mainloop_args.ptr_SFB = b_scales_ptr; + mainloop_args.layout_SFB = layout_SFB; + auto prob_shape = cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::TileSchedulerArguments scheduler; - - static constexpr bool UsesStreamKScheduler = - cute::is_same_v; - - if constexpr (UsesStreamKScheduler) { - using DecompositionMode = typename cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::DecompositionMode; - using ReductionMode = typename cutlass::gemm::kernel::detail:: - PersistentTileSchedulerSm90StreamKParams::ReductionMode; - - scheduler.decomposition_mode = DecompositionMode::StreamK; - scheduler.reduction_mode = ReductionMode::Nondeterministic; - } - c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, - epilogue_args, scheduler); + epilogue_args); } template @@ -177,18 +165,12 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - auto k = a.size(1); - auto n = b.size(1); - - if (k > 3 * n) { - cutlass_gemm_caller_blockwise>( - out, a, b, a_scales, b_scales); - } else { - cutlass_gemm_caller_blockwise>( - out, a, b, a_scales, b_scales); - } + // TODO: better heuristics + cutlass_gemm_caller_blockwise, + Shape<_1, _2, _1>, cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>>( + out, a, b, a_scales, b_scales); } } // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp index 2ee6a19407f92..3af59267bd60c 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -32,7 +32,7 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); int32_t version_num = get_sm_version_num(); - if (version_num >= 100) { + if (version_num >= 90) { TORCH_CHECK( a.size(0) == a_scales.size(0) && cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), @@ -41,32 +41,6 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), "b_scale_group_shape must be [128, 128]."); - } else { - // TODO: Remove this after using cutlass sm90 blockwise scaling gemm - // kernel, or introducing ceil_div to the load_init() of mainloop. - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; - - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); } TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 9bbeb0334fb9a..74fde23782ce5 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -26,271 +26,47 @@ #include "dispatch_utils.h" #include "cuda_utils.h" +#include "nvfp4_utils.cuh" namespace vllm { -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = c10::Half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = c10::BFloat16; -}; - -template <> -struct TypeConverter { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif +// silu in float32 +__device__ __forceinline__ float silu(float x) { + return __fdividef(x, (1.f + __expf(-x))); } -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif +__device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); } -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - template -__inline__ __device__ PackedVec compute_silu(PackedVec& vec, - PackedVec& vec2) { +__inline__ __device__ PackedVec compute_silu_mul(PackedVec& vec, + PackedVec& vec2) { PackedVec result; + using packed_type = typename TypeConverter::Type; + #pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { - if constexpr (std::is_same_v) { - half2 val(0.5f, 0.5f); - half2 t0 = __hmul2(vec.elts[i], val); - half2 t1 = __hfma2(h2tanh(t0), val, val); - half2 t2 = __hmul2(vec.elts[i], t1); - result.elts[i] = __hmul2(t2, vec2.elts[i]); + // silu_mul in float32 + if constexpr (std::is_same_v) { + float2 silu_vec = silu2(__half22float2(vec.elts[i])); + result.elts[i] = + __float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i]))); } else { - __nv_bfloat162 val(0.5f, 0.5f); - __nv_bfloat162 t0 = __hmul2(vec.elts[i], val); - __nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val); - __nv_bfloat162 t2 = __hmul2(vec.elts[i], t1); - result.elts[i] = __hmul2(t2, vec2.elts[i]); + float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i])); + result.elts[i] = __float22bfloat162_rn( + __fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i]))); } } return result; } -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, - PackedVec& vec2, - float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - PackedVec out_silu = compute_silu(vec, vec2); - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(out_silu.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(out_silu.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} - // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4( -#else -silu_and_cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(1024, 4) + silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, + uint32_t* SFout) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -316,34 +92,39 @@ silu_and_cvt_fp16_to_fp4( // Get the output tensor offset. // Same as inOffset because 8 elements are packed into one uint32_t. int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - ; auto& out_pos = out[outOffset]; + // Compute silu and mul + PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); + auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( rowIdx, colIdx, numCols, SFout); - out_pos = silu_and_cvt_warp_fp16_to_fp4( - in_vec, in_vec2, SFScaleVal, sf_out); + out_pos = cvt_warp_fp16_to_fp4(out_silu_mul, SFScaleVal, + sf_out); } } -#endif } } // namespace vllm -void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] - torch::Tensor& output_sf, - torch::Tensor& input, // [..., 2 * d] - torch::Tensor& input_sf) { - TORCH_CHECK(input.dtype() == torch::kFloat16 || - input.dtype() == torch::kBFloat16); +void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] + torch::Tensor& output_sf, + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& input_sf) { int32_t m = input.size(0); int32_t n = input.size(1) / 2; + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); + int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + auto input_sf_ptr = static_cast(input_sf.data_ptr()); auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); @@ -352,17 +133,14 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); int const numBlocksPerSM = 2048 / block.x; dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + VLLM_DISPATCH_HALF_TYPES( - input.scalar_type(), "act_and_mul_quant_kernel", [&] { - auto input_ptr = reinterpret_cast(input.data_ptr()); - VLLM_DISPATCH_BYTE_TYPES( - output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type", - [&] { - vllm::silu_and_cvt_fp16_to_fp4 - <<>>( - m, n, input_ptr, input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); + input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::silu_mul_cvt_fp16_to_fp4<<>>( + m, n, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); }); } diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index 03db5cc196d59..2c8df6144bf4d 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include #include diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 190d66f318a83..ce3ba2c19b9eb 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -1,247 +1,42 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include +#include +#include + #include #include -#include #include +#include "dispatch_utils.h" -template -struct TypeConverter { - using Type = half2; -}; // keep for generality +#include "nvfp4_utils.cuh" -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} +namespace vllm { // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(512, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(512, 4) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts, + bool low_latency) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -299,8 +94,8 @@ cvt_fp16_to_fp4( &input_offset_by_experts[chunk_start + 12])); local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); - // Check against the 16 loaded offsets - #pragma unroll +// Check against the 16 loaded offsets +#pragma unroll for (int i = 0; i < 16; i++) { if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { rowIdx_in_expert = rowIdx - local_offsets[i]; @@ -330,21 +125,15 @@ cvt_fp16_to_fp4( out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } -#endif } // Kernel for LARGE_M_TOPK = true (large m_topk optimized version) template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(1024, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(1024, 4) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -425,7 +214,6 @@ cvt_fp16_to_fp4( out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } -#endif } template @@ -501,6 +289,8 @@ void quant_impl(void* output, void* output_scale, void* input, } } +} // namespace vllm + /*Quantization entry for fp4 experts quantization*/ #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") #define CHECK_CONTIGUOUS(x, m) \ @@ -560,23 +350,17 @@ void scaled_fp4_experts_quant_sm100a( // 4 means 4 fp8 values are packed into one int32 TORCH_CHECK(output_scale.size(1) * 4 == padded_k); - auto in_dtype = input.dtype(); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); - if (in_dtype == at::ScalarType::Half) { - quant_impl(output.data_ptr(), output_scale.data_ptr(), - input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, k, - n_experts, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { - quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(), - input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, - k, n_experts, stream); - } else { - TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); - } + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::quant_impl( + output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), + input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, + stream); + }); } diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index 1b61bd4519fc3..c2b39e5438805 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -32,6 +32,14 @@ void scaled_fp4_experts_quant_sm100a( torch::Tensor const& output_scale_offset_by_experts); #endif +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, + torch::Tensor& output_sf, + torch::Tensor& input, + torch::Tensor& input_sf); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ @@ -54,3 +62,13 @@ void scaled_fp4_experts_quant( TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel"); } + +void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf, + torch::Tensor& input, torch::Tensor& input_sf) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf); +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled silu_and_mul nvfp4 quantization kernel"); +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 4e080de151648..0c1b9ef0664d7 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -23,245 +23,18 @@ #include #include +#include "dispatch_utils.h" #include "cuda_utils.h" +#include "nvfp4_utils.cuh" -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} +namespace vllm { // Use UE4M3 by default. template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(512, 4) cvt_fp16_to_fp4( -#else -cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__global__ void __launch_bounds__(512, 4) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout) { using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -293,7 +66,6 @@ cvt_fp16_to_fp4( cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); } } -#endif } template @@ -332,6 +104,8 @@ template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, int multiProcessorCount, cudaStream_t stream); +} // namespace vllm + void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, torch::Tensor const& input, torch::Tensor const& output_sf, @@ -340,6 +114,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, int32_t n = input.size(1); TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); @@ -353,24 +130,10 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, // We don't support e8m0 scales at this moment. bool useUE8M0 = false; - switch (input.scalar_type()) { - case torch::kHalf: { - auto input_ptr = reinterpret_cast(input.data_ptr()); - invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, - useUE8M0, multiProcessorCount, stream); - break; - } - case torch::kBFloat16: { - auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); - invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, - useUE8M0, multiProcessorCount, stream); - break; - } - default: { - std::cerr << "Observing: " << input.scalar_type() - << " for the input datatype which is invalid"; - throw std::runtime_error( - "Unsupported input data type for quantize_to_fp4."); - } - } + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, + sf_out, useUE8M0, multiProcessorCount, stream); + }); } diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh new file mode 100644 index 0000000000000..48e4959de9793 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +namespace vllm { + +// Convert PyTorch cpp type to CUDA type +template +struct CUDATypeConverter { + using Type = T; +}; + +template <> +struct CUDATypeConverter { + using Type = half; +}; + +template <> +struct CUDATypeConverter { + using Type = __nv_bfloat16; +}; + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } + return nullptr; +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +} + +} // namespace vllm diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 5fe5dd04bd891..45d6d5082ce49 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,15 +1,10 @@ #include "common.cuh" #include "dispatch_utils.h" +#include "../../cub_helpers.h" #include "../vectorization_utils.cuh" #include #include -#ifndef USE_ROCM - #include -#else - #include -#endif - namespace vllm { template @@ -116,7 +111,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp; const float block_max = - BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x); + BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x); __shared__ float token_scale; if (tid == 0) { diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index 1aad6330c44b8..7838f211c59db 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -5,7 +5,9 @@ #include -#ifdef USE_ROCM +#ifndef USE_ROCM + #include "nvidia/quant_utils.cuh" +#else #include "amd/quant_utils.cuh" #endif @@ -48,7 +50,9 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, float r = fmaxf(-quant_type_max_v, fminf(x, quant_type_max_v)); #ifndef USE_ROCM - return static_cast(r); + // Use hardware cvt instruction for fp8 on nvidia + // Currently only support fp8_type = c10::Float8_e4m3fn + return fp8::vec_conversion(r); #else // Use hardware cvt instruction for fp8 on rocm return fp8::cvt_c10(r); diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index f8cd1dcba4ab3..5b9c2df8468cb 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -12,13 +12,26 @@ namespace vllm { namespace fp8 { #ifdef ENABLE_FP8 - #if 0 // Disable the following code to reduce the binary size. template -__inline__ __device__ Tout -vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { +__inline__ __device__ Tout vec_conversion( + const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) { return x; } +// float -> c10::Float8_e4m3fn +template <> +__inline__ __device__ c10::Float8_e4m3fn +vec_conversion( + const float& a, const __nv_fp8_interpretation_t fp8_type) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return static_cast(a); + #else + return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type), + c10::Float8_e4m3fn::from_bits()); + #endif +} + + #if 0 // Disable the following code to reduce the binary size. // fp8 -> half template <> __inline__ __device__ uint16_t vec_conversion( diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 3f188872d80d3..2d2fd771205c7 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -8,11 +8,7 @@ #include "quantization/utils.cuh" #include "quant_conversions.cuh" -#ifndef USE_ROCM - #include -#else - #include -#endif +#include "../../cub_helpers.h" namespace vllm { @@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); __shared__ float s_rms; if (threadIdx.x == 0) { @@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales( __shared__ typename BlockReduce::TempStorage reduceStore; block_absmax_val_maybe = BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); __shared__ float s_token_scale; if (threadIdx.x == 0) { @@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); __shared__ float s_rms; if (threadIdx.x == 0) { @@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales( __shared__ typename BlockReduce::TempStorage reduceStore; block_absmax_val_maybe = BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); __shared__ float s_token_scale; if (threadIdx.x == 0) { diff --git a/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu new file mode 100644 index 0000000000000..5369d409f9b21 --- /dev/null +++ b/csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu @@ -0,0 +1,817 @@ +// clang-format off +// Adapted from: https://github.com/meta-pytorch/applied-ai/blob/main/kernels/cuda/inference/hadamard_transform/hadamard_transform_cuda.cu + +/*********** +Copyright 2024 Meta + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +***********/ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "core/registration.h" +#include "dispatch_utils.h" + +namespace hadacore { + +#ifndef __CUDACC__ +#define __launch_bounds__(x,y) +#endif + +#define MAX_WARPS_PER_SM 48 + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +using b16 = uint16_t; +using b32 = uint32_t; + +constexpr int launch_configs_big[7][3] = { + // default + {2, 1, 24}, + {2, 2, 16}, + {2, 4, 8}, + {2, 8, 4}, + {2, 16, 3}, + {4, 16, 2}, + {8, 16, 1} + // // extra coalescing + // {2, 1, 24}, + // {2, 2, 16}, + // {2, 4, 8}, + // {2, 8, 4}, + // {4, 8, 3}, + // {8, 8, 2}, + // {16, 8, 1} + // // less coalescing + // {2, 1, 24}, + // {2, 2, 16}, + // {2, 4, 8}, + // {2, 8, 4}, + // {1, 32, 1}, + // {2, 32, 1}, + // {4, 32, 1} +}; + +// a 4x2, b 2x2, c 2x2 +template +__device__ __forceinline__ void mma_m16_n8_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32& c0, b32& c1){ + static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16); + // d, a, b, c + b32 zero = 0; + if constexpr(dtype == torch::ScalarType::Half) { + asm ( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n\t" + : "=r"(c0), "=r"(c1) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(zero), "r"(zero) + ); + } else { + b32 temp0, temp1, temp2, temp3; + asm ( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n\t" + : "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(zero), "r"(zero), "r"(zero), "r"(zero) + ); + asm ("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0)); + asm ("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2)); + } +} + +// a 4x2, b 4x2, c 4x2 +template +__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32 b2, b32 b3, b32& c0, b32& c1, b32& c2, b32& c3){ + mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b0, b1, c0, c1); + mma_m16_n8_k16_b16_b16_b16_noacc(a0, a1, a2, a3, b2, b3, c2, c3); +} + +__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(b32& a0) { + asm ( + "movmatrix.sync.aligned.m8n8.trans.b16 " + "%0, %1;\n\t" + : "=r"(a0) : "r"(a0) + ); +} + +#define p_p(i) ((val_1p[i] & 0x0000FFFF) | val_1p[i] << 16) +#define p_n(i) ((val_1p[i] & 0x0000FFFF) | val_1n[i] << 16) +#define n_p(i) ((val_1n[i] & 0x0000FFFF) | val_1p[i] << 16) +#define n_n(i) ((val_1n[i] & 0x0000FFFF) | val_1n[i] << 16) + +template +__global__ void __launch_bounds__(32 * warps_per_block, blocks_per_sm) +// a is column major, b is row major +hadamard_transform_kernel(b16* a, b16* out, int total_num_chunks) { + static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); + + b32 b_frag_all[num_chunks][4]; // for all chunks, holds matrix fragment (which takes 4 regs of b16x2 * 32 threads) + + int64_t blockid = blockIdx.x * warps_per_block + threadIdx.x / 32; + int64_t threadid = threadIdx.x % 32; + extern __shared__ b32 bfrag_arr[]; // num_chunks * warps_per_block * 128 + int64_t real_num_chunks = ((blockid + 1) * num_chunks) > total_num_chunks ? (total_num_chunks - (blockid * num_chunks)) : num_chunks; + int64_t diff_num_chunks = real_num_chunks - num_chunks; + + b32* a_start_ptr = (b32*) (a + blockid * num_chunks * 256); // offset a to where this warp starts + b32* out_start_ptr = (b32*) (out + blockid * num_chunks * 256); + b32* a_ptr = a_start_ptr + threadid * 4; + b32* b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128 + threadid * 4; + + #if (__CUDA_ARCH__ < 900) // SM80, SM89 + uint64_t cache_policy; + asm volatile( + "createpolicy.fractional.L2::evict_first.b64 %0, 1.0;\n" + : "=l"(cache_policy) + ); + #endif + + #pragma unroll + for (int64_t k = 0; k < num_chunks; k++) { + size_t shared_ptr = __cvta_generic_to_shared(b_frag_ptr); + #if (__CUDA_ARCH__ >= 900) // SM90 + asm volatile( + "cp.async.cg.shared.global [%0], [%1], 16;\n" + "cp.async.commit_group;\n" + :: "l"(shared_ptr), "l"(a_ptr) + ); + #else // SM80, SM89 + asm volatile( + "cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2;\n" + "cp.async.commit_group;\n" + :: "l"(shared_ptr), "l"(a_ptr), "l"(cache_policy) + ); + #endif + + a_ptr += 128; + b_frag_ptr += 128; + } + + // generate hadamard 16x16 (up to 2 of them) + constexpr b16 fp16_1p[4] = {0b0011100110101000, 0b0011100000000000, 0b0011010110101000, 0b0011010000000000}; + constexpr b16 fp16_1n[4] = {0b1011100110101000, 0b1011100000000000, 0b1011010110101000, 0b1011010000000000}; + constexpr b16 bf16_1p[4] = {0b0011111100110101, 0b0011111100000000, 0b0011111010110101, 0b0011111010000000}; + constexpr b16 bf16_1n[4] = {0b1011111100110101, 0b1011111100000000, 0b1011111010110101, 0b1011111010000000}; + + #define val_type_1p(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1p[i]) : (bf16_1p[i])) + #define val_type_1n(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1n[i]) : (bf16_1n[i])) + constexpr b16 val_1p[4] = {val_type_1p(0), val_type_1p(1), val_type_1p(2), val_type_1p(3)}; + constexpr b16 val_1n[4] = {val_type_1n(0), val_type_1n(1), val_type_1n(2), val_type_1n(3)}; + + constexpr b32 p_p[4] = {p_p(0), p_p(1), p_p(2), p_p(3)}; + constexpr b32 p_n[4] = {p_n(0), p_n(1), p_n(2), p_n(3)}; + constexpr b32 n_p[4] = {n_p(0), n_p(1), n_p(2), n_p(3)}; + constexpr b32 n_n[4] = {n_n(0), n_n(1), n_n(2), n_n(3)}; + const b32 had_16_p1[4][4] = { + { + 0b10001000010001000010001000010001, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b10001000010001000010001000010001 + }, + { + 0b11001100100010000011001100100010, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11001100100010000011001100100010 + }, + { + 0b11111111101010101100110010011001, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11111111101010101100110010011001 + }, + { + 0b11111111101010101100110010011001, + 0b11111111101010101100110010011001, + 0b11111111101010101100110010011001, + 0b00000000010101010011001101100110 + } + }; + const b32 had_16_p2[4][4] = { + { + 0b10000000010000000010000000010000, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b10000000010000000010000000010000 + }, + { + 0b11000000100001000011000000100001, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11000000100001000011000000100001 + }, + { + 0b11110000101001011100001110010110, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11110000101001011100001110010110 + }, + { + 0b11110000101001011100001110010110, + 0b11110000101001011100001110010110, + 0b11110000101001011100001110010110, + 0b00001111010110100011110001101001 + } + }; + const b32 had_16_mask[3][4] = { + { + 0b10001000010001000010001000010001, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b10001000010001000010001000010001 + }, + { + 0b11001100110011000011001100110011, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11001100110011000011001100110011 + }, + { + 0b11111111111111111111111111111111, + 0b00000000000000000000000000000000, + 0b00000000000000000000000000000000, + 0b11111111111111111111111111111111 + } + }; + b32 had_frag[8]; + #pragma unroll + for (int64_t i = 0; i < 2; i++) { + int64_t c_log_h = (i == 0) ? MIN(4, log_had_size) : log_had_size % 4; + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + if (c_log_h < 4) { + bool mask = had_16_mask[c_log_h - 1][j] & (1 << (31 - threadid)); + if (!mask) { + had_frag[i * 4 + j] = 0; + continue; + } + } + bool pred1 = had_16_p1[c_log_h - 1][j] & (1 << (31 - threadid)); + bool pred2 = had_16_p2[c_log_h - 1][j] & (1 << (31 - threadid)); + b32 val = pred1 ? (pred2 ? p_p[c_log_h - 1] : p_n[c_log_h - 1]) : (pred2 ? n_p[c_log_h - 1] : n_n[c_log_h - 1]); + had_frag[i * 4 + j] = val; + } + if constexpr(log_had_size <= 4 || log_had_size % 4 == 0) break; + } + + // log had size above 8, only used for above 2^8 = 256 size + constexpr int64_t part8_log_had_size = log_had_size - 8; + + b32* a_chunk_ptr = a_start_ptr; // first chunk starts at this warp's data starts + b32* out_chunk_ptr = out_start_ptr; + + #pragma unroll + for (int64_t l = 0; l < 2; l++) { + if constexpr(log_had_size <= 8) { // l == 0 guaranteed, redundant simplified version of else body, to help compiler warnings + b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128; + } else { + b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * (l == 0 ? 128 : (128 >> part8_log_had_size)); + } + + if (l == 1) { + if constexpr(log_had_size > 8) { + __syncthreads(); // sync between first and second iterations if above size 256 + + if constexpr(log_had_size >= 12) { + // sizes 4k and above + + // a + threadblock offset + warp offset + // can then index into all chunks owned by this warp + b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block)); + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + #pragma unroll + for (int64_t k = 0; k < num_chunks; k++) { + // here, j represents register, and k represents 8-offset/chunk + uint64_t real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data + + int64_t real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread # + int64_t chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data) + int64_t thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads) + int64_t thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads + int64_t reg_idx = (j / 2) * 8 + (j % 2); // index due to target register + int64_t idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index + + // fix idx for majorness + int64_t rowidx = idx % (1 << part8_log_had_size); + int64_t colidx = idx >> part8_log_had_size; + + // store[rowidx * 128 + colidx] = data; + b32 data = store[rowidx * 128 + colidx]; + + // compiler generates excessive instructions, so we manually do the if statement + #pragma unroll + for (uint64_t i = 0; i < num_chunks; i++) { + asm volatile ( + "{\n\t" + " .reg .pred p0;\n\t" + " setp.eq.s64 p0, %1, %2;\n\t" + " @p0 mov.b32 %0, %3;\n\t" + "}\n\t" + : "+r"(b_frag_all[i][j]) // Output operand %0 + : "l"(real_chunk_num), "l"(i), "r"(data) // Input operands %1, %2, %3 + ); + } + } + } + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + #pragma unroll + for (int64_t k = 1; k < num_chunks; k++) { + int64_t threadid_contig = threadid % num_chunks; + int64_t threadid_mul = threadid / num_chunks; + int64_t threadid2 = (threadid_contig + num_chunks - k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to + b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2); + } + } + } + } + } + + #pragma unroll + for (int64_t k = 0; k < num_chunks; k++) { + if constexpr(enable_mask) { + if (k >= real_num_chunks) + break; + } + if (l == 0) { + // bad fix for k not being recognized as a constexpr by compiler + // asm("cp.async.wait_group %0;\n" :: "n"(num_chunks - k - 1)); + #define SWITCH_WAIT_ASYNC_LOAD_GROUP(i) case i: asm volatile("cp.async.wait_group %0;\n" :: "n"(num_chunks - i - 1)); break; + if constexpr(enable_mask) { + switch(k + diff_num_chunks) { + SWITCH_WAIT_ASYNC_LOAD_GROUP(0) + SWITCH_WAIT_ASYNC_LOAD_GROUP(1) + SWITCH_WAIT_ASYNC_LOAD_GROUP(2) + SWITCH_WAIT_ASYNC_LOAD_GROUP(3) + SWITCH_WAIT_ASYNC_LOAD_GROUP(4) + SWITCH_WAIT_ASYNC_LOAD_GROUP(5) + SWITCH_WAIT_ASYNC_LOAD_GROUP(6) + SWITCH_WAIT_ASYNC_LOAD_GROUP(7) + SWITCH_WAIT_ASYNC_LOAD_GROUP(8) + SWITCH_WAIT_ASYNC_LOAD_GROUP(9) + SWITCH_WAIT_ASYNC_LOAD_GROUP(10) + SWITCH_WAIT_ASYNC_LOAD_GROUP(11) + SWITCH_WAIT_ASYNC_LOAD_GROUP(12) + SWITCH_WAIT_ASYNC_LOAD_GROUP(13) + SWITCH_WAIT_ASYNC_LOAD_GROUP(14) + SWITCH_WAIT_ASYNC_LOAD_GROUP(15) + SWITCH_WAIT_ASYNC_LOAD_GROUP(16) + SWITCH_WAIT_ASYNC_LOAD_GROUP(17) + SWITCH_WAIT_ASYNC_LOAD_GROUP(18) + SWITCH_WAIT_ASYNC_LOAD_GROUP(19) + SWITCH_WAIT_ASYNC_LOAD_GROUP(20) + SWITCH_WAIT_ASYNC_LOAD_GROUP(21) + SWITCH_WAIT_ASYNC_LOAD_GROUP(22) + SWITCH_WAIT_ASYNC_LOAD_GROUP(23) + SWITCH_WAIT_ASYNC_LOAD_GROUP(24) + SWITCH_WAIT_ASYNC_LOAD_GROUP(25) + SWITCH_WAIT_ASYNC_LOAD_GROUP(26) + SWITCH_WAIT_ASYNC_LOAD_GROUP(27) + SWITCH_WAIT_ASYNC_LOAD_GROUP(28) + SWITCH_WAIT_ASYNC_LOAD_GROUP(29) + SWITCH_WAIT_ASYNC_LOAD_GROUP(30) + SWITCH_WAIT_ASYNC_LOAD_GROUP(31) + } + } else { + switch(k) { + SWITCH_WAIT_ASYNC_LOAD_GROUP(0) + SWITCH_WAIT_ASYNC_LOAD_GROUP(1) + SWITCH_WAIT_ASYNC_LOAD_GROUP(2) + SWITCH_WAIT_ASYNC_LOAD_GROUP(3) + SWITCH_WAIT_ASYNC_LOAD_GROUP(4) + SWITCH_WAIT_ASYNC_LOAD_GROUP(5) + SWITCH_WAIT_ASYNC_LOAD_GROUP(6) + SWITCH_WAIT_ASYNC_LOAD_GROUP(7) + SWITCH_WAIT_ASYNC_LOAD_GROUP(8) + SWITCH_WAIT_ASYNC_LOAD_GROUP(9) + SWITCH_WAIT_ASYNC_LOAD_GROUP(10) + SWITCH_WAIT_ASYNC_LOAD_GROUP(11) + SWITCH_WAIT_ASYNC_LOAD_GROUP(12) + SWITCH_WAIT_ASYNC_LOAD_GROUP(13) + SWITCH_WAIT_ASYNC_LOAD_GROUP(14) + SWITCH_WAIT_ASYNC_LOAD_GROUP(15) + SWITCH_WAIT_ASYNC_LOAD_GROUP(16) + SWITCH_WAIT_ASYNC_LOAD_GROUP(17) + SWITCH_WAIT_ASYNC_LOAD_GROUP(18) + SWITCH_WAIT_ASYNC_LOAD_GROUP(19) + SWITCH_WAIT_ASYNC_LOAD_GROUP(20) + SWITCH_WAIT_ASYNC_LOAD_GROUP(21) + SWITCH_WAIT_ASYNC_LOAD_GROUP(22) + SWITCH_WAIT_ASYNC_LOAD_GROUP(23) + SWITCH_WAIT_ASYNC_LOAD_GROUP(24) + SWITCH_WAIT_ASYNC_LOAD_GROUP(25) + SWITCH_WAIT_ASYNC_LOAD_GROUP(26) + SWITCH_WAIT_ASYNC_LOAD_GROUP(27) + SWITCH_WAIT_ASYNC_LOAD_GROUP(28) + SWITCH_WAIT_ASYNC_LOAD_GROUP(29) + SWITCH_WAIT_ASYNC_LOAD_GROUP(30) + SWITCH_WAIT_ASYNC_LOAD_GROUP(31) + } + } + } + + if (l == 0) { + // loading for the first iteration + + // thread 0 loads [t0r0, t16r1, t0r2, t16r3] + // thread 16 loads [t0r1, t16r0, t0r3, t16r2] + // allows full coalescing, same for t1/t17, t2/t18, etc. + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + int64_t reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2)); + int64_t real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16); + int64_t real_row = real_thread_id % 4; + int64_t real_col = real_thread_id / 4; + b_frag_all[k][j] = b_frag_ptr[(real_row + (reg % 2) * 4) + (real_col + (j / 2) * 8) * 8]; + } + + // for t16 swap r0/r1 and r2/r3 to have [t16r0, t0r1, t16r2, t0r3] + // so registers are in right order, same for t17, t18, etc. + if ((threadid & 16) != 0) { + b32 temp = b_frag_all[k][0]; + b_frag_all[k][0] = b_frag_all[k][1]; + b_frag_all[k][1] = temp; + + temp = b_frag_all[k][2]; + b_frag_all[k][2] = b_frag_all[k][3]; + b_frag_all[k][3] = temp; + } + + // t0 and t16 swap r1 and r3 to have their own data, + // same for t1/t17, t2/18, etc. + #pragma unroll + for (int64_t j = 1; j < 4; j += 2) { + b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16); + } + } else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings + if constexpr(log_had_size < 12) { + // sizes 512, 1k, and 2k + + // for 512: + // thread 0 loads [t0r0, t0r1, t16r2, t16r3] + // thread 16 loads [t0r2, t0r3, t16r0, t16r1] + // same for t1/t17, t2/t18, etc. + // for 1k and 2k: + // thread 0 loads [t0r0, t0r1, t1r2, t1r3] + // thread 1 loads [t0r2, t0r3, t1r0, t1r1] + // same for t2/t3, t4/t5, etc. + // allows full coalescing for 512 and 1k, 16x coalescing for 2k + constexpr int64_t xor_val = log_had_size == 9 ? 16 : 1; + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + int64_t reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4; + int64_t real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val); + int64_t idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2); + int64_t rowidx = idx % (1 << part8_log_had_size); + int64_t colidx = idx >> part8_log_had_size; + b_frag_all[k][j] = b_frag_ptr[rowidx * 128 + colidx]; + } + + if ((threadid & xor_val) != 0) { + b32 temp = b_frag_all[k][0]; + b_frag_all[k][0] = b_frag_all[k][2]; + b_frag_all[k][2] = temp; + + temp = b_frag_all[k][1]; + b_frag_all[k][1] = b_frag_all[k][3]; + b_frag_all[k][3] = temp; + } + + #pragma unroll + for (int64_t j = 2; j < 4; j++) { + b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val); + } + } + } + + if (l == 1) { + // for second iteration, we load 2 consecutive b16s (1 b32) per register, + // but tensor core register layout requires 2 b16s that are in the + // same column/consecutive rows to be in the same register, so do the swap + b32 f0 = ((b_frag_all[k][1] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF); + b32 f1 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][2] & 0xFFFF); + b32 f2 = (b_frag_all[k][1] & 0xFFFF0000) | (b_frag_all[k][0] >> 16); + b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][2] >> 16); + b_frag_all[k][0] = f0; + b_frag_all[k][1] = f1; + b_frag_all[k][2] = f2; + b_frag_all[k][3] = f3; + } + + #pragma unroll + for(int64_t i = 0, remaining_log_had_size = log_had_size - l * 8; i < 2 && remaining_log_had_size > 0; i++) { + int64_t had_off = ((remaining_log_had_size < 4) && !(log_had_size <= 4 || log_had_size % 4 == 0)) ? 4 : 0; + mma_m16_n16_k16_b16_b16_b16_noacc(had_frag[had_off + 0], had_frag[had_off + 1], had_frag[had_off + 2], had_frag[had_off + 3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3]); + + remaining_log_had_size -= 4; + if (remaining_log_had_size <= 0 && i == 0) { + // TODO: consider different storing so no need for transpose + matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][0]); + matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][1]); + matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][2]); + matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][3]); + } else { + // swap and use output directly as b_frag for next iteration as an actually free transpose + b32 temp = b_frag_all[k][1]; + b_frag_all[k][1] = b_frag_all[k][2]; + b_frag_all[k][2] = temp; + } + } + + if (l == 1) { + // invert swap from above for second iteration + b32 f0 = ((b_frag_all[k][2] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF); + b32 f1 = (b_frag_all[k][2] & 0xFFFF0000) | (b_frag_all[k][0] >> 16); + b32 f2 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][1] & 0xFFFF); + b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][1] >> 16); + b_frag_all[k][0] = f0; + b_frag_all[k][1] = f1; + b_frag_all[k][2] = f2; + b_frag_all[k][3] = f3; + } + + if (l == 0) { + // inverse of coalesced load for first iteration to store result + #pragma unroll + for (int64_t j = 1; j < 4; j += 2) { + b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16); + } + + if ((threadid & 16) != 0) { + b32 temp = b_frag_all[k][0]; + b_frag_all[k][0] = b_frag_all[k][1]; + b_frag_all[k][1] = temp; + + temp = b_frag_all[k][2]; + b_frag_all[k][2] = b_frag_all[k][3]; + b_frag_all[k][3] = temp; + } + + // if only going up to 256 size, store directly back to global memory, + // otherwise store back to shared memory for next iteration + b32* store = (log_had_size <= 8) ? out_chunk_ptr : b_frag_ptr; + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + int64_t reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2)); + int64_t real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16); + int64_t real_row = real_thread_id % 4; + int64_t real_col = real_thread_id / 4; + store[(real_row + (reg % 2) * 4) + (real_col + (reg / 2) * 8) * 8] = b_frag_all[k][j]; + } + } else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings + if (log_had_size < 12) { + // inverse of coalesced load for sizes 512, 1k and 2k to store result + constexpr int xor_val = log_had_size == 9 ? 16 : 1; + #pragma unroll + for (int64_t j = 2; j < 4; j++) { + b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val); + } + + if ((threadid & xor_val) != 0) { + b32 temp = b_frag_all[k][0]; + b_frag_all[k][0] = b_frag_all[k][2]; + b_frag_all[k][2] = temp; + + temp = b_frag_all[k][1]; + b_frag_all[k][1] = b_frag_all[k][3]; + b_frag_all[k][3] = temp; + } + + b32* store = (b32*)(out + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 256 + (256 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block) + k)); + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + int64_t reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4; + b32 data = b_frag_all[k][j]; + int64_t real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val); + int64_t idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2); + int64_t rowidx = idx % (1 << part8_log_had_size); + int64_t colidx = idx >> part8_log_had_size; + store[rowidx * 128 + colidx] = data; + } + } + // for size 4k and above, wait to process all chunks so a final store can be performed coalesced + } + + a_chunk_ptr += 128; // (only affects first 256 size) move on to next chunk by skipping 256 elements in b16 (= 128 in b32) + out_chunk_ptr += 128; + if constexpr(log_had_size > 8) { + b_frag_ptr += (l == 0 ? 128 : (128 >> part8_log_had_size)); + } else { // else is redundant, simplified version of if body, to help compiler warnings + b_frag_ptr += 128; + } + } + if (log_had_size <= 8) + break; + } + + if constexpr(log_had_size >= 12) { + // for sizes 4k and above, perform final coalesced store after processing all chunks + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + #pragma unroll + for (int64_t k = 1; k < num_chunks; k++) { + int64_t threadid_contig = threadid % num_chunks; + int64_t threadid_mul = threadid / num_chunks; + int64_t threadid2 = (threadid_contig + k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to + b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2); + } + } + + // a + threadblock offset + warp offset + // can then index into all chunks owned by this warp + b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block)); + + #pragma unroll + for (int64_t j = 0; j < 4; j++) { + #pragma unroll + for (int64_t k = 0; k < num_chunks; k++) { + // here, j represents register, and k represents 8-offset/chunk + int64_t real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data + + // b32 data = b_frag_all[real_chunk_num][j]; // target thread data + b32 data; + #pragma unroll + for (int64_t i = 0; i < num_chunks; i++) { + if (real_chunk_num == i) data = b_frag_all[i][j]; + } + + int64_t real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread # + int64_t chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data) + int64_t thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads) + int64_t thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads + int64_t reg_idx = (j / 2) * 8 + (j % 2); // index due to target register + int64_t idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index + + // fix idx for majorness + int64_t rowidx = idx % (1 << part8_log_had_size); + int64_t colidx = idx >> part8_log_had_size; + + store[rowidx * 128 + colidx] = data; + } + } + + __syncthreads(); + store = ((b32*) out) + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 128; + int4* store4 = (int4*) store; + int4* bfrag_arr4 = (int4*) bfrag_arr; + // flush smem, simply linearly write to store + // always divisible by 128*32b, so (32*4)*32b is ok + #pragma unroll + for (int64_t warp_off = 0; warp_off < (num_chunks * warps_per_block * 128 / 4); warp_off += 32 * warps_per_block) { + int64_t total_off = warp_off + threadid + (blockid % warps_per_block) * 32; + store4[total_off] = bfrag_arr4[total_off]; + } + } + +} + +constexpr int64_t ceil_div(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +template +void __forceinline__ run_kernel(b16* a_mat, b16* out, int64_t num_chunks, cudaStream_t stream) { + int64_t shared_size = chunks_per_warp * warps_per_block * 128 * 4; + dim3 block_size = 32 * warps_per_block; + + #define CHECK_SHARED_LIM() { \ + if (shared_size > 48 * 1024) { \ + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536)); \ + } \ + } \ + + if constexpr(check_masking) { + if (num_chunks % (chunks_per_warp * warps_per_block) != 0) { + dim3 grid_size = ceil_div(ceil_div(num_chunks, chunks_per_warp), warps_per_block); + auto kernel = hadamard_transform_kernel; + CHECK_SHARED_LIM(); + kernel<<>>(a_mat, out, num_chunks); + } else { + dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block; + auto kernel = hadamard_transform_kernel; + CHECK_SHARED_LIM(); + kernel<<>>(a_mat, out, num_chunks); + } + } else { + dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block; + auto kernel = hadamard_transform_kernel; + CHECK_SHARED_LIM(); + kernel<<>>(a_mat, out, num_chunks); + } + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream) { + int64_t num_chunks = numel / 256; // caller required to ensure divisible by 256 + // for size 256, use (2, 1) + // for size 32k use (8, 16) + constexpr int64_t chunks_per_warp_small = 1;// 8; + constexpr int64_t warps_per_block_small = 1;//2;//16; + constexpr int64_t blocks_per_sm_small = 24; + constexpr int64_t chunks_per_warp_large = 2; + constexpr int64_t warps_per_block_large = 1; + constexpr int64_t blocks_per_sm_large = 24; + + b16* a_mat = (b16*) a_mat_ptr; + b16* out = (b16*) out_ptr; + + if (numel <= 256) { + switch (had_size) { + case (1<<1): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<2): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<3): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<4): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<5): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<6): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<7): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<8): run_kernel(a_mat, out, num_chunks, stream); break; + } + } else { + switch (had_size) { + case (1<<1): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<2): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<3): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<4): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<5): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<6): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<7): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<8): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<9): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<10): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<11): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<12): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<13): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<14): run_kernel(a_mat, out, num_chunks, stream); break; + case (1<<15): run_kernel(a_mat, out, num_chunks, stream); break; + } + } +} + +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); +template void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream); + +} // namespace hadacore + +constexpr bool is_power_of_two(int x) { return x && !(x & (x - 1)); } + +torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) { + auto dtype = x.scalar_type(); + TORCH_CHECK(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently"); + TORCH_CHECK(x.is_cuda()); + + const int had_size = x.size(-1); + TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)), + "Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size); + + const auto res_shape = x.sizes(); + x = x.reshape({-1, had_size}); + + auto numel = x.numel(); + if (numel % 256 != 0) { + x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size})); + } + + if (x.stride(-1) != 1) { + x = x.contiguous(); + } + torch::Tensor out = inplace ? x : torch::empty_like(x); + + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + VLLM_DISPATCH_HALF_TYPES(x.scalar_type(), "hadacore_transform_runfht", [&] { + auto constexpr SCALAR_TYPE = c10::CppTypeToScalarType::value; + hadacore::run_fht(x.data_ptr(), x.data_ptr(), x.numel(), had_size, stream); + }); + + if (numel % 256 != 0) { + out = out.index({torch::indexing::Slice(0, numel / had_size)}); + } + + if (inplace && out.data_ptr() != x.data_ptr()) { + x.copy_(out.view(res_shape)); + return x; + } + return out.reshape(res_shape); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("hadacore_transform", &hadacore_transform); +} diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 0d14ba15937c6..8fd536ef46e3d 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -417,7 +417,7 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): )) def prepacked_type_key(prepack_type: PrepackTypeConfig): - # For now we we can just use the first accumulator type seen since + # For now, we can just use the first accumulator type seen since # the tensor core shapes/layouts don't vary based on accumulator # type so we can generate less code this way return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index e3a0e15f5304f..dac9df6048f2a 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -30,6 +30,10 @@ #define __HIP__GFX9__ #endif +#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__FP8MFMA__ +#endif + #if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__)) #define __HIP__GFX11__ #endif @@ -51,6 +55,12 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +enum class MFMAType { + F16 = 0, + Fp8 = 1, + Fp4 = 2, +}; + #if defined(__HIP__GFX9__) #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 @@ -112,6 +122,21 @@ __device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, } } +template +__device__ __forceinline__ floatx4 gcn_mfma16x16x32_instr(const long& inpA, + const long& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(inpA, inpB, inpC, absz, + cbid, blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(inpA, inpB, inpC, absz, + cbid, blgp); + } else { + static_assert(false, "unsupported 8b dtype"); + } +} + template __device__ __forceinline__ float to_float(const T& inp) { if constexpr (std::is_same::value) { @@ -256,12 +281,44 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { return ret; } +typedef union u64_cvt { + half f16x4[4]; + int16_t b16x4[4]; + _B8x8 b8x8; + _B16x4 b64; + int64_t i64; +} _T8x8; + +__device__ __forceinline__ _B8x8 convert_b16x8(const _B16x8& input, + _T8x8& Mtemp) { + _T8x8 Qtmp8x8; + + for (int i = 0; i < 2; i++) { + floatx4 q_out = {0, 0, 0, 0}; + q_out = gcn_mfma16x16x16_instr<_Float16, 0, 0, 0>(Mtemp.b64, input.xy[i], + q_out); + Qtmp8x8.b16x4[i * 2] = + __builtin_amdgcn_cvt_pk_fp8_f32(q_out[0], q_out[1], 0, false); + Qtmp8x8.b16x4[i * 2 + 1] = + __builtin_amdgcn_cvt_pk_fp8_f32(q_out[2], q_out[3], 0, false); + } + return Qtmp8x8.b8x8; +} + +__device__ float warpReduceMax(float val) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + val = max( + val, __shfl_down(val, offset, WARP_SIZE)); // Using max() for reduction + } + return val; +} + // grid (num_seqs, num_partitions,num_kv_heads) // block (256) // clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -367,6 +424,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; int kphysical_block_number[TLOOP]; + #if defined(__HIP__FP8MFMA__) + float q_max = 0; + float q_scale = 1.0; + #endif // fetch k physical block numbers for (int token_depth = 0; token_depth < TLOOP; token_depth++) { @@ -416,6 +477,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( Qlocal[qkhe_depth][qkratio].xy[i] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO] [2 * qkratio + i]; + #if defined(__HIP__FP8MFMA__) + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto && + MFMA_TYPE == MFMAType::Fp8) { + scalar_t* qptr = + reinterpret_cast(&Qlocal[qkhe_depth][qkratio].xy[i]); + for (int k = 0; k < 4; k++) + q_max = fmax(fabs(to_float(qptr[k])), q_max); + } + #endif } } } @@ -515,6 +585,14 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { // multiply by k_scale if fp8 kv cache scale2 *= *k_scale; + #if defined(__HIP__FP8MFMA__) + q_max = warpReduceMax(q_max); + constexpr float FP8_E4M3_SCALE_TARGET = 224.0f; + if constexpr (MFMA_TYPE == MFMAType::Fp8) { + q_scale = q_max > 0 ? FP8_E4M3_SCALE_TARGET / q_max : 1.0f; + scale2 /= q_scale; + } + #endif } floatx4 d_out[TLOOP]; @@ -534,12 +612,41 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( auto Ktmp = Klocal[token_depth][qkhe_depth]; _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { - _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; - _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); - for (int i = 0; i < 2; i++) { - d_out[token_depth] = gcn_mfma16x16x16_instr( - Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], - d_out[token_depth]); + if constexpr (MFMA_TYPE == MFMAType::F16) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i = 0; i < 2; i++) { + d_out[token_depth] = gcn_mfma16x16x16_instr( + Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], + d_out[token_depth]); + } + } else { + #if defined(__HIP__FP8MFMA__) + _T8x8 Ktmp8x8, Qtmp8x8; + Ktmp8x8.b8x8 = Ktmp8x16.xy[qkratio]; + + for (int n = 0; n < 2; n++) { + scalar_t* qptr = reinterpret_cast( + &Qlocal[qkhe_depth][qkratio].xy[n]); + + Qtmp8x8.b16x4[n * 2] = + vllm::fp8::scaled_vec_conversion( + make_float2(to_float(qptr[0]), + to_float(qptr[1])), + q_scale); + Qtmp8x8.b16x4[n * 2 + 1] = + vllm::fp8::scaled_vec_conversion( + make_float2(to_float(qptr[2]), + to_float(qptr[3])), + q_scale); + } + + d_out[token_depth] = + gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>( + Ktmp8x8.i64, Qtmp8x8.i64, d_out[token_depth]); + #else + UNREACHABLE_CODE + #endif } } } @@ -629,17 +736,36 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // disable rtz conversion due to its impact on accuracy. constexpr bool LOGITS_RTZ_CONVERSION = false; + #if defined(__HIP__FP8MFMA__) + int rowid_8x8 = rowid / 2; + int offset = rowid % 2; + #endif + // write logits to shared mem for (int token_depth = 0; token_depth < TLOOP; token_depth++) { d_out[token_depth] *= inv_sum_scale; - if constexpr (LOGITS_RTZ_CONVERSION) { - // use rtz conversion for better performance, with negligible impact on - // accuracy - shared_logits[warpid][token_depth][lane16id][rowid] = - from_floatx4_rtz(d_out[token_depth]); + if constexpr (MFMA_TYPE != MFMAType::Fp8) { + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz conversion for better performance, with negligible impact on + // accuracy + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4_rtz(d_out[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4(d_out[token_depth]); + } } else { - shared_logits[warpid][token_depth][lane16id][rowid] = - from_floatx4(d_out[token_depth]); + #if defined(__HIP__FP8MFMA__) + // cast _B16x4* to _B8x8* + _T8x8& logits_8x8 = *reinterpret_cast<_T8x8*>( + &shared_logits[warpid][token_depth][lane16id][rowid_8x8]); + logits_8x8.b16x4[offset * 2] = __builtin_amdgcn_cvt_pk_fp8_f32( + d_out[token_depth][0], d_out[token_depth][1], 0, false); + logits_8x8.b16x4[offset * 2 + 1] = __builtin_amdgcn_cvt_pk_fp8_f32( + d_out[token_depth][2], d_out[token_depth][3], 0, false); + #else + UNREACHABLE_CODE + #endif } } @@ -692,19 +818,42 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) { _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; - _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); - for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { - const int offset = - rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + - j * ELEMS8_ELEMS4_RATIO + i; - const int offset1 = offset % ROWS_PER_WARP; - const int offset2 = offset / ROWS_PER_WARP; - // output format is 16 qheads across 16 lanes, 16 head elems - // spread across 4 rows - tmp_out = gcn_mfma16x16x16_instr( - Vlocaltmp.xy[i], - shared_logits[vtoken_depth][offset2][lane16id][offset1], - tmp_out); + if constexpr (MFMA_TYPE == MFMAType::F16) { + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocaltmp.xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } else { + #if defined(__HIP__FP8MFMA__) + for (int i = 0; i < ELEMS8_ELEMS4_RATIO / 2; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = (offset % ROWS_PER_WARP) / 2; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>( + reinterpret_cast<_T8x8*>(&Vtmp8x8)->i64, + reinterpret_cast<_T8x8*>( + &shared_logits[vtoken_depth][offset2][lane16id] + [offset1]) + ->i64, + tmp_out); + } + #else + UNREACHABLE_CODE + #endif } } } @@ -1570,7 +1719,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { // clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, + MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -2337,7 +2487,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) { // clang-format off template + int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, + MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -2969,7 +3120,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( template + int GQA_RATIO, MFMAType MFMA_TYPE> __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -3041,7 +3192,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ paged_attention_ll4mi_QKV_mfma16_kernel \ + GQA_RATIO, MFMA_TYPE> \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \ @@ -3069,7 +3220,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( template + bool ALIBI_ENABLED, MFMAType MFMA_TYPE> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -3225,7 +3376,7 @@ void paged_attention_custom_launcher( template + bool ALIBI_ENABLED, MFMAType MFMA_TYPE> void paged_attention_custom_launcher_navi( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -3397,74 +3548,77 @@ void paged_attention_custom_launcher_navi( } #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ - PSIZE, ALIBI_ENABLED) \ + PSIZE, ALIBI_ENABLED, MFMA_TYPE) \ if (!is_navi) { \ paged_attention_custom_launcher( \ + OUTT, PSIZE, ALIBI_ENABLED, MFMA_TYPE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ } else { \ - paged_attention_custom_launcher_navi< \ - T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \ + paged_attention_custom_launcher_navi( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ max_seq_len, alibi_slopes, k_scale, v_scale); \ } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - OUTT, PSIZE) \ + OUTT, PSIZE, MFMA_TYPE) \ if (alibi_slopes) { \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ - true); \ + true, MFMA_TYPE); \ } else { \ CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ - false); \ + false, MFMA_TYPE); \ } #if defined(__HIPCC__) && defined(__gfx90a__) - #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + MFMA_TYPE) \ if (fp8_out_scale) { \ TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ } else { \ CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ - 256); \ + 256, MFMA_TYPE); \ } #else - #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + MFMA_TYPE) \ if (fp8_out_scale) { \ CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - uint8_t, 256); \ + uint8_t, 256, MFMA_TYPE); \ } else { \ CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ - 256); \ + 256, MFMA_TYPE); \ } #endif -#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ - switch (block_size) { \ - case 16: \ - CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ - break; \ - case 32: \ - CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE, MFMA_TYPE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE, MFMA_TYPE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE, MFMA_TYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ - switch (head_size) { \ - case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ - break; \ - case 128: \ - CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported head size: ", head_size); \ - break; \ +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE, MFMA_TYPE) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64, MFMA_TYPE); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128, MFMA_TYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ } bool is_navi_gpu() { @@ -3503,28 +3657,43 @@ void paged_attention( const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, - const std::optional& fp8_out_scale) { + const std::optional& fp8_out_scale, + const std::string& mfma_type) { // clang-format on bool is_navi = is_navi_gpu(); - const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, - vllm::Fp8KVCacheDataType::kAuto); + CALL_CUSTOM_LAUNCHER_BLK_HEAD( + _Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto, MFMAType::F16); } else if (query.dtype() == at::ScalarType::BFloat16) { CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, - vllm::Fp8KVCacheDataType::kAuto); + vllm::Fp8KVCacheDataType::kAuto, + MFMAType::F16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); + if (mfma_type == "fp8") { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::Fp8); + } else { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::F16); + } } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); + if (mfma_type == "fp8") { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::Fp8); + } else { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3, + MFMAType::F16); + } } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 34dcc9401aae8..b6ee2656746c1 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -19,4 +19,5 @@ void paged_attention( const std::optional& query_start_loc, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale, const std::optional& fp8_out_scale); + torch::Tensor& v_scale, const std::optional& fp8_out_scale, + const std::string& mfma_type); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 66bdc448da3ca..c0c4daef64f05 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -48,7 +48,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor? alibi_slopes," " str kv_cache_dtype," " Tensor k_scale, Tensor v_scale," - " Tensor? fp8_out_scale) -> ()"); + " Tensor? fp8_out_scale," + " str mfma_type) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 56626a02c0277..bc096406c51ae 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #define stride_tag #endif + ops.def( + "silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! " + "y_q, Tensor! y_s, int group_size, " + "bool use_ue8m0, int num_parallel_tokens) -> ()"); + ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA, + &silu_mul_fp8_quant_deep_gemm_cuda); + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); @@ -115,8 +122,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +#ifndef USE_ROCM ops.def( "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " "Tensor input, Tensor input_global_scale) -> ()"); @@ -169,6 +175,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + // Polynomial Normalization. + ops.def( + "poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float " + "epsilon) -> ()"); + ops.impl("poly_norm", torch::kCUDA, &poly_norm); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " @@ -209,16 +221,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); - // Apply GPT-NeoX or GPT-J style rotary embedding to query and key - // (supports multiple loras). - ops.def( - "batched_rotary_embedding(Tensor positions, Tensor! query," - " Tensor!? key, int head_size," - " Tensor cos_sin_cache, bool is_neox," - " int rot_dim," - " Tensor cos_sin_cache_offsets) -> ()"); - ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); - // Quantization ops #ifndef USE_ROCM // Quantized GEMM for AWQ. @@ -508,19 +510,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]"); ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress); - // CUTLASS MLA decode - ops.def( - "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, float scale) -> ()"); - ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); - // SM100 CUTLASS MLA decode ops.def( - "sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, Tensor workspace, float " - "scale," + "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," + " Tensor q_pe, Tensor kv_c_and_k_pe_cache," + " Tensor seq_lens, Tensor page_table," + " Tensor workspace, float scale," " int num_kv_splits) -> ()"); // conditionally compiled so impl in source file @@ -611,6 +606,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int pad_slot_id) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + // Hadamard transforms + ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor"); + #ifndef USE_ROCM // Compute per-token-group FP8 quantized tensor and scaling factor. ops.def( @@ -694,16 +692,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor scale) -> ()"); cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); - cache_ops.def( - "cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe," - " Tensor cp_local_token_select_indices," - " Tensor! kv_cache," - " Tensor slot_mapping," - " str kv_cache_dtype," - " Tensor scale) -> ()"); - cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA, - &cp_fused_concat_and_cache_mla); - // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " diff --git a/docker/Dockerfile b/docker/Dockerfile index 75e8fa49f86c9..034f73736ca72 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -196,6 +196,7 @@ ARG SCCACHE_S3_NO_CREDENTIALS=0 # Flag to control whether to use pre-built vLLM wheels ARG VLLM_USE_PRECOMPILED="" +ARG VLLM_MAIN_CUDA_VERSION="" # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/uv \ @@ -213,6 +214,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ && export SCCACHE_IDLE_TIMEOUT=0 \ && export CMAKE_BUILD_TYPE=Release \ && export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \ + && export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \ && export VLLM_DOCKER_BUILD_CONTEXT=1 \ && sccache --show-stats \ && python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \ @@ -237,7 +239,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ # Check the size of the wheel if RUN_WHEEL_CHECK is true COPY .buildkite/check-wheel-size.py check-wheel-size.py # sync the default value with .buildkite/check-wheel-size.py -ARG VLLM_MAX_SIZE_MB=400 +ARG VLLM_MAX_SIZE_MB=450 ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ARG RUN_WHEEL_CHECK=true RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ @@ -261,6 +263,8 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" # Use copy mode to avoid hardlink failures with Docker cache mounts ENV UV_LINK_MODE=copy +# Install libnuma-dev, required by fastsafetensors (fixes #20384) +RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* COPY requirements/lint.txt requirements/lint.txt COPY requirements/test.txt requirements/test.txt COPY requirements/dev.txt requirements/dev.txt @@ -279,6 +283,10 @@ WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive ARG TARGETPLATFORM +ARG GDRCOPY_CUDA_VERSION=12.8 +# Keep in line with FINAL_BASE_IMAGE +ARG GDRCOPY_OS_VERSION=Ubuntu22_04 + SHELL ["/bin/bash", "-c"] ARG DEADSNAKES_MIRROR_URL @@ -373,7 +381,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # Install FlashInfer from source ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" # Keep this in sync with "flashinfer" extra in setup.py -ARG FLASHINFER_GIT_REF="v0.2.14.post1" +ARG FLASHINFER_GIT_REF="v0.3.1" # Flag to control whether to compile FlashInfer AOT kernels # Set to "true" to enable AOT compilation: # docker build --build-arg FLASHINFER_AOT_COMPILE=true ... @@ -437,13 +445,21 @@ COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh RUN --mount=type=cache,target=/root/.cache/uv \ VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" ${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"} -# Install EP kernels(pplx-kernels and DeepEP), NixL +COPY tools/install_gdrcopy.sh install_gdrcopy.sh +RUN set -eux; \ + case "${TARGETPLATFORM}" in \ + linux/arm64) UUARCH="aarch64" ;; \ + linux/amd64) UUARCH="x64" ;; \ + *) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \ + esac; \ + ./install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}"; \ + rm ./install_gdrcopy.sh + +# Install EP kernels(pplx-kernels and DeepEP) COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh -COPY tools/install_nixl.sh install_nixl.sh ENV CUDA_HOME=/usr/local/cuda RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \ - && bash install_python_libraries.sh \ - && bash install_nixl.sh --force + && bash install_python_libraries.sh #################### vLLM installation IMAGE #################### @@ -517,7 +533,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ else \ BITSANDBYTES_VERSION="0.46.1"; \ fi; \ - uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3] + uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' boto3 runai-model-streamer runai-model-streamer[s3] ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/docker/Dockerfile.neuron b/docker/Dockerfile.neuron deleted file mode 100644 index 8bc23554718dc..0000000000000 --- a/docker/Dockerfile.neuron +++ /dev/null @@ -1,56 +0,0 @@ -# default base image -# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04" - -FROM $BASE_IMAGE - -RUN echo "Base image is $BASE_IMAGE" - -# Install some basic utilities -RUN apt-get update && \ - apt-get install -y \ - git \ - python3 \ - python3-pip \ - ffmpeg libsm6 libxext6 libgl1 - -### Mount Point ### -# When launching the container, mount the code directory to /workspace -ARG APP_MOUNT=/workspace -VOLUME [ ${APP_MOUNT} ] -WORKDIR ${APP_MOUNT}/vllm - -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity -RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -RUN python3 -m pip install pytest - -# uninstall transformers-neuronx package explicitly to avoid version conflict -RUN python3 -m pip uninstall -y transformers-neuronx - -COPY . . -ARG GIT_REPO_CHECK=0 -RUN --mount=type=bind,source=.git,target=.git \ - if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi - -RUN python3 -m pip install -U \ - 'cmake>=3.26.1' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ - -r requirements/neuron.txt - -ENV VLLM_TARGET_DEVICE neuron -RUN --mount=type=bind,source=.git,target=.git \ - pip install --no-build-isolation -v -e . - -# install development dependencies (for testing) -RUN python3 -m pip install -e tests/vllm_test_utils - -# install transformers-neuronx package as an optional dependencies (for V0) -# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict -RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps - -RUN python3 -m pip install sentencepiece transformers==4.48.0 -U - -# overwrite entrypoint to run bash script -RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py - -CMD ["/bin/bash"] diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index e147b97f0e056..ae12ed0f7cabb 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -246,7 +246,7 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2. # build flashinfer for torch nightly from source around 10 mins -# release version: v0.2.2.post1 +# release version: v0.3.1 # todo(elainewy): cache flashinfer build result for faster build ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ @@ -254,7 +254,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ echo "git clone flashinfer..." \ && git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \ && cd flashinfer \ - && git checkout v0.2.2.post1 \ + && git checkout v0.3.1 \ && git submodule update --init --recursive \ && echo "finish git clone flashinfer..." \ && rm -rf build \ diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f164857325043..c8900212e5a1b 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -29,7 +29,10 @@ ARG VLLM_BRANCH="main" ONBUILD RUN git clone ${VLLM_REPO} \ && cd vllm \ && git fetch -v --prune -- origin ${VLLM_BRANCH} \ - && git checkout FETCH_HEAD + && git checkout FETCH_HEAD \ + && if [ ${VLLM_REPO} != "https://github.com/vllm-project/vllm.git" ] ; then \ + git remote add upstream "https://github.com/vllm-project/vllm.git" \ + && git fetch upstream ; fi FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm # ----------------------- @@ -47,6 +50,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements /requirements COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite # ----------------------- @@ -71,7 +75,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace RUN cd /vllm-workspace \ && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ - && python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \ + && python3 -m pip install lm-eval[api]==0.4.4 \ && python3 -m pip install pytest-shard # ----------------------- @@ -100,8 +104,10 @@ ARG COMMON_WORKDIR # Copy over the benchmark scripts as well COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples +COPY --from=export_vllm /docker ${COMMON_WORKDIR}/vllm/docker ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 +ENV RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 ENV TOKENIZERS_PARALLELISM=false # ENV that can improve safe tensor loading, and end-to-end time diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 3414c0aa845cb..4973b57f76563 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -1,27 +1,23 @@ -ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete -ARG HIPBLASLT_BRANCH="db8e93b4" -ARG HIPBLAS_COMMON_BRANCH="7c1566b" -ARG LEGACY_HIPBLASLT_OPTION= -ARG RCCL_BRANCH="648a58d" -ARG RCCL_REPO="https://github.com/ROCm/rccl" -ARG TRITON_BRANCH="e5be006" -ARG TRITON_REPO="https://github.com/triton-lang/triton.git" -ARG PYTORCH_BRANCH="295f2ed4" -ARG PYTORCH_VISION_BRANCH="v0.21.0" -ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" +ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete +ARG TRITON_BRANCH="f9e5bf54" +ARG TRITON_REPO="https://github.com/ROCm/triton.git" +ARG PYTORCH_BRANCH="b2fb6885" +ARG PYTORCH_VISION_BRANCH="v0.23.0" +ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" -ARG FA_BRANCH="1a7f4dfa" +ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="916bf3c" +ARG AITER_BRANCH="2ab9f4cd" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base -ENV PATH=/opt/rocm/llvm/bin:$PATH +ENV PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV ROCM_PATH=/opt/rocm ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: -ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201 +ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201 ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} +ENV AITER_ROCM_ARCH=gfx942;gfx950 ARG PYTHON_VERSION=3.12 @@ -45,38 +41,7 @@ RUN apt-get update -y \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version -RUN pip install -U packaging 'cmake<4' ninja wheel setuptools pybind11 Cython - -FROM base AS build_hipblaslt -ARG HIPBLASLT_BRANCH -ARG HIPBLAS_COMMON_BRANCH -# Set to "--legacy_hipblas_direct" for ROCm<=6.2 -ARG LEGACY_HIPBLASLT_OPTION -RUN git clone https://github.com/ROCm/hipBLAS-common.git -RUN cd hipBLAS-common \ - && git checkout ${HIPBLAS_COMMON_BRANCH} \ - && mkdir build \ - && cd build \ - && cmake .. \ - && make package \ - && dpkg -i ./*.deb -RUN git clone https://github.com/ROCm/hipBLASLt -RUN cd hipBLASLt \ - && git checkout ${HIPBLASLT_BRANCH} \ - && apt-get install -y llvm-dev \ - && ./install.sh -dc --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \ - && cd build/release \ - && make package -RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install - -FROM base AS build_rccl -ARG RCCL_BRANCH -ARG RCCL_REPO -RUN git clone ${RCCL_REPO} -RUN cd rccl \ - && git checkout ${RCCL_BRANCH} \ - && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} -RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install +RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython FROM base AS build_triton ARG TRITON_BRANCH @@ -84,9 +49,11 @@ ARG TRITON_REPO RUN git clone ${TRITON_REPO} RUN cd triton \ && git checkout ${TRITON_BRANCH} \ - && cd python \ - && python3 setup.py bdist_wheel --dist-dir=dist -RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install + && if [ ! -f setup.py ]; then cd python; fi \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && mkdir -p /app/install && cp dist/*.whl /app/install +RUN if [ -d triton/python/triton_kernels ]; then pip install build && cd triton/python/triton_kernels \ + && python3 -m build --wheel && cp dist/*.whl /app/install; fi FROM base AS build_amdsmi RUN cd /opt/rocm/share/amd_smi \ @@ -129,18 +96,21 @@ RUN cd aiter \ && git checkout ${AITER_BRANCH} \ && git submodule update --init --recursive \ && pip install -r requirements.txt -RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl +RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install +FROM base AS debs +RUN mkdir /app/debs +RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs + FROM base AS final -RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \ - dpkg -i /install/*deb \ - && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ - && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status -RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \ - dpkg -i /install/*deb \ - && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ - && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ pip install /install/*.whl RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ @@ -151,11 +121,6 @@ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ pip install /install/*.whl ARG BASE_IMAGE -ARG HIPBLAS_COMMON_BRANCH -ARG HIPBLASLT_BRANCH -ARG LEGACY_HIPBLASLT_OPTION -ARG RCCL_BRANCH -ARG RCCL_REPO ARG TRITON_BRANCH ARG TRITON_REPO ARG PYTORCH_BRANCH @@ -167,11 +132,6 @@ ARG FA_REPO ARG AITER_BRANCH ARG AITER_REPO RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ - && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \ - && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \ - && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \ - && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \ - && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \ && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \ && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \ && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \ @@ -179,5 +139,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ + && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \ && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ - && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt + && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt \ No newline at end of file diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 9270b48c54d4b..9942b7626f81e 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -16,7 +16,8 @@ ENV LANG=C.UTF-8 \ RUN microdnf install -y \ which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile && \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile \ + clang llvm-devel llvm-static clang-devel && \ microdnf clean all # Python Installation @@ -191,7 +192,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ -DCOMPILER_RT_BUILD_ORC=OFF \ -DCOMPILER_RT_INCLUDE_TESTS=OFF \ ${CMAKE_ARGS} -GNinja ../llvm \ - && ninja install . && \ # build llvmlite cd ../../llvmlite && python setup.py bdist_wheel && \ @@ -200,6 +200,45 @@ RUN --mount=type=cache,target=/root/.cache/uv \ sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ fi && python setup.py bdist_wheel +# Edit aws-lc-sys to support s390x +FROM python-install AS aws-lc-sys-editor +WORKDIR /tmp +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +ARG AWS_LC_VERSION=v0.30.0 +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + git clone --recursive https://github.com/aws/aws-lc-rs.git && \ + cd aws-lc-rs && \ + git checkout tags/aws-lc-sys/${AWS_LC_VERSION} && \ + git submodule sync && \ + git submodule update --init --recursive && \ + cd aws-lc-sys && \ + sed -i '682 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c && \ + sed -i '712 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c && \ + sed -i '747 s/strncmp(buf, "-----END ", 9)/memcmp(buf, "-----END ", 9)/' aws-lc/crypto/pem/pem_lib.c + +# Build Outlines Core +FROM python-install AS outlines-core-builder +WORKDIR /tmp +ENV CARGO_HOME=/root/.cargo +ENV RUSTUP_HOME=/root/.rustup +ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +ARG OUTLINES_CORE_VERSION=0.2.10 +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \ + --mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \ + --mount=type=bind,from=aws-lc-sys-editor,source=/tmp/aws-lc-rs/aws-lc-sys,target=/tmp/aws-lc-sys,rw \ + git clone https://github.com/dottxt-ai/outlines-core.git && \ + cd outlines-core && \ + git checkout tags/${OUTLINES_CORE_VERSION} && \ + sed -i "s/version = \"0.0.0\"/version = \"${OUTLINES_CORE_VERSION}\"/" Cargo.toml && \ + echo '[patch.crates-io]' >> Cargo.toml && \ + echo 'aws-lc-sys = { path = "/tmp/aws-lc-sys" }' >> Cargo.toml && \ + uv pip install maturin && \ + python -m maturin build --release --out dist # Final build stage FROM python-install AS vllm-cpu @@ -230,6 +269,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ --mount=type=bind,from=numba-builder,source=/tmp/llvmlite/dist,target=/tmp/llvmlite-wheels/ \ --mount=type=bind,from=numba-builder,source=/tmp/numba/dist,target=/tmp/numba-wheels/ \ + --mount=type=bind,from=outlines-core-builder,source=/tmp/outlines-core/dist,target=/tmp/outlines-core/dist/ \ sed -i '/^torch/d' requirements/build.txt && \ ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl) && \ VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl) && \ @@ -237,6 +277,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl) && \ LLVM_WHL_FILE=$(ls /tmp/llvmlite-wheels/*.whl) && \ NUMBA_WHL_FILE=$(ls /tmp/numba-wheels/*.whl) && \ + OUTLINES_CORE_WHL_FILE=$(ls /tmp/outlines-core/dist/*.whl) && \ uv pip install -v \ $ARROW_WHL_FILE \ $VISION_WHL_FILE \ @@ -244,6 +285,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ $TORCH_WHL_FILE \ $LLVM_WHL_FILE \ $NUMBA_WHL_FILE \ + $OUTLINES_CORE_WHL_FILE \ --index-strategy unsafe-best-match \ -r requirements/build.txt \ -r requirements/cpu.txt diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 65d2e5036b783..ef422352509a9 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -1,12 +1,10 @@ FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base -RUN rm /etc/apt/sources.list.d/intel-graphics.list +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ + echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ + add-apt-repository -y ppa:kobuk-team/intel-graphics RUN apt clean && apt-get update -y && \ - apt-get install -y software-properties-common && \ - add-apt-repository ppa:deadsnakes/ppa && \ - apt-get install -y python3.10 python3.10-distutils && \ - curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && \ apt-get install -y --no-install-recommends --fix-missing \ curl \ ffmpeg \ @@ -17,17 +15,29 @@ RUN apt clean && apt-get update -y && \ libgl1 \ lsb-release \ numactl \ - python3.10-dev \ - wget + wget \ + vim \ + python3.12 \ + python3.12-dev \ + python3-pip +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 -RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 -RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 +RUN apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing + +RUN wget https://github.com/uxlfoundation/oneCCL/releases/download/2021.15.4/intel-oneccl-2021.15.4.11_offline.sh +RUN bash intel-oneccl-2021.15.4.11_offline.sh -a --silent --eula accept && echo "source /opt/intel/oneapi/setvars.sh --force" >> /root/.bashrc +SHELL ["bash", "-c"] +CMD ["bash", "-c", "source /root/.bashrc && exec bash"] WORKDIR /workspace/vllm COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt COPY requirements/common.txt /workspace/vllm/requirements/common.txt +# suppress the python externally managed environment error +RUN python3 -m pip config set global.break-system-packages true + RUN --mount=type=cache,target=/root/.cache/pip \ pip install --no-cache-dir \ -r requirements/xpu.txt @@ -54,8 +64,9 @@ FROM vllm-base AS vllm-openai RUN --mount=type=cache,target=/root/.cache/pip \ pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope -ENV VLLM_USAGE_SOURCE production-docker-image \ - TRITON_XPU_PROFILE 1 +RUN --mount=type=cache,target=/root/.cache/pip \ + pip uninstall oneccl oneccl-devel -y + # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/docs/.nav.yml b/docs/.nav.yml index dbac0e12f1bf2..c103ed476d76d 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -32,10 +32,7 @@ nav: - models/pooling_models.md - models/extensions - Hardware Supported Models: models/hardware_supported_models - - Features: - - features/compatibility_matrix.md - - features/* - - features/quantization + - Features: features - Developer Guide: - contributing/README.md - General: @@ -47,11 +44,12 @@ nav: - contributing/model/registration.md - contributing/model/tests.md - contributing/model/multimodal.md + - contributing/model/transcription.md - CI: contributing/ci - Design Documents: design - API Reference: - api/README.md - - api/vllm/* + - api/vllm - CLI Reference: cli - Community: - community/* diff --git a/docs/README.md b/docs/README.md index 683e1d37563f5..ae95717def4cd 100644 --- a/docs/README.md +++ b/docs/README.md @@ -56,7 +56,7 @@ vLLM is flexible and easy to use with: - Tensor, pipeline, data and expert parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators. +- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend. - Prefix caching support - Multi-LoRA support diff --git a/docs/api/README.md b/docs/api/README.md index 57142e8f5625d..86e310f567dd3 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -14,7 +14,7 @@ API documentation for vLLM's configuration classes. - [vllm.config.LoRAConfig][] - [vllm.config.MultiModalConfig][] - [vllm.config.PoolerConfig][] -- [vllm.config.DecodingConfig][] +- [vllm.config.StructuredOutputsConfig][] - [vllm.config.ObservabilityConfig][] - [vllm.config.KVTransferConfig][] - [vllm.config.CompilationConfig][] @@ -46,7 +46,6 @@ Engine classes for offline and online inference. Inference parameters for vLLM APIs. [](){ #sampling-params } -[](){ #pooling-params } - [vllm.SamplingParams][] - [vllm.PoolingParams][] diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 221a7bd96213f..a3004249b758b 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -2,6 +2,8 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ), August 30th 2025. [[Slides]](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA) +- [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet), August 27th 2025. [[Slides]](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing) - [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) - [vLLM Korea Meetup](https://luma.com/cgcgprmh), August 19th 2025. [[Slides]](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 0ab2ae58ad861..5564d8a81d937 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -175,6 +175,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u Known supported models: - GLM-4.5V GLM-4.1V () +- InternVL () - Kimi-VL () - Llama4 () - MiniCPM-V-2.5 or above (, ) @@ -210,7 +211,7 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 !!! note API server scale-out disables [multi-modal IPC caching](#ipc-caching) - because it requires a one-to-one correspondance between API and engine core processes. + because it requires a one-to-one correspondence between API and engine core processes. This does not impact [multi-modal processor caching](#processor-caching). @@ -227,9 +228,23 @@ to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalPro ### IPC Caching Multi-modal IPC caching is automatically enabled when -there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes, +there is a one-to-one correspondence between API (`P0`) and engine core (`P1`) processes, to avoid repeatedly transferring the same multi-modal inputs between them. +#### Key-Replicated Cache + +By default, IPC caching uses a **key-replicated cache**, where cache keys exist +in both the API (`P0`) and engine core (`P1`) processes, but the actual cache +data resides only in `P1`. + +#### Shared Memory Cache + +When multiple worker processes are involved (e.g., when TP > 1), a +**shared-memory cache** is more efficient. This can be enabled by setting +`mm_processor_cache_type="shm"`. In this mode, cache keys are stored +on `P0`, while the cache data itself lives in shared memory accessible by all +processes. + ### Configuration You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). @@ -244,6 +259,12 @@ Examples: llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=8) +# Use a shared-memory based IPC cache +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + tensor_parallel_size=2, + mm_processor_cache_type="shm", + mm_processor_cache_gb=8) + # Disable the cache llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=0) @@ -253,11 +274,12 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: -| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | -|-------------------|-------------|------------|------------|-------------| -| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | -| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | -| ❌ | ❌ | N/A | N/A | `0` | +| mm_processor_cache_type | Cache Type | `P0` Cache | `P1` Engine Cache | `P1` Worker Cache | Max. Memory | +|-------------------|-------------|------------|------------|-------------|-------------| +| lru | Processor Caching | K + V | N/A | N/A | `mm_processor_cache_gb * data_parallel_size` | +| lru | Key-Replicated Caching | K | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| shm | Shared Memory Caching | K | N/A | V | `mm_processor_cache_gb * api_server_count` | +| N/A | Disabled | N/A | N/A | N/A | `0` | K: Stores the hashes of multi-modal items V: Stores the processed tensor data of multi-modal items diff --git a/docs/contributing/README.md b/docs/contributing/README.md index 5a2a70d57e85f..b0a95b3b3d3a5 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -26,113 +26,123 @@ See . ## Developing ---8<-- "docs/getting_started/installation/python_env_setup.inc.md" - -Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation. -Check out the [building from source][build-from-source] documentation for details. - -For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations. - -### Building the docs with MkDocs - -#### Introduction to MkDocs - -[MkDocs](https://github.com/mkdocs/mkdocs) is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file. - -#### Install MkDocs and Plugins - -Install MkDocs along with the [plugins](https://github.com/vllm-project/vllm/blob/main/mkdocs.yaml) used in the vLLM documentation, as well as required dependencies: - -```bash -uv pip install -r requirements/docs.txt -``` - -!!! note - Ensure that your Python version is compatible with the plugins (e.g., `mkdocs-awesome-nav` requires Python 3.10+) - -#### Verify Installation - -Confirm that MkDocs is correctly installed: - -```bash -mkdocs --version -``` - -Example output: - -```console -mkdocs, version 1.6.1 from /opt/miniconda3/envs/mkdoc/lib/python3.10/site-packages/mkdocs (Python 3.10) -``` - -#### Clone the `vLLM` repository +The first step of contributing to vLLM is to clone the GitHub repository: ```bash git clone https://github.com/vllm-project/vllm.git cd vllm ``` -#### Start the Development Server +Then, configure your Python virtual environment. -MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. Make sure you're in the same directory as the `mkdocs.yml` configuration file, and then start the server by running the `mkdocs serve` command: +--8<-- "docs/getting_started/installation/python_env_setup.inc.md" + +If you are only developing vLLM's Python code, install vLLM using: ```bash -mkdocs serve +VLLM_USE_PRECOMPILED=1 uv pip install -e . ``` -Example output: +If you are developing vLLM's Python and CUDA/C++ code, install vLLM using: -```console -INFO - Documentation built in 106.83 seconds -INFO - [22:02:02] Watching paths for changes: 'docs', 'mkdocs.yaml' -INFO - [22:02:02] Serving on http://127.0.0.1:8000/ +```bash +uv pip install -e . ``` -#### View in Your Browser +For more details about installing from source and installing for other hardware, check out the [installation instructions](../getting_started/installation/README.md) for your hardware and head to the "Build wheel from source" section. -Open up [http://127.0.0.1:8000/](http://127.0.0.1:8000/) in your browser to see a live preview:. - -#### Learn More - -For additional features and advanced configurations, refer to the official [MkDocs Documentation](https://www.mkdocs.org/). - -## Testing - -??? console "Commands" - - ```bash - # These commands are only for Nvidia CUDA platforms. - uv pip install -r requirements/common.txt -r requirements/dev.txt --torch-backend=auto - - # Linting, formatting and static type checking - pre-commit install - - # You can manually run pre-commit with - pre-commit run --all-files --show-diff-on-failure - - # To manually run something from CI that does not run - # locally by default, you can run: - pre-commit run mypy-3.9 --hook-stage manual --all-files - - # Unit tests - pytest tests/ - - # Run tests for a single test file with detailed output - pytest -s -v tests/test_logger.py - ``` +For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations. !!! tip - Since the ships with Python 3.12, all tests in CI (except `mypy`) are run with Python 3.12. + vLLM is compatible with Python versions 3.9 to 3.12. However, vLLM's default [Dockerfile](gh-file:docker/Dockerfile) ships with Python 3.12 and tests in CI (except `mypy`) are run with Python 3.12. Therefore, we recommend developing with Python 3.12 to minimise the chance of your local environment clashing with our CI environment. -!!! note "Install python3-dev if Python.h is missing" +### Linting + +vLLM uses `pre-commit` to lint and format the codebase. See if `pre-commit` is new to you. Setting up `pre-commit` is as easy as: + +```bash +uv pip install pre-commit +pre-commit install +``` + +vLLM's `pre-commit` hooks will now run automatically every time you commit. + +!!! tip "Tips" + You can manually run the `pre-commit` hooks using: + + ```bash + pre-commit run # runs on staged files + pre-commit run -a # runs on all files (short for --all-files) + ``` + + --- + + Some `pre-commit` hooks only run in CI. If you need to, you can run them locally with: + + ```bash + pre-commit run --hook-stage manual markdownlint + pre-commit run --hook-stage manual mypy-3.9 + ``` + +### Documentation + +MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file, . + +Get started with: + +```bash +uv pip install -r requirements/docs.txt +``` + +!!! tip + Ensure that your Python version is compatible with the plugins + (e.g., `mkdocs-awesome-nav` requires Python 3.10+) + +MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. +From the root of the repository, run: + +```bash +mkdocs serve # with API ref (~10 minutes) +API_AUTONAV_EXCLUDE=vllm mkdocs serve # API ref off (~15 seconds) +``` + +Once you see `Serving on http://127.0.0.1:8000/` in the logs, the live preview is ready! +Open in your browser to see it. + +For additional features and advanced configurations, refer to the: + +- [MkDocs documentation](https://www.mkdocs.org/) +- [Material for MkDocs documentation](https://squidfunk.github.io/mkdocs-material/) (the MkDocs theme we use) + +### Testing + +vLLM uses `pytest` to test the codebase. + +```bash +# Install the test dependencies used in CI (CUDA only) +uv pip install -r requirements/common.txt -r requirements/dev.txt --torch-backend=auto + +# Install some common test dependencies (hardware agnostic) +uv pip install pytest pytest-asyncio + +# Run all tests +pytest tests/ + +# Run tests for a single test file with detailed output +pytest -s -v tests/test_logger.py +``` + +!!! tip "Install python3-dev if Python.h is missing" If any of the above commands fails with `Python.h: No such file or directory`, install `python3-dev` with `sudo apt install python3-dev`. -!!! note +!!! warning "Warnings" Currently, the repository is not fully checked by `mypy`. -!!! note + --- + Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU platform to run unit tests locally, rely on the continuous integration system to run the tests for now. @@ -194,8 +204,7 @@ appropriately to indicate the type of change. Please use one of the following: The PR needs to meet the following code quality standards: - We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html). -- Pass all linter checks. Please use `pre-commit` to format your code. See - if `pre-commit` is new to you. +- Pass all linter checks. - The code needs to be well-documented to ensure future contributors can easily understand the code. - Include sufficient tests to ensure the project stays correct and robust. This diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md index 2bbed778f3c6a..2a03ce1dffd63 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -1,9 +1,787 @@ +--- +toc_depth: 4 +--- + # Benchmark Suites -vLLM contains two sets of benchmarks: +vLLM provides comprehensive benchmarking tools for performance testing and evaluation: -- [Performance benchmarks][performance-benchmarks] -- [Nightly benchmarks][nightly-benchmarks] +- **[Benchmark CLI]**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing +- **[Performance benchmarks][performance-benchmarks]**: Automated CI benchmarks for development +- **[Nightly benchmarks][nightly-benchmarks]**: Comparative benchmarks against alternatives + +[Benchmark CLI]: #benchmark-cli + +## Benchmark CLI + +This section guides you through running benchmark tests with the extensive +datasets supported on vLLM. It's a living document, updated as new features and datasets +become available. + +### Dataset Overview + + + +| Dataset | Online | Offline | Data Path | +|---------|--------|---------|-----------| +| ShareGPT | ✅ | ✅ | `wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json` | +| ShareGPT4V (Image) | ✅ | ✅ | `wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json`
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
`wget http://images.cocodataset.org/zips/train2017.zip` | +| ShareGPT4Video (Video) | ✅ | ✅ | `git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video` | +| BurstGPT | ✅ | ✅ | `wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv` | +| Sonnet (deprecated) | ✅ | ✅ | Local file: `benchmarks/sonnet.txt` | +| Random | ✅ | ✅ | `synthetic` | +| RandomMultiModal (Image/Video) | 🟡 | 🚧 | `synthetic` | +| Prefix Repetition | ✅ | ✅ | `synthetic` | +| HuggingFace-VisionArena | ✅ | ✅ | `lmarena-ai/VisionArena-Chat` | +| HuggingFace-MMVU | ✅ | ✅ | `yale-nlp/MMVU` | +| HuggingFace-InstructCoder | ✅ | ✅ | `likaixin/InstructCoder` | +| HuggingFace-AIMO | ✅ | ✅ | `AI-MO/aimo-validation-aime`, `AI-MO/NuminaMath-1.5`, `AI-MO/NuminaMath-CoT` | +| HuggingFace-Other | ✅ | ✅ | `lmms-lab/LLaVA-OneVision-Data`, `Aeala/ShareGPT_Vicuna_unfiltered` | +| HuggingFace-MTBench | ✅ | ✅ | `philschmid/mt-bench` | +| HuggingFace-Blazedit | ✅ | ✅ | `vdaita/edit_5k_char`, `vdaita/edit_10k_char` | +| Spec Bench | ✅ | ✅ | `wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl` | +| Custom | ✅ | ✅ | Local file: `data.jsonl` | + +Legend: + +- ✅ - supported +- 🟡 - Partial support +- 🚧 - to be supported + +!!! note + HuggingFace dataset's `dataset-name` should be set to `hf`. + For local `dataset-path`, please set `hf-name` to its Hugging Face ID like + + ```bash + --dataset-path /datasets/VisionArena-Chat/ --hf-name lmarena-ai/VisionArena-Chat + ``` + +### Examples + +#### 🚀 Online Benchmark + +
+Show more + +First start serving your model + +```bash +vllm serve NousResearch/Hermes-3-Llama-3.1-8B +``` + +Then run the benchmarking script + +```bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +vllm bench serve \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --endpoint /v1/completions \ + --dataset-name sharegpt \ + --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json \ + --num-prompts 10 +``` + +If successful, you will see the following output + +```text +============ Serving Benchmark Result ============ +Successful requests: 10 +Benchmark duration (s): 5.78 +Total input tokens: 1369 +Total generated tokens: 2212 +Request throughput (req/s): 1.73 +Output token throughput (tok/s): 382.89 +Total Token throughput (tok/s): 619.85 +---------------Time to First Token---------------- +Mean TTFT (ms): 71.54 +Median TTFT (ms): 73.88 +P99 TTFT (ms): 79.49 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 7.91 +Median TPOT (ms): 7.96 +P99 TPOT (ms): 8.03 +---------------Inter-token Latency---------------- +Mean ITL (ms): 7.74 +Median ITL (ms): 7.70 +P99 ITL (ms): 8.39 +================================================== +``` + +##### Custom Dataset + +If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl + +```json +{"prompt": "What is the capital of India?"} +{"prompt": "What is the capital of Iran?"} +{"prompt": "What is the capital of China?"} +``` + +```bash +# start server +VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct +``` + +```bash +# run benchmarking script +vllm bench serve --port 9001 --save-result --save-detailed \ + --backend vllm \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --endpoint /v1/completions \ + --dataset-name custom \ + --dataset-path \ + --custom-skip-chat-template \ + --num-prompts 80 \ + --max-concurrency 1 \ + --temperature=0.3 \ + --top-p=0.75 \ + --result-dir "./log/" +``` + +You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. + +##### VisionArena Benchmark for Vision Language Models + +```bash +# need a model with vision capability here +vllm serve Qwen/Qwen2-VL-7B-Instruct +``` + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path lmarena-ai/VisionArena-Chat \ + --hf-split train \ + --num-prompts 1000 +``` + +##### InstructCoder Benchmark with Speculative Decoding + +``` bash +VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' +``` + +``` bash +vllm bench serve \ + --model meta-llama/Meta-Llama-3-8B-Instruct \ + --dataset-name hf \ + --dataset-path likaixin/InstructCoder \ + --num-prompts 2048 +``` + +##### Spec Bench Benchmark with Speculative Decoding + +``` bash +VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' +``` + +[SpecBench dataset](https://github.com/hemingkx/Spec-Bench) + +Run all categories: + +``` bash +# Download the dataset using: +# wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl + +vllm bench serve \ + --model meta-llama/Meta-Llama-3-8B-Instruct \ + --dataset-name spec_bench \ + --dataset-path "/data/spec_bench/question.jsonl" \ + --num-prompts -1 +``` + +Available categories include `[writing, roleplay, reasoning, math, coding, extraction, stem, humanities, translation, summarization, qa, math_reasoning, rag]`. + +Run only a specific category like "summarization": + +``` bash +vllm bench serve \ + --model meta-llama/Meta-Llama-3-8B-Instruct \ + --dataset-name spec_bench \ + --dataset-path "/data/spec_bench/question.jsonl" \ + --num-prompts -1 + --spec-bench-category "summarization" +``` + +##### Other HuggingFaceDataset Examples + +```bash +vllm serve Qwen/Qwen2-VL-7B-Instruct +``` + +`lmms-lab/LLaVA-OneVision-Data`: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path lmms-lab/LLaVA-OneVision-Data \ + --hf-split train \ + --hf-subset "chart2text(cauldron)" \ + --num-prompts 10 +``` + +`Aeala/ShareGPT_Vicuna_unfiltered`: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name hf \ + --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \ + --hf-split train \ + --num-prompts 10 +``` + +`AI-MO/aimo-validation-aime`: + +``` bash +vllm bench serve \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path AI-MO/aimo-validation-aime \ + --num-prompts 10 \ + --seed 42 +``` + +`philschmid/mt-bench`: + +``` bash +vllm bench serve \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path philschmid/mt-bench \ + --num-prompts 80 +``` + +`vdaita/edit_5k_char` or `vdaita/edit_10k_char`: + +``` bash +vllm bench serve \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path vdaita/edit_5k_char \ + --num-prompts 90 \ + --blazedit-min-distance 0.01 \ + --blazedit-max-distance 0.99 +``` + +##### Running With Sampling Parameters + +When using OpenAI-compatible backends such as `vllm`, optional sampling +parameters can be specified. Example client command: + +```bash +vllm bench serve \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --endpoint /v1/completions \ + --dataset-name sharegpt \ + --dataset-path /ShareGPT_V3_unfiltered_cleaned_split.json \ + --top-k 10 \ + --top-p 0.9 \ + --temperature 0.5 \ + --num-prompts 10 +``` + +##### Running With Ramp-Up Request Rate + +The benchmark tool also supports ramping up the request rate over the +duration of the benchmark run. This can be useful for stress testing the +server or finding the maximum throughput that it can handle, given some latency budget. + +Two ramp-up strategies are supported: + +- `linear`: Increases the request rate linearly from a start value to an end value. +- `exponential`: Increases the request rate exponentially. + +The following arguments can be used to control the ramp-up: + +- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`). +- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. +- `--ramp-up-end-rps`: The request rate at the end of the benchmark. + +
+ +#### 📈 Offline Throughput Benchmark + +
+Show more + +```bash +vllm bench throughput \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset-name sonnet \ + --dataset-path vllm/benchmarks/sonnet.txt \ + --num-prompts 10 +``` + +If successful, you will see the following output + +```text +Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s +Total num prompt tokens: 5014 +Total num output tokens: 1500 +``` + +##### VisionArena Benchmark for Vision Language Models + +```bash +vllm bench throughput \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --backend vllm-chat \ + --dataset-name hf \ + --dataset-path lmarena-ai/VisionArena-Chat \ + --num-prompts 1000 \ + --hf-split train +``` + +The `num prompt tokens` now includes image token counts + +```text +Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s +Total num prompt tokens: 14527 +Total num output tokens: 1280 +``` + +##### InstructCoder Benchmark with Speculative Decoding + +``` bash +VLLM_WORKER_MULTIPROC_METHOD=spawn \ +VLLM_USE_V1=1 \ +vllm bench throughput \ + --dataset-name=hf \ + --dataset-path=likaixin/InstructCoder \ + --model=meta-llama/Meta-Llama-3-8B-Instruct \ + --input-len=1000 \ + --output-len=100 \ + --num-prompts=2048 \ + --async-engine \ + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' +``` + +```text +Throughput: 104.77 requests/s, 23836.22 total tokens/s, 10477.10 output tokens/s +Total num prompt tokens: 261136 +Total num output tokens: 204800 +``` + +##### Other HuggingFaceDataset Examples + +`lmms-lab/LLaVA-OneVision-Data`: + +```bash +vllm bench throughput \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --backend vllm-chat \ + --dataset-name hf \ + --dataset-path lmms-lab/LLaVA-OneVision-Data \ + --hf-split train \ + --hf-subset "chart2text(cauldron)" \ + --num-prompts 10 +``` + +`Aeala/ShareGPT_Vicuna_unfiltered`: + +```bash +vllm bench throughput \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --backend vllm-chat \ + --dataset-name hf \ + --dataset-path Aeala/ShareGPT_Vicuna_unfiltered \ + --hf-split train \ + --num-prompts 10 +``` + +`AI-MO/aimo-validation-aime`: + +```bash +vllm bench throughput \ + --model Qwen/QwQ-32B \ + --backend vllm \ + --dataset-name hf \ + --dataset-path AI-MO/aimo-validation-aime \ + --hf-split train \ + --num-prompts 10 +``` + +Benchmark with LoRA adapters: + +``` bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json +vllm bench throughput \ + --model meta-llama/Llama-2-7b-hf \ + --backend vllm \ + --dataset_path /ShareGPT_V3_unfiltered_cleaned_split.json \ + --dataset_name sharegpt \ + --num-prompts 10 \ + --max-loras 2 \ + --max-lora-rank 8 \ + --enable-lora \ + --lora-path yard1/llama-2-7b-sql-lora-test +``` + +
+ +#### 🛠️ Structured Output Benchmark + +
+Show more + +Benchmark the performance of structured output generation (JSON, grammar, regex). + +##### Server Setup + +```bash +vllm serve NousResearch/Hermes-3-Llama-3.1-8B +``` + +##### JSON Schema Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset json \ + --structured-output-ratio 1.0 \ + --request-rate 10 \ + --num-prompts 1000 +``` + +##### Grammar-based Generation Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset grammar \ + --structure-type grammar \ + --request-rate 10 \ + --num-prompts 1000 +``` + +##### Regex-based Generation Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset regex \ + --request-rate 10 \ + --num-prompts 1000 +``` + +##### Choice-based Generation Benchmark + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset choice \ + --request-rate 10 \ + --num-prompts 1000 +``` + +##### XGrammar Benchmark Dataset + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset xgrammar_bench \ + --request-rate 10 \ + --num-prompts 1000 +``` + +
+ +#### 📚 Long Document QA Benchmark + +
+Show more + +Benchmark the performance of long document question-answering with prefix caching. + +##### Basic Long Document QA Test + +```bash +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 16 \ + --document-length 2000 \ + --output-len 50 \ + --repeat-count 5 +``` + +##### Different Repeat Modes + +```bash +# Random mode (default) - shuffle prompts randomly +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode random + +# Tile mode - repeat entire prompt list in sequence +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode tile + +# Interleave mode - repeat each prompt consecutively +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode interleave +``` + +
+ +#### 🗂️ Prefix Caching Benchmark + +
+Show more + +Benchmark the efficiency of automatic prefix caching. + +##### Fixed Prompt with Prefix Caching + +```bash +python3 benchmarks/benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-prompts 1 \ + --repeat-count 100 \ + --input-length-range 128:256 +``` + +##### ShareGPT Dataset with Prefix Caching + +```bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +python3 benchmarks/benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \ + --enable-prefix-caching \ + --num-prompts 20 \ + --repeat-count 5 \ + --input-length-range 128:256 +``` + +##### Prefix Repetition Dataset + +```bash +vllm bench serve \ + --backend openai \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-name prefix_repetition \ + --num-prompts 100 \ + --prefix-repetition-prefix-len 512 \ + --prefix-repetition-suffix-len 128 \ + --prefix-repetition-num-prefixes 5 \ + --prefix-repetition-output-len 128 +``` + +
+ +#### ⚡ Request Prioritization Benchmark + +
+Show more + +Benchmark the performance of request prioritization in vLLM. + +##### Basic Prioritization Test + +```bash +python3 benchmarks/benchmark_prioritization.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --input-len 128 \ + --output-len 64 \ + --num-prompts 100 \ + --scheduling-policy priority +``` + +##### Multiple Sequences per Prompt + +```bash +python3 benchmarks/benchmark_prioritization.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --input-len 128 \ + --output-len 64 \ + --num-prompts 100 \ + --scheduling-policy priority \ + --n 2 +``` + +
+ +#### 👁️ Multi-Modal Benchmark + +
+Show more + +Benchmark the performance of multi-modal requests in vLLM. + +##### Images (ShareGPT4V) + +Start vLLM: + +```bash +python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"image": 1}' \ + --allowed-local-media-path /path/to/sharegpt4v/images +``` + +Send requests with images: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completion +``` + +##### Videos (ShareGPT4Video) + +Start vLLM: + +```bash +python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dtype bfloat16 \ + --limit-mm-per-prompt '{"video": 1}' \ + --allowed-local-media-path /path/to/sharegpt4video/videos +``` + +Send requests with videos: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --dataset-name sharegpt \ + --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \ + --num-prompts 100 \ + --save-result \ + --result-dir ~/vllm_benchmark_results \ + --save-detailed \ + --endpoint /v1/chat/completion +``` + +##### Synthetic Random Images (random-mm) + +Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. + +Notes: + +- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. +- Video sampling is not yet implemented. + +Start the server (example): + +```bash +vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ + --dtype bfloat16 \ + --max-model-len 16384 \ + --limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --mm-processor-kwargs max_pixels=1003520 +``` + +Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. + +Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name random-mm \ + --num-prompts 100 \ + --max-concurrency 10 \ + --random-prefix-len 25 \ + --random-input-len 300 \ + --random-output-len 40 \ + --random-range-ratio 0.2 \ + --random-mm-base-items-per-request 2 \ + --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --random-mm-bucket-config '{(224, 224, 1): 1.0}' \ + --request-rate inf \ + --ignore-eos \ + --seed 42 +``` + +The number of items per request can be controlled by passing multiple image buckets: + +```bash + --random-mm-base-items-per-request 2 \ + --random-mm-num-mm-items-range-ratio 0.5 \ + --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \ + --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \ +``` + +Flags specific to `random-mm`: + +- `--random-mm-base-items-per-request`: base number of multimodal items per request. +- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items. +- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'. +- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported). + +Behavioral notes: + +- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. + +How sampling works: + +- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. +- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. +- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. +This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. +- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. + +
[](){ #performance-benchmarks } @@ -11,9 +789,39 @@ vLLM contains two sets of benchmarks: The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM. +### Manually Trigger the benchmark + +Use [vllm-ci-test-repo images](https://gallery.ecr.aws/q9t5s3a7/vllm-ci-test-repo) with vLLM benchmark suite. +For CPU environment, please use the image with "-cpu" postfix. + +Here is an example for docker run command for CPU. + +```bash +docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN='' --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:1da94e673c257373280026f75ceb4effac80e892-cpu +``` + +Then, run below command inside the docker instance. + +```bash +bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +``` + +When run, benchmark script generates results under **benchmark/results** folder, along with the benchmark_results.md and benchmark_results.json. + +#### Runtime environment variables + +- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0. +- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file). +- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file). +- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file). +- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string. +- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string. + +For more results visualization, check the [visualizing the results](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md#visualizing-the-results). + The latest performance results are hosted on the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm). -More information on the performance benchmarks and their parameters can be found [here](gh-file:.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md). +More information on the performance benchmarks and their parameters can be found in [Benchmark README](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md) and [performance benchmark description](gh-file:.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md). [](){ #nightly-benchmarks } diff --git a/docs/contributing/incremental_build.md b/docs/contributing/incremental_build.md index 0e34e69245afb..cc01a60ce1e7f 100644 --- a/docs/contributing/incremental_build.md +++ b/docs/contributing/incremental_build.md @@ -40,6 +40,16 @@ python tools/generate_cmake_presets.py The script will prompt you if it cannot automatically determine certain paths (e.g., `nvcc` or a specific Python executable for your vLLM development environment). Follow the on-screen prompts. If an existing `CMakeUserPresets.json` is found, the script will ask for confirmation before overwriting it. +**Force overwrite existing file:** + +To automatically overwrite an existing `CMakeUserPresets.json` without prompting, use the `--force-overwrite` flag: + +```console +python tools/generate_cmake_presets.py --force-overwrite +``` + +This is particularly useful in automated scripts or CI/CD environments where interactive prompts are not desired. + After running the script, a `CMakeUserPresets.json` file will be created in the root of your vLLM repository. ### Example `CMakeUserPresets.json` diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md index 0ca77fa499db7..36068bc14876b 100644 --- a/docs/contributing/model/README.md +++ b/docs/contributing/model/README.md @@ -3,7 +3,7 @@ !!! important Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve ` works first! -vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/compatibility_matrix.md) to optimize their performance. +vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/README.md#compatibility-matrix) to optimize their performance. The complexity of integrating a model into vLLM depends heavily on the model's architecture. The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. @@ -15,6 +15,7 @@ Read through these pages for a step-by-step guide: - [Registering a Model](registration.md) - [Unit Testing](tests.md) - [Multi-Modal Support](multimodal.md) +- [Speech-to-Text Support](transcription.md) !!! tip If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues) diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index dc742c8fcf2cd..87d34d207cde3 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -840,7 +840,6 @@ Some HF processors directly insert feature tokens without replacing anything in Examples: - BLIP-2 (insert at start of prompt): -- Florence2 (insert at start of prompt): - Molmo (insert after `<|endoftext|>` token): ### Handling prompt updates unrelated to multi-modal data diff --git a/docs/contributing/model/transcription.md b/docs/contributing/model/transcription.md new file mode 100644 index 0000000000000..62e58e5c6ac58 --- /dev/null +++ b/docs/contributing/model/transcription.md @@ -0,0 +1,276 @@ +# Speech-to-Text (Transcription/Translation) Support + +This document walks you through the steps to add support for speech-to-text (ASR) models to vLLM’s transcription and translation APIs by implementing [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription]. +Please refer to the [supported models](../../models/supported_models.md#transcription) for further guidance. + +## Update the base vLLM model + +It is assumed you have already implemented your model in vLLM according to the basic model guide. Extend your model with the [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription] interface and implement the following class attributes and methods. + +### `supported_languages` and `supports_transcription_only` + +Declare supported languages and capabilities: + +- The `supported_languages` mapping is validated at init time. +- Set `supports_transcription_only=True` if the model should not serve text generation (eg Whisper). + +??? code "supported_languages and supports_transcription_only" + ```python + from typing import ClassVar, Mapping, Optional, Literal + import numpy as np + import torch + from torch import nn + + from vllm.config import ModelConfig, SpeechToTextConfig + from vllm.inputs.data import PromptType + from vllm.model_executor.models.interfaces import SupportsTranscription + + class YourASRModel(nn.Module, SupportsTranscription): + # Map of ISO 639-1 language codes to language names + supported_languages: ClassVar[Mapping[str, str]] = { + "en": "English", + "it": "Italian", + # ... add more as needed + } + + # If your model only supports audio-conditioned generation + # (no text-only generation), enable this flag. + supports_transcription_only: ClassVar[bool] = True + ``` + +Provide an ASR configuration via [get_speech_to_text_config][vllm.model_executor.models.interfaces.SupportsTranscription.get_speech_to_text_config]. + +This is for controlling general behavior of the API when serving your model: + +??? code "get_speech_to_text_config()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_speech_to_text_config( + cls, + model_config: ModelConfig, + task_type: Literal["transcribe", "translate"], + ) -> SpeechToTextConfig: + return SpeechToTextConfig( + sample_rate=16_000, + max_audio_clip_s=30, + # Set to None to disable server-side chunking if your + # model/processor handles it already + min_energy_split_window_size=None, + ) + ``` + +See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls. + +Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.data.PromptType]. There are two common patterns: + +#### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n) + +Return a dict containing `multi_modal_data` with the audio, and either a `prompt` string or `prompt_token_ids`: + +??? code "get_generation_prompt()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: + # Example with a free-form instruction prompt + task_word = "Transcribe" if task_type == "transcribe" else "Translate" + prompt = ( + "user\n" + f"{task_word} this audio: " + "\nmodel\n" + ) + + return { + "multi_modal_data": {"audio": (audio, stt_config.sample_rate)}, + "prompt": prompt, + } + ``` + + For further clarification on multi modal inputs, please refer to [Multi-Modal Inputs](../../features/multimodal_inputs.md). + +#### Encoder–decoder audio-only (e.g., Whisper) + +Return a dict with separate `encoder_prompt` and `decoder_prompt` entries: + +??? code "get_generation_prompt()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: + if language is None: + raise ValueError("Language must be specified") + + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + }, + "decoder_prompt": ( + (f"<|prev|>{request_prompt}" if request_prompt else "") + + f"<|startoftranscript|><|{language}|>" + + f"<|{task_type}|><|notimestamps|>" + ), + } + return cast(PromptType, prompt) + ``` + +### `validate_language` (optional) + +Language validation via [validate_language][vllm.model_executor.models.interfaces.SupportsTranscription.validate_language] + +If your model requires a language and you want a default, override this method (see Whisper): + +??? code "validate_language()" + ```python + @classmethod + def validate_language(cls, language: Optional[str]) -> Optional[str]: + if language is None: + logger.warning( + "Defaulting to language='en'. If you wish to transcribe audio in a different language, pass the `language` field.") + language = "en" + return super().validate_language(language) + ``` + +### `get_num_audio_tokens` (optional) + +Token accounting for streaming via [get_num_audio_tokens][vllm.model_executor.models.interfaces.SupportsTranscription.get_num_audio_tokens] + +Provide a fast duration→token estimate to improve streaming usage statistics: + +??? code "get_num_audio_tokens()" + ```python + class YourASRModel(nn.Module, SupportsTranscription): + ... + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: + # Return None if unknown; otherwise return an estimate. + return int(audio_duration_s * stt_config.sample_rate // 320) # example + ``` + +## Audio preprocessing and chunking + +The API server takes care of basic audio I/O and optional chunking before building prompts: + +- Resampling: Input audio is resampled to `SpeechToTextConfig.sample_rate` using `librosa`. +- Chunking: If `SpeechToTextConfig.allow_audio_chunking` is True and the duration exceeds `max_audio_clip_s`, the server splits the audio into overlapping chunks and generates a prompt per chunk. Overlap is controlled by `overlap_chunk_second`. +- Energy-aware splitting: When `min_energy_split_window_size` is set, the server finds low-energy regions to minimize cutting within words. + +Relevant server logic: + +??? code "_preprocess_speech_to_text()" + ```python + # vllm/entrypoints/openai/speech_to_text.py + async def _preprocess_speech_to_text(...): + language = self.model_cls.validate_language(request.language) + ... + y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) + duration = librosa.get_duration(y=y, sr=sr) + do_split_audio = (self.asr_config.allow_audio_chunking + and duration > self.asr_config.max_audio_clip_s) + chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) + prompts = [] + for chunk in chunks: + prompt = self.model_cls.get_generation_prompt( + audio=chunk, + stt_config=self.asr_config, + model_config=self.model_config, + language=language, + task_type=self.task_type, + request_prompt=request.prompt, + to_language=to_language, + ) + prompts.append(prompt) + return prompts, duration + ``` + +## Exposing tasks automatically + +vLLM automatically advertises transcription support if your model implements the interface: + +```python +if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + supported_tasks.append("transcription") +``` + +When enabled, the server initializes the transcription and translation handlers: + +```python +state.openai_serving_transcription = OpenAIServingTranscription(...) if "transcription" in supported_tasks else None +state.openai_serving_translation = OpenAIServingTranslation(...) if "transcription" in supported_tasks else None +``` + +No extra registration is required beyond having your model class available via the model registry and implementing `SupportsTranscription`. + +## Examples in-tree + +- Whisper encoder–decoder (audio-only): +- Voxtral decoder-only (audio embeddings + LLM): +- Gemma3n decoder-only with fixed instruction prompt: + +## Test with the API + +Once your model implements `SupportsTranscription`, you can test the endpoints (API mimics OpenAI): + +- Transcription (ASR): + + ```bash + curl -s -X POST \ + -H "Authorization: Bearer $VLLM_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@/path/to/audio.wav" \ + -F "model=$MODEL_ID" \ + http://localhost:8000/v1/audio/transcriptions + ``` + +- Translation (source → English unless otherwise supported): + + ```bash + curl -s -X POST \ + -H "Authorization: Bearer $VLLM_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@/path/to/audio.wav" \ + -F "model=$MODEL_ID" \ + http://localhost:8000/v1/audio/translations + ``` + +Or check out more examples in . + +!!! note + - If your model handles chunking internally (e.g., via its processor or encoder), set `min_energy_split_window_size=None` in the returned `SpeechToTextConfig` to disable server-side chunking. + - Implementing `get_num_audio_tokens` improves accuracy of streaming usage metrics (`prompt_tokens`) without an extra forward pass. + - For multilingual behavior, keep `supported_languages` aligned with actual model capabilities. diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index dffd62385e017..5b83d93274f0d 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -19,7 +19,7 @@ When using `vllm bench serve`, you can enable profiling by passing the `--profil Traces can be visualized using . !!! tip -You can directly call bench module without installing vllm using `python -m vllm.entrypoints.cli.main bench`. + You can directly call bench module without installing vLLM using `python -m vllm.entrypoints.cli.main bench`. !!! tip Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index 0b41e73b030cc..40a463a8a596c 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -1,41 +1,53 @@ -# Anything LLM +# AnythingLLM -[Anything LLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. +[AnythingLLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. ## Prerequisites -- Setup vLLM environment +Set up the vLLM environment: + +```bash +pip install vllm +``` ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with a supported chat-completion model, for example: -```bash -vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 -``` + ```bash + vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 + ``` -- Download and install [Anything LLM desktop](https://anythingllm.com/desktop). +1. Download and install [AnythingLLM Desktop](https://anythingllm.com/desktop). -- On the bottom left of open settings, AI Providers --> LLM: - - LLM Provider: Generic OpenAI - - Base URL: http://{vllm server host}:{vllm server port}/v1 - - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` +1. Configure the AI provider: -![](../../assets/deployment/anything-llm-provider.png) + - At the bottom, click the 🔧 wrench icon -> **Open settings** -> **AI Providers** -> **LLM**. + - Enter the following values: + - LLM Provider: Generic OpenAI + - Base URL: `http://{vllm server host}:{vllm server port}/v1` + - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` -- Back to home page, New Workspace --> create `vllm` workspace, and start to chat: + ![set AI providers](../../assets/deployment/anything-llm-provider.png) -![](../../assets/deployment/anything-llm-chat-without-doc.png) +1. Create a workspace: -- Click the upload button: - - upload the doc - - select the doc and move to the workspace - - save and embed + 1. At the bottom, click the ↺ back icon and back to workspaces. + 1. Create a workspace (e.g., `vllm`) and start chatting. -![](../../assets/deployment/anything-llm-upload-doc.png) + ![create a workspace](../../assets/deployment/anything-llm-chat-without-doc.png) -- Chat again: +1. Add a document. -![](../../assets/deployment/anything-llm-chat-with-doc.png) + 1. Click the 📎 attachment icon. + 1. Upload a document. + 1. Select and move the document into your workspace. + 1. Save and embed it. + + ![add a document](../../assets/deployment/anything-llm-upload-doc.png) + +1. Chat using your document as context. + + ![chat with your context](../../assets/deployment/anything-llm-chat-with-doc.png) diff --git a/docs/deployment/frameworks/autogen.md b/docs/deployment/frameworks/autogen.md index c255a85d38401..7517ee771c097 100644 --- a/docs/deployment/frameworks/autogen.md +++ b/docs/deployment/frameworks/autogen.md @@ -4,9 +4,7 @@ ## Prerequisites -- Setup vLLM environment - -- Setup [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment +Set up the vLLM and [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment: ```bash pip install vllm @@ -18,14 +16,14 @@ pip install -U "autogen-agentchat" "autogen-ext[openai]" ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -python -m vllm.entrypoints.openai.api_server \ - --model mistralai/Mistral-7B-Instruct-v0.2 -``` + ```bash + python -m vllm.entrypoints.openai.api_server \ + --model mistralai/Mistral-7B-Instruct-v0.2 + ``` -- Call it with AutoGen: +1. Call it with AutoGen: ??? code diff --git a/docs/deployment/frameworks/chatbox.md b/docs/deployment/frameworks/chatbox.md index cbca6e6282fc6..002935da56009 100644 --- a/docs/deployment/frameworks/chatbox.md +++ b/docs/deployment/frameworks/chatbox.md @@ -6,27 +6,31 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac ## Prerequisites -- Setup vLLM environment +Set up the vLLM environment: + +```bash +pip install vllm +``` ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve qwen/Qwen1.5-0.5B-Chat -``` + ```bash + vllm serve qwen/Qwen1.5-0.5B-Chat + ``` -- Download and install [Chatbox desktop](https://chatboxai.app/en#download). +1. Download and install [Chatbox desktop](https://chatboxai.app/en#download). -- On the bottom left of settings, Add Custom Provider +1. On the bottom left of settings, Add Custom Provider - API Mode: `OpenAI API Compatible` - Name: vllm - API Host: `http://{vllm server host}:{vllm server port}/v1` - API Path: `/chat/completions` - Model: `qwen/Qwen1.5-0.5B-Chat` -![](../../assets/deployment/chatbox-settings.png) + ![](../../assets/deployment/chatbox-settings.png) -- Go to `Just chat`, and start to chat: +1. Go to `Just chat`, and start to chat: -![](../../assets/deployment/chatbox-chat.png) + ![](../../assets/deployment/chatbox-chat.png) diff --git a/docs/deployment/frameworks/dify.md b/docs/deployment/frameworks/dify.md index 35f02c33cb02b..820ef0cbed9fa 100644 --- a/docs/deployment/frameworks/dify.md +++ b/docs/deployment/frameworks/dify.md @@ -8,44 +8,50 @@ This guide walks you through deploying Dify using a vLLM backend. ## Prerequisites -- Setup vLLM environment -- Install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) +Set up the vLLM environment: + +```bash +pip install vllm +``` + +And install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/). ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve Qwen/Qwen1.5-7B-Chat -``` + ```bash + vllm serve Qwen/Qwen1.5-7B-Chat + ``` -- Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): +1. Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): -```bash -git clone https://github.com/langgenius/dify.git -cd dify -cd docker -cp .env.example .env -docker compose up -d -``` + ```bash + git clone https://github.com/langgenius/dify.git + cd dify + cd docker + cp .env.example .env + docker compose up -d + ``` -- Open the browser to access `http://localhost/install`, config the basic login information and login. +1. Open the browser to access `http://localhost/install`, config the basic login information and login. -- In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it. +1. In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it. + +1. Fill in the model provider details as follows: -- Fill in the model provider details as follows: - **Model Type**: `LLM` - **Model Name**: `Qwen/Qwen1.5-7B-Chat` - **API Endpoint URL**: `http://{vllm_server_host}:{vllm_server_port}/v1` - **Model Name for API Endpoint**: `Qwen/Qwen1.5-7B-Chat` - **Completion Mode**: `Completion` -![](../../assets/deployment/dify-settings.png) + ![](../../assets/deployment/dify-settings.png) -- To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type: +1. To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type: -![](../../assets/deployment/dify-create-chatbot.png) + ![](../../assets/deployment/dify-create-chatbot.png) -- Click the chatbot you just created to open the chat interface and start interacting with the model: +1. Click the chatbot you just created to open the chat interface and start interacting with the model: -![](../../assets/deployment/dify-chat.png) + ![](../../assets/deployment/dify-chat.png) diff --git a/docs/deployment/frameworks/haystack.md b/docs/deployment/frameworks/haystack.md index 70b4b48d4543e..836305cf15c42 100644 --- a/docs/deployment/frameworks/haystack.md +++ b/docs/deployment/frameworks/haystack.md @@ -6,7 +6,7 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac ## Prerequisites -- Setup vLLM and Haystack environment +Set up the vLLM and Haystack environment: ```bash pip install vllm haystack-ai @@ -14,13 +14,13 @@ pip install vllm haystack-ai ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve mistralai/Mistral-7B-Instruct-v0.1 -``` + ```bash + vllm serve mistralai/Mistral-7B-Instruct-v0.1 + ``` -- Use the `OpenAIGenerator` and `OpenAIChatGenerator` components in Haystack to query the vLLM server. +1. Use the `OpenAIGenerator` and `OpenAIChatGenerator` components in Haystack to query the vLLM server. ??? code diff --git a/docs/deployment/frameworks/litellm.md b/docs/deployment/frameworks/litellm.md index c7e514f2276e0..0d6c3729911ad 100644 --- a/docs/deployment/frameworks/litellm.md +++ b/docs/deployment/frameworks/litellm.md @@ -13,7 +13,7 @@ And LiteLLM supports all models on VLLM. ## Prerequisites -- Setup vLLM and litellm environment +Set up the vLLM and litellm environment: ```bash pip install vllm litellm @@ -23,13 +23,13 @@ pip install vllm litellm ### Chat completion -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -vllm serve qwen/Qwen1.5-0.5B-Chat -``` + ```bash + vllm serve qwen/Qwen1.5-0.5B-Chat + ``` -- Call it with litellm: +1. Call it with litellm: ??? code @@ -51,13 +51,13 @@ vllm serve qwen/Qwen1.5-0.5B-Chat ### Embeddings -- Start the vLLM server with the supported embedding model, e.g. +1. Start the vLLM server with the supported embedding model, e.g. -```bash -vllm serve BAAI/bge-base-en-v1.5 -``` + ```bash + vllm serve BAAI/bge-base-en-v1.5 + ``` -- Call it with litellm: +1. Call it with litellm: ```python from litellm import embedding diff --git a/docs/deployment/frameworks/retrieval_augmented_generation.md b/docs/deployment/frameworks/retrieval_augmented_generation.md index d5f2ec302b6cd..d86ab1600f126 100644 --- a/docs/deployment/frameworks/retrieval_augmented_generation.md +++ b/docs/deployment/frameworks/retrieval_augmented_generation.md @@ -11,7 +11,7 @@ Here are the integrations: ### Prerequisites -- Setup vLLM and langchain environment +Set up the vLLM and langchain environment: ```bash pip install -U vllm \ @@ -22,33 +22,33 @@ pip install -U vllm \ ### Deploy -- Start the vLLM server with the supported embedding model, e.g. +1. Start the vLLM server with the supported embedding model, e.g. -```bash -# Start embedding service (port 8000) -vllm serve ssmits/Qwen2-7B-Instruct-embed-base -``` + ```bash + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + ``` -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -# Start chat service (port 8001) -vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 -``` + ```bash + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + ``` -- Use the script: +1. Use the script: -- Run the script +1. Run the script -```python -python retrieval_augmented_generation_with_langchain.py -``` + ```python + python retrieval_augmented_generation_with_langchain.py + ``` ## vLLM + llamaindex ### Prerequisites -- Setup vLLM and llamaindex environment +Set up the vLLM and llamaindex environment: ```bash pip install vllm \ @@ -60,24 +60,24 @@ pip install vllm \ ### Deploy -- Start the vLLM server with the supported embedding model, e.g. +1. Start the vLLM server with the supported embedding model, e.g. -```bash -# Start embedding service (port 8000) -vllm serve ssmits/Qwen2-7B-Instruct-embed-base -``` + ```bash + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + ``` -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with the supported chat completion model, e.g. -```bash -# Start chat service (port 8001) -vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 -``` + ```bash + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + ``` -- Use the script: +1. Use the script: -- Run the script +1. Run the script: -```python -python retrieval_augmented_generation_with_llamaindex.py -``` + ```python + python retrieval_augmented_generation_with_llamaindex.py + ``` diff --git a/docs/deployment/frameworks/streamlit.md b/docs/deployment/frameworks/streamlit.md index af0f0690c68e2..c119878f137a4 100644 --- a/docs/deployment/frameworks/streamlit.md +++ b/docs/deployment/frameworks/streamlit.md @@ -6,35 +6,33 @@ It can be quickly integrated with vLLM as a backend API server, enabling powerfu ## Prerequisites -- Setup vLLM environment +Set up the vLLM environment by installing all required packages: + +```bash +pip install vllm streamlit openai +``` ## Deploy -- Start the vLLM server with the supported chat completion model, e.g. +1. Start the vLLM server with a supported chat completion model, e.g. -```bash -vllm serve qwen/Qwen1.5-0.5B-Chat -``` + ```bash + vllm serve Qwen/Qwen1.5-0.5B-Chat + ``` -- Install streamlit and openai: +1. Use the script: -```bash -pip install streamlit openai -``` +1. Start the streamlit web UI and start to chat: -- Use the script: - -- Start the streamlit web UI and start to chat: - -```bash -streamlit run streamlit_openai_chatbot_webserver.py - -# or specify the VLLM_API_BASE or VLLM_API_KEY -VLLM_API_BASE="http://vllm-server-host:vllm-server-port/v1" \ + ```bash streamlit run streamlit_openai_chatbot_webserver.py -# start with debug mode to view more details -streamlit run streamlit_openai_chatbot_webserver.py --logger.level=debug -``` + # or specify the VLLM_API_BASE or VLLM_API_KEY + VLLM_API_BASE="http://vllm-server-host:vllm-server-port/v1" \ + streamlit run streamlit_openai_chatbot_webserver.py -![](../../assets/deployment/streamlit-chat.png) + # start with debug mode to view more details + streamlit run streamlit_openai_chatbot_webserver.py --logger.level=debug + ``` + + ![Chat with vLLM assistant in Streamlit](../../assets/deployment/streamlit-chat.png) diff --git a/docs/deployment/integrations/llamastack.md b/docs/deployment/integrations/llamastack.md index 28031f01f85e8..8eb7f8d81275d 100644 --- a/docs/deployment/integrations/llamastack.md +++ b/docs/deployment/integrations/llamastack.md @@ -1,6 +1,6 @@ # Llama Stack -vLLM is also available via [Llama Stack](https://github.com/meta-llama/llama-stack) . +vLLM is also available via [Llama Stack](https://github.com/llamastack/llama-stack). To install Llama Stack, run @@ -8,9 +8,9 @@ To install Llama Stack, run pip install llama-stack -q ``` -## Inference using OpenAI Compatible API +## Inference using OpenAI-Compatible API -Then start Llama Stack server pointing to your vLLM server with the following configuration: +Then start the Llama Stack server and configure it to point to your vLLM server with the following settings: ```yaml inference: @@ -20,15 +20,15 @@ inference: url: http://127.0.0.1:8000 ``` -Please refer to [this guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) for more details on this remote vLLM provider. +Please refer to [this guide](https://llama-stack.readthedocs.io/en/latest/providers/inference/remote_vllm.html) for more details on this remote vLLM provider. -## Inference via Embedded vLLM +## Inference using Embedded vLLM -An [inline vLLM provider](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/inference/vllm) +An [inline provider](https://github.com/llamastack/llama-stack/tree/main/llama_stack/providers/inline/inference) is also available. This is a sample of configuration using that method: ```yaml -inference +inference: - provider_type: vllm config: model: Llama3.1-8B-Instruct diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index b03483d1c9b21..cb2037b575e53 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -54,8 +54,8 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts ### FusedMoEPrepareAndFinalize -The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare` and `finalize` functions. -The `prepare` function is responsible for input activation Quantization and All2All Dispatch. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) +The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions. +The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section) ![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks") @@ -146,6 +146,10 @@ This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked. +`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False. + +`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked. + `FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked. `FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise. diff --git a/docs/design/huggingface_integration.md b/docs/design/huggingface_integration.md index 5a7582c86d49f..412ce658b92a2 100644 --- a/docs/design/huggingface_integration.md +++ b/docs/design/huggingface_integration.md @@ -1,31 +1,31 @@ # Integration with Hugging Face -This document describes how vLLM integrates with HuggingFace libraries. We will explain step by step what happens under the hood when we run `vllm serve`. +This document describes how vLLM integrates with Hugging Face libraries. We will explain step by step what happens under the hood when we run `vllm serve`. -Let's say we want to serve the popular QWen model by running `vllm serve Qwen/Qwen2-7B`. +Let's say we want to serve the popular Qwen model by running `vllm serve Qwen/Qwen2-7B`. 1. The `model` argument is `Qwen/Qwen2-7B`. vLLM determines whether this model exists by checking for the corresponding config file `config.json`. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L162-L182) for the implementation. Within this process: - If the `model` argument corresponds to an existing local path, vLLM will load the config file directly from this path. - - If the `model` argument is a HuggingFace model ID consisting of a username and model name, vLLM will first try to use the config file from the HuggingFace local cache, using the `model` argument as the model name and the `--revision` argument as the revision. See [their website](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome) for more information on how the HuggingFace cache works. - - If the `model` argument is a HuggingFace model ID but it is not found in the cache, vLLM will download the config file from the HuggingFace model hub. Refer to [this function](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91) for the implementation. The input arguments include the `model` argument as the model name, the `--revision` argument as the revision, and the environment variable `HF_TOKEN` as the token to access the model hub. In our case, vLLM will download the [config.json](https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json) file. + - If the `model` argument is a Hugging Face model ID consisting of a username and model name, vLLM will first try to use the config file from the Hugging Face local cache, using the `model` argument as the model name and the `--revision` argument as the revision. See [their website](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome) for more information on how the Hugging Face cache works. + - If the `model` argument is a Hugging Face model ID but it is not found in the cache, vLLM will download the config file from the Hugging Face model hub. Refer to [this function](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91) for the implementation. The input arguments include the `model` argument as the model name, the `--revision` argument as the revision, and the environment variable `HF_TOKEN` as the token to access the model hub. In our case, vLLM will download the [config.json](https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json) file. 2. After confirming the existence of the model, vLLM loads its config file and converts it into a dictionary. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L185-L186) for the implementation. 3. Next, vLLM [inspects](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L189) the `model_type` field in the config dictionary to [generate](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L190-L216) the config object to use. There are some `model_type` values that vLLM directly supports; see [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L48) for the list. If the `model_type` is not in the list, vLLM will use [AutoConfig.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained) to load the config class, with `model`, `--revision`, and `--trust_remote_code` as the arguments. Please note that: - - HuggingFace also has its own logic to determine the config class to use. It will again use the `model_type` field to search for the class name in the transformers library; see [here](https://github.com/huggingface/transformers/tree/main/src/transformers/models) for the list of supported models. If the `model_type` is not found, HuggingFace will use the `auto_map` field from the config JSON file to determine the class name. Specifically, it is the `AutoConfig` field under `auto_map`. See [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json) for an example. - - The `AutoConfig` field under `auto_map` points to a module path in the model's repository. To create the config class, HuggingFace will import the module and use the `from_pretrained` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when `--trust_remote_code` is enabled. + - Hugging Face also has its own logic to determine the config class to use. It will again use the `model_type` field to search for the class name in the transformers library; see [here](https://github.com/huggingface/transformers/tree/main/src/transformers/models) for the list of supported models. If the `model_type` is not found, Hugging Face will use the `auto_map` field from the config JSON file to determine the class name. Specifically, it is the `AutoConfig` field under `auto_map`. See [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json) for an example. + - The `AutoConfig` field under `auto_map` points to a module path in the model's repository. To create the config class, Hugging Face will import the module and use the `from_pretrained` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when `--trust_remote_code` is enabled. 4. Subsequently, vLLM applies some historical patches to the config object. These are mostly related to RoPE configuration; see [here](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/config.py#L244) for the implementation. 5. Finally, vLLM can reach the model class we want to initialize. vLLM uses the `architectures` field in the config object to determine the model class to initialize, as it maintains the mapping from architecture name to model class in [its registry](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/registry.py#L80). If the architecture name is not found in the registry, it means this model architecture is not supported by vLLM. For `Qwen/Qwen2-7B`, the `architectures` field is `["Qwen2ForCausalLM"]`, which corresponds to the `Qwen2ForCausalLM` class in [vLLM's code](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/qwen2.py#L364). This class will initialize itself depending on various configs. -Beyond that, there are two more things vLLM depends on HuggingFace for. +Beyond that, there are two more things vLLM depends on Hugging Face for. -1. **Tokenizer**: vLLM uses the tokenizer from HuggingFace to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check HuggingFace's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24). +1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24). -2. **Model weight**: vLLM downloads the model weight from the HuggingFace model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights. +2. **Model weight**: vLLM downloads the model weight from the Hugging Face model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights. - It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that: -This completes the integration between vLLM and HuggingFace. +This completes the integration between vLLM and Hugging Face. -In summary, vLLM reads the config file `config.json`, tokenizer, and model weight from the HuggingFace model hub or a local directory. It uses the config class from either vLLM, HuggingFace transformers, or loads the config class from the model's repository. +In summary, vLLM reads the config file `config.json`, tokenizer, and model weight from the Hugging Face model hub or a local directory. It uses the config class from either vLLM, Hugging Face transformers, or loads the config class from the model's repository. diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index ee474b5a7b997..e70ee4a076e54 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -2,7 +2,7 @@ IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output. -When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggerd via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint. +When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggered via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint. ## Writing an IO Processor Plugin diff --git a/docs/design/logits_processors.md b/docs/design/logits_processors.md new file mode 100644 index 0000000000000..20d78ca3aae2c --- /dev/null +++ b/docs/design/logits_processors.md @@ -0,0 +1,559 @@ +# Logits Processors + +!!! important + Some logits processors design changes are still in progress and the API may + change in the near future. We hope to stabilize this part of the API soon + +This document describes how the vLLM engine interacts with logits processors, and the programming model which vLLM supports for implementing logits processors. + +## Logits Processors Background + +A logits processor adjusts the next-token probability distribution, usually with the intention of steering the model towards a desired type of behavior. + +In vLLM, logits processors operate at batch granularity. During a given engine step, the logits processor consumes a `(num_requests) x (vocab_size)` tensor of raw logits output by the model. For all requests which enable the logits processor, the logits processor applies a transformation to the corresponding row of the logits tensor, while leaving other rows unmodified. The transformed logits tensor is then passed to softmax. + +## Logits Processors in the vLLM engine + +The vLLM engine's persistent batch data structure maintains a list of loaded logits processors. + +In order to operate on the entire batch at once, each logits processor may maintain metadata about the requests in the batch (i.e. each request's logits-processor-specific configuration settings). Therefore, logits processors are stateful. + +In each engine step, the vLLM engine will (1) update each logits processor's internal state and (2) apply logits processors to the model output logits. + +### Updating Logits Processor Internal State + +At the beginning of each engine step, the persistent batch may add, discard and/or reorder requests in response to the scheduler output. After the persistent batch has reorganized, the vLLM engine invokes each logits processor's `update_state()` method. This is necessary to ensure that logits processors' internal states are reorganized to match the new persistent batch state at the beginning of the engine step. + +The pseudocode below shows the process by which the vLLM persistent batch notifies each logits processor of changes in batch state: + +??? code "Model Runner Updates Logits Processor States" + + ``` python + # gpu_model_runner.py + + class GPUModelRunner(...): + + ... + + def execute_model(self, scheduler_output, ...): + self._update_states(scheduler_output) + + ... + + def _update_states(...): + + ... + + # ...update persistent batch to reflect new/finished requests & reordering + # of requests within batch... + + ... + + self.input_batch.refresh_metadata() + + + # gpu_input_batch.py + + class InputBatch: + + ... + + def refresh_metadata(self): + + ... + + # Update each logits processor's state to reflect persistent batch state + batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) + for logit_proc in self.logitsprocs.all: + logit_proc.update_state(batch_update) + + ... + + + # vllm/v1/sample/logits_processor/interface.py + + @dataclass(frozen=True) + class BatchUpdate: + # Batch state-change data structure which is passed to logits processors' + # update_state() methods + + batch_size: int + + removed: Sequence[RemovedRequest] + added: Sequence[AddedRequest] + moved: Sequence[MovedRequest] + + ``` + +### Applying Logits Processors to the Model Output Logits + +After updating persistent batch state, the vLLM model runner performs model inference to obtain logits. Then, the model runner invokes the sampler against the logits. In turn, part of the sampler's operation is to invoke the logits processors' `apply()` methods against the model output logit processors, yielding transformed logits (the `apply()` methods may modify the logits in-place or out-of-place, although in-place is more memory-efficient). This process is shown in the pseudocode below. + +Note that the sampler will access the logits processors via `SamplingMetadata.logitsprocs`. When the vLLM engine constructs `SamplingMetadata` (not shown in the code below), the reference to the list of logits processors is passed from the persistent batch data structure to `SamplingMetadata`. + +??? code "Apply logits processors to model output logits" + + ``` python + # gpu_model_runner.py + + class GPUModelRunner(...): + + ... + + def execute_model(self, scheduler_output, ...): + # (discussed in previous section) + self._update_states(scheduler_output) + + ... + + # ...run model inference to obtain logits... + + ... + + # Invoke sampler, which applies logits processors + sampler_output = self.sampler(logits=logits, + sampling_metadata=sampling_metadata) + + ... + + + # sampler.py + + class Sampler(nn.Module): + + ... + + def forward(self, logits, sampling_metadata): + + ... + + # Apply non-argmax-invariant logits processors to model output logits + for processor in (sampling_metadata.logitsprocs.non_argmax_invariant): + logits = processor.apply(logits) + + sampled = self.sample(logits, sampling_metadata) + + ... + + # ...return sampler output data structure... + + + def sample(self, logits, sampling_metadta) + + ... + + # ...exit early if all requests are greedy-sampling... + + ... + + # Apply argmax-invariant logits processors + for processor in sampling_metadata.logitsprocs.argmax_invariant: + logits = processor.apply(logits) + + ... + + # ...perform sampling and return sampling result... + ``` + +At sampling time, the sampler checks whether all requests in the persistent batch employ greedy sampling. If that is the case, the sampler saves compute by skipping "argmax-invariant" logits processors. Here, "argmax" is shorthand for the token ID with the highest logit value in a given row of the logits tensor (i.e. the token which the model weighted the highest for a given request). + +* An **argmax-invariant logits processor** is a logits processor (such as Min-P) which does not modify the argmax. For example, a logits processor which masks out the lowest-probability tokens will not change which token ID has the max logit. Greedy sampling always picks the highest-logit-value token ID, and so conceptually an argmax-invariant logits processor can be skipped for greedy sampling requests. + +* A **non-argmax-invariant logits processor** is a logits processor which may modify the argmax. For example, a logits processor which masks all tokens except for EOS after a certain number of steps in order to force decoding to terminate might end up masking the max-logit-value token and therefore change the argmax. Conceptually, these logits processors cannot be skipped for greedy sampling requests. + +The vLLM logits processor abstraction requires the engine to apply logits processors at batch granularity; therefore in practice the argmax-invariant logits processors can only be skipped when the entire batch uses greedy sampling. + +## Logits Processor Programming Model + +The previous sections alluded to the interfaces which vLLM logits processors must support. This section introduces in full the programming model for implementing logits processors that are compatible with the vLLM engine, including the `LogitsProcessor` base class and its interface methods as well as the `BatchUpdate` data structure for representing persistent batch state changes, both of which are shown in the code below: + +??? code "`LogitsProcessor` base class and `BatchUpdate` data structure" + + ``` python + from abc import ABC, abstractmethod + from collections.abc import Sequence + from dataclasses import dataclass + from enum import Enum, auto + from typing import TYPE_CHECKING, Optional + + import torch + + from vllm import SamplingParams + + if TYPE_CHECKING: + from vllm.config import VllmConfig + + + class MoveDirectionality(Enum): + # One-way i1->i2 req move within batch + UNIDIRECTIONAL = auto() + # Two-way i1<->i2 req swap within batch + SWAP = auto() + + + # (index, params, prompt_tok_ids, output_tok_ids) tuples for new + # requests added to the batch. + AddedRequest = tuple[int, SamplingParams, list[int], list[int]] + + # (index 1, index 2, directionality) tuples representing + # one-way moves or two-way swaps of requests in batch + MovedRequest = tuple[int, int, MoveDirectionality] + + # Batch indices of any removed requests. + RemovedRequest = int + + + @dataclass(frozen=True) + class BatchUpdate: + """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch + + # Metadata for requests added to, removed from, and moved + # within the persistent batch. + # + # Key assumption: the `output_tok_ids` list (which is an element of each + # tuple in `added`) is a reference to the request's running output tokens + # list; via this reference, the logits processors always see the latest + # list of generated output tokens + removed: Sequence[RemovedRequest] + moved: Sequence[MovedRequest] + added: Sequence[AddedRequest] + + + class LogitsProcessor(ABC): + + @abstractmethod + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool) -> None: + raise NotImplementedError + + @abstractmethod + def apply(self, logits: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def is_argmax_invariant(self) -> bool: + """True if logits processor has no impact on the + argmax computation in greedy sampling. + NOTE: may or may not have the same value for all + instances of a given LogitsProcessor subclass, + depending on subclass implementation. + """ + raise NotImplementedError + + @abstractmethod + def update_state( + self, + batch_update: Optional["BatchUpdate"], + ) -> None: + """Called when there are new output tokens, prior + to each forward pass. + + Args: + batch_update is non-None iff there have been + changes to the batch makeup. + """ + raise NotImplementedError + + ``` + +A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods: + +* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)` + * `vllm_config`: engine configuration data structure + * `device`: hardware accelerator device info + * `is_pin_memory`: flag indicating whether pin memory is available to support logits processor implementation + +* `apply(self, logits: torch.Tensor) -> torch.Tensor`: + * Consume a `(num_requests) x (vocab_size)` logits tensor (`logits`) + * Apply logits processor transformation at batch granularity + * Return a transformed `(num_requests) x (vocab_size)` logits tensor + * You can modify the input logits processors in-place or out-of-place; in-place is more memory-efficient + +* `is_argmax_invariant(self) -> bool`: + * Return `True` if the logits processor is argmax invariant (never changes what is the highest-logit-value token ID for a given request), `False` if the logits processor may modify argmax + * `is_argmax_invariant()` is evaluated once at startup; if `True`, vLLM will skip applying this logits processor in a given step when all requests use greedy sampling + +* `update_state(self, batch_update: Optional["BatchUpdate"]) -> None`: + * Consume a `BatchUpdate` data structure representing persistent batch state changes at the beginning of the current engine step + * Use the `BatchUpdate` members to update logits processor internal state + * **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added. + +### `BatchUpdate` data structure + +The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`): + +* **Remove:** remove (without replacement) request at index `i` + + * A Remove is represented in `Batchupdate.removed` by an `int` (representing `i`) + + * Effect of remove-at-index on batch: + + ``` text + Batch: [A,B,C] + Remove @ i: 1 + + => + + New Batch: [A,x,C] # Discard B and leave an empty slot + ``` + +* **Add:** add (or replace existing request with) a new request at index `i`. If a request is replaced, its associated state should be discarded. + + * An Add is represented in `Batchupdate.added` as a tuple of + + ``` text + (index, new request SamplingParams, prompt token ids, output token ids) + ``` + + * `prompt token ids` and `output token ids` are references to the request's prompt token ids and output token ids lists, respectively. Note that the output token ids list grows with each engine step, and this growth is visible to the logits processor because output token ids are passed by reference. **This is important for LogitsProcessors that take into account the tokens generated so far**. + + * The implementation of the particular logits processor subclass determines whether or how the fields in the added request tuple are digested into an internal representation. For example, a logits processor that does not utilize prompt or output token ids may only need to utilize `index` and `SamplingParams` and discard the other tuple fields + + * If index `i` currently holds a request, a replacement occurs: + + ``` text + Batch: [A,B,C] + New request to be added @ i: D @ 1 + + => + + New Batch: [A,D,C] # Add D, discard B + ``` + + * If index `i` does not currently hold a request (because `i` is out of bounds of the current batch size): + + ``` text + Batch: [A,B,C] + New request to be added @ i: D @ 3 + + => + + New Batch: [A,B,C,D] # Add D, extending batch + ``` + +* **Move:** move request at index `s` to index `d` OR swap requests at indices `s` and `d` + + * A Move is represented in `Batchupdate.moved` as a tuple of + + ``` text + (s, d, UNIDIRECTIONAL or SWAP) + ``` + + * If the Move specifies `UNIDRECTIONAL`: + + * The request at index `s` is moved to index `d`; index `s` becomes an empty slot + + ``` text + Batch: [A,x,C,D] + Unidirectionally Move s -> d: 3 -> 1 + + => + + New Batch: [A,D,C,x] # Move D to 1, leaving empty slot at 3 + ``` + + * If another request already resided at index `d`, it is replaced and discarded + + ``` text + Batch: [A,B,C,D] + Unidirectionally Move s -> d: 3 -> 1 + + => + + New Batch: [A,D,C,x] # Move D to 1, discarding B and leaving empty slot at 3 + ``` + + * If the Move specifies `SWAP`, the requests at `s` and `d` exchange indices + + ``` text + Batch: [A,B,C,D] + Swap Move s <-> d: 3 <-> 1 + + => + + New Batch: [A,D,C,B] # Swap B and D + ``` + +Additionally, the `BatchUpdate` data structure includes a representation (`batch_size`) of the size of the persistent batch at the beginning of the engine step. + +### How the vLLM engine builds the `BatchUpdate` data structure + +Logits processor `update_state()` implementations should assume the following model for how the model runner updates persistent batch state (expressed here in terms of the `BatchUpdate` abstraction): + +1. Identify indices of requests which finished in the current engine step + +2. Identify new requests introduced in the current step + +3. Use Add operations to replace as many finished requests with new requests, in order of increasing index of the replaced request starting with the lowest index + +4. Based on the relative number of new and finished requests: + + 1. If the numbers of new and finished requests are the same, proceed to next step + + 2. *If there are more new requests than finished requests:* apply Add operations to extend the batch with the remaining new requests which did not replace finished requests. Assign consecutive indices to these new requests, starting with `current_max_batch_index + 1` + + 3. *If there are fewer new requests than finished requests:* + + * Apply Remove operations to finished requests which were not replaced with new requests. These removed request indices will necessarily be greater than the greatest index of the finished requests which were replaced in the previous step. The Removes may leave the batch in a non-contiguous state + + * **"Condense" the batch to be contiguous:** starting with the lowest-index empty slot (which was caused by a Remove), apply a Unidirectional Move from the current highest non-empty slot in the batch to fill the empty slot. Proceed with additional Unidirectional Move operations in order of increasing empty slot destination index and decreasing non-empty slot source index until the batch is contiguous + + * **Shrink the batch:** a side-effect of condensing the batch is that empty slots resulting from Remove operations are grouped in a contiguous block at the end of the batch array. Thus, after condensing, update `BatchUpdate.batch_size` to reflect the number of non-empty slots + +5. Reorder the batch for improved efficiency. Depending on the attention backend implementation and the current characteristics of the batch, zero or more Swap Move operations may be applied to reorder the batch + +Notes: + +* A logits processor `update_state()` method must process batch update operations in the following order: removes, adds, moves + +* The index argument for Add operations refers to the index *at the time the Add occurred*, i.e. before any Move operations + * Example: if a request is Added at index 5 and then swapped with index 3, the Add operation in `BatchUpdate.added` will be associated with index 5 not 3 + * In other words Move operations can be assumed to be applied after Adds and Removes + +* Move operations can be assumed to be applied in the order in which they appear in `BatchUpdate.moved` + +* If there are no new/finished requests and there is no batch reordering, then the batch update for the logits processors will be `None` + +#### Example: Batch Update with Fewer New Requests Than Finished Requests + +The following example models an engine step where 1 new request is introduced and 2 finished requests are eliminated, additionally the attention backend performs a swap to optimize the batch ordering. + +``` text +Batch state (beginning of engine step): [A,B,C,D] +Batch size: 4 + +New requests: E + +Finished requests: A, C + +Processing steps (using BatchUpdate abstraction): + +1. Add E at index 0 + +[E,B,C,D] # Discard A +Batch size: 4 + +2. Remove at index 2 + +[E,B,x,D] # Discard C, empty slot at index 2 +Batch size: 4 + +3. Condense batch with a Unidirectional Move 3 -> 2 operation and shrink batch + +[E,B,D] x # Empty slot is now outside batch +Batch size: 3 + +4. Attention backend optimization: reorder batch with Swap 0 <-> 1 + +[B,E,D] +Batch size: 3 + +``` + +The resulting `BatchUpdate` data structure will look like + +``` text +BatchUpdate instance +* added: [(0,E's SamplingParams,E's prompt tokens ref,E's output tokens ref)] +* removed: [2] # request C was removed without replacement +* moved: [(3,2,UNIDIRECTIONAL),(0,1,SWAP)] +``` + +#### Example: Batch Update with More New Requests Than Finished Requests + +The following example models an engine step where 2 new requests are introduced and 1 finished request is eliminated, additionally the attention backend performs a swap to optimize the batch ordering. + +``` text +Batch state (beginning of engine step): [A,B,C,D] +Batch size: 4 + +New requests: E,F + +Finished requests: C + +Processing steps (using BatchUpdate abstraction): + +1. Add E at index 2 + +[A,B,E,D] # Discard C +Batch size: 4 + +2. Add F at index 4 (current max batch index + 1) + +[A,B,E,D,F] # Extend batch by 1 +Batch size: 5 + +4. Attention backend optimization: reorder batch with Swap 0 <-> 1 + +[B,A,E,D,F] +Batch size: 5 + +``` + +Note that batch condensation is skipped because there are no empty slots left behind by Remove operations. + +The resulting `BatchUpdate` data structure will look like + +``` text +BatchUpdate instance +* added: [(2,E's SamplingParams,E's prompt tokens ref,E's output tokens ref),(4,F's SamplingParams,F's prompt tokens ref,F's output tokens ref)] +* removed: [] # no requests were removed without replacement +* moved: [(0,1,SWAP)] +``` + +## How to Introduce a New Logits Processor to vLLM + +### Best Practices for Writing Built-In Logits Processors + +* Write efficient `apply()` and `update_state()` implementations in light of the fact that logits processors operate at batch granularity + * For example, you may be able to use efficient vectorized operations to implement `apply()` or update internal state vectors in `update_state()` + * However, if you think that a logits processor may be used infrequently, it may be appropriate to use a "sparse" representation of request state i.e. the class can represent request configuration using a dictionary which only stores metadata about requests that enable the logits processor + +* It is up to the logits processor author to determine: + + 1. **The per-request attributes which configure the logits processor's behavior against that request.** For example, if you are writing a new built-in logits processor for vLLM, you may or may not need to add additional fields to `SamplingParams` and the vLLM REST API + + 2. **The conditions under which the logits processor is or is not enabled on a per-request basis.** Unless your intention is for the built-in logits processor to act on all requests all the time, you should write your logits processor in such a way that it is possible to disable the logits processor for a given request, i.e. by defaulting an argument to `None` or by passing in a specific do-nothing argument value i.e. `0.0`. Try to save compute and memory for requests which disable the logits processor + + 3. **The conditions under which the logits processor is short-circuited at the batch level.** Even if you have defined a way to disable the built-in logits processor at the request level, it may be difficult to translate this into compute savings i.e. if your `update_state()` and `apply()` implementations use efficient vectorized implementations that operate on the whole persistent batch in a single command. For example, you cannot skip an entire vectorized operation in `apply()` just because one request disabled the logits processor. To save compute in the edge-case where no running requests utilize the built-in logits processor, we recommend designing `apply()` to return the unmodified input tensor if all requests have the logits processor disabled. Similarly, consider whether steps can be skipped in `update_state()` if no requests enable the logits processor + + * Additionally, an easy way to save compute in `update_state()` is to exit early when the batch_update is `None` + +* Ensure that the logits processor `update_state` method discards information about finished requests (i.e. requests which are replaced by an Add or which are subject to a Remove) + +* `is_argmax_invariant()` can be hard-coded to `True` or `False` if the logits processor has consistent behavior. However the argmax invariance may also be determined programmatically (i.e. if your logits processor is user-customizable in some way that impacts whether the logits processor is argmax invariant). For this reason, `is_argmax_invariant()` is not a class method + +### Built-In Logits Processors + +Built-in logits processors are always loaded when the vLLM engine starts. See the existing vLLM built-in logits processors in `vllm/v1/sample/logits_processor/builtin.py` for examples of how to write a new built-in vLLM logits processor. It makes sense to write a PR to introduce a new logits processor as a built-in if it is likely to be useful to a wide audience. vLLM currently employs the following built-in logits processors based on the programming model described above: + +* Min-P + +* Logit bias + +* Min-tokens + +Review these logits processor implementations for guidance on writing built-in logits processors. + +Additionally, the following logits-processor-like functionalities are hard-coded into the sampler and do not yet utilize the programming model described above. Most of them will be refactored to use the aforemented logits processor programming model. + +* Allowed token IDs + +* Bad words + +* Repetition penalty + +* Frequency penalty + +* Presence penalty + +* Temperature + +* Top-K + +* Top-P + +### Custom Logits Processors + +vLLM can be augmented with [user-provided custom logits processors](../features/custom_logitsprocs.md). diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 247072d1cb275..6e92b20d267b4 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -8,7 +8,7 @@ page for information on known issues and how to solve them. ## Introduction !!! important - The source code references are to the state of the code at the time of writing in December, 2024. + The source code references are to the state of the code at the time of writing in December 2024. The use of Python multiprocessing in vLLM is complicated by: diff --git a/docs/examples/README.md b/docs/examples/README.md index 3cf93027f4209..94f5efc92f386 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -2,6 +2,6 @@ vLLM's examples are split into three categories: -- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference) -- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving) -- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others) +- If you are using vLLM from within Python code, see the *Offline Inference* section. +- If you are using vLLM from an HTTP application or client, see the *Online Serving* section. +- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see the *Others* section. diff --git a/docs/features/compatibility_matrix.md b/docs/features/README.md similarity index 95% rename from docs/features/compatibility_matrix.md rename to docs/features/README.md index 5b08b3810776c..d8e26ec02aecc 100644 --- a/docs/features/compatibility_matrix.md +++ b/docs/features/README.md @@ -1,4 +1,6 @@ -# Compatibility Matrix +# Features + +## Compatibility Matrix The tables below show mutually exclusive features and the support on some hardware. @@ -12,7 +14,7 @@ The symbols used have the following meanings: !!! note Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/hardware combination. -## Feature x Feature +### Feature x Feature -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | -|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| -| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | -| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | -| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | Google TPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------| +| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | +| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | +| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | +| BitBLAS | ✅︎ | ✅ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| BitBLAS (GPTQ) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | +| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. - ✅︎ indicates that the quantization method is supported on the specified hardware. diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index d9a785eb73fbe..85681669dfb22 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -10,11 +10,12 @@ vLLM currently supports the following reasoning models: | Model Series | Parser Name | Structured Output Support | Tool Calling | |--------------|-------------|------------------|-------------| -| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | -| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | +| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ | +| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | -| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | -| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `guided_json`, `guided_regex` | ✅ | +| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` | ✅ | +| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `json`, `regex` | ✅ | +| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` | ✅ | !!! note IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 0d6294a5fdd79..1f955c6e30d6c 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -12,23 +12,23 @@ You can generate structured outputs using the OpenAI's [Completions](https://pla The following parameters are supported, which must be added as extra parameters: -- `guided_choice`: the output will be exactly one of the choices. -- `guided_regex`: the output will follow the regex pattern. -- `guided_json`: the output will follow the JSON schema. -- `guided_grammar`: the output will follow the context free grammar. +- `choice`: the output will be exactly one of the choices. +- `regex`: the output will follow the regex pattern. +- `json`: the output will follow the JSON schema. +- `grammar`: the output will follow the context free grammar. - `structural_tag`: Follow a JSON schema within a set of specified tags within the generated text. You can see the complete list of supported parameters on the [OpenAI-Compatible Server](../serving/openai_compatible_server.md) page. Structured outputs are supported by default in the OpenAI-Compatible Server. You may choose to specify the backend to use by setting the -`--guided-decoding-backend` flag to `vllm serve`. The default backend is `auto`, +`--structured-outputs-config.backend` flag to `vllm serve`. The default backend is `auto`, which will try to choose an appropriate backend based on the details of the request. You may also choose a specific backend, along with some options. A full set of options is available in the `vllm serve --help` text. -Now let´s see an example for each of the cases, starting with the `guided_choice`, as it´s the easiest one: +Now let´s see an example for each of the cases, starting with the `choice`, as it´s the easiest one: ??? code @@ -45,12 +45,12 @@ Now let´s see an example for each of the cases, starting with the `guided_choic messages=[ {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], - extra_body={"guided_choice": ["positive", "negative"]}, + extra_body={"structured_outputs": {"choice": ["positive", "negative"]}}, ) print(completion.choices[0].message.content) ``` -The next example shows how to use the `guided_regex`. The idea is to generate an email address, given a simple regex template: +The next example shows how to use the `regex`. The idea is to generate an email address, given a simple regex template: ??? code @@ -63,18 +63,18 @@ The next example shows how to use the `guided_regex`. The idea is to generate an "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com and new line. Example result: alan.turing@enigma.com\n", } ], - extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]}, + extra_body={"structured_outputs": {"regex": r"\w+@\w+\.com\n"}, "stop": ["\n"]}, ) print(completion.choices[0].message.content) ``` One of the most relevant features in structured text generation is the option to generate a valid JSON with pre-defined fields and formats. -For this we can use the `guided_json` parameter in two different ways: +For this we can use the `json` parameter in two different ways: - Using directly a [JSON Schema](https://json-schema.org/) - Defining a [Pydantic model](https://docs.pydantic.dev/latest/) and then extracting the JSON Schema from it (which is normally an easier option). -The next example shows how to use the `guided_json` parameter with a Pydantic model: +The next example shows how to use the `response_format` parameter with a Pydantic model: ??? code @@ -119,7 +119,7 @@ The next example shows how to use the `guided_json` parameter with a Pydantic mo JSON schema and how the fields should be populated. This can improve the results notably in most cases. -Finally we have the `guided_grammar` option, which is probably the most +Finally we have the `grammar` option, which is probably the most difficult to use, but it´s really powerful. It allows us to define complete languages like SQL queries. It works by using a context free EBNF grammar. As an example, we can use to define a specific format of simplified SQL queries: @@ -149,7 +149,7 @@ As an example, we can use to define a specific format of simplified SQL queries: "content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.", } ], - extra_body={"guided_grammar": simplified_sql_grammar}, + extra_body={"structured_outputs": {"grammar": simplified_sql_grammar}}, ) print(completion.choices[0].message.content) ``` @@ -292,8 +292,8 @@ An example of using `structural_tag` can be found here: , which contains several features in addition to what's -available on vLLM V0. Please utilize the AWS Fork for the following features: - -- Llama-3.2 multi-modal support -- Multi-node distributed inference - -Refer to [vLLM User Guide for NxD Inference](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/vllm-user-guide.html) - for more details and usage examples. - -To install the AWS Neuron fork, run the following: - -```bash -git clone -b neuron-2.23-vllm-v0.7.2 https://github.com/aws-neuron/upstreaming-to-vllm.git -cd upstreaming-to-vllm -pip install -r requirements/neuron.txt -VLLM_TARGET_DEVICE="neuron" pip install -e . -``` - -Note that the AWS Neuron fork is only intended to support Neuron hardware; compatibility with other hardwares is not tested. - -## Set up using Docker - -### Pre-built images - -Currently, there are no pre-built Neuron images. - -### Build image from source - -See [deployment-docker-build-image-from-source][deployment-docker-build-image-from-source] for instructions on building the Docker image. - -Make sure to use in place of the default Dockerfile. - -## Extra information - -[](){ #feature-support-through-nxd-inference-backend } - -### Feature support through NxD Inference backend - -The current vLLM and Neuron integration relies on either the `neuronx-distributed-inference` (preferred) or `transformers-neuronx` backend -to perform most of the heavy lifting which includes PyTorch model initialization, compilation, and runtime execution. Therefore, most -[features supported on Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html) are also available via the vLLM integration. - -To configure NxD Inference features through the vLLM entrypoint, use the `override_neuron_config` setting. Provide the configs you want to override -as a dictionary (or JSON object when starting vLLM from the CLI). For example, to disable auto bucketing, include - -```python -override_neuron_config={ - "enable_bucketing":False, -} -``` - -or when launching vLLM from the CLI, pass - -```bash ---override-neuron-config "{\"enable_bucketing\":false}" -``` - -Alternatively, users can directly call the NxDI library to trace and compile your model, then load the pre-compiled artifacts -(via `NEURON_COMPILED_ARTIFACTS` environment variable) in vLLM to run inference workloads. - -### Known limitations - -- EAGLE speculative decoding: NxD Inference requires the EAGLE draft checkpoint to include the LM head weights from the target model. Refer to this - [guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html#eagle-checkpoint-compatibility) - for how to convert pretrained EAGLE model checkpoints to be compatible for NxDI. -- Quantization: the native quantization flow in vLLM is not well supported on NxD Inference. It is recommended to follow this - [Neuron quantization guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/custom-quantization.html) - to quantize and compile your model using NxD Inference, and then load the compiled artifacts into vLLM. -- Multi-LoRA serving: NxD Inference only supports loading of LoRA adapters at server startup. Dynamic loading of LoRA adapters at - runtime is not currently supported. Refer to [multi-lora example](https://github.com/aws-neuron/upstreaming-to-vllm/blob/neuron-2.23-vllm-v0.7.2/examples/offline_inference/neuron_multi_lora.py) -- Multi-modal support: multi-modal support is only available through the AWS Neuron fork. This feature has not been upstreamed - to vLLM main because NxD Inference currently relies on certain adaptations to the core vLLM logic to support this feature. -- Multi-node support: distributed inference across multiple Trainium/Inferentia instances is only supported on the AWS Neuron fork. Refer - to this [multi-node example](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2/examples/neuron/multi_node) - to run. Note that tensor parallelism (distributed inference across NeuronCores) is available in vLLM main. -- Known edge case bug in speculative decoding: An edge case failure may occur in speculative decoding when sequence length approaches - max model length (e.g. when requesting max tokens up to the max model length and ignoring eos). In this scenario, vLLM may attempt - to allocate an additional block to ensure there is enough memory for number of lookahead slots, but since we do not have good support - for paged attention, there isn't another Neuron block for vLLM to allocate. A workaround fix (to terminate 1 iteration early) is - implemented in the AWS Neuron fork but is not upstreamed to vLLM main as it modifies core vLLM logic. - -### Environment variables - -- `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid - compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the - artifacts under `neuron-compiled-artifacts/{unique_hash}/` subdirectory in the model path. If this environment variable is set, - but the directory does not exist, or the contents are invalid, Neuron will also fall back to a new compilation and store the artifacts - under this specified path. -- `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend). -- `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend). diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 7f0ecb2bc0b74..f8b4f75308df7 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -180,7 +180,7 @@ Inference batch size is an important parameter for the performance. Larger batch - Offline Inference: `256 * world_size` - Online Serving: `128 * world_size` -vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning DP, TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use DP, TP and PP together if there are enough CPU sockets and memory nodes. +vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning DP, TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommended to use DP, TP and PP together if there are enough CPU sockets and memory nodes. ### Which quantization configs does vLLM CPU support? @@ -194,3 +194,35 @@ vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel - Both of them require `amx` CPU flag. - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. + +### Why do I see `get_mempolicy: Operation not permitted` when running in Docker? + +In some container environments (like Docker), NUMA-related syscalls used by vLLM (e.g., `get_mempolicy`, `migrate_pages`) are blocked/denied in the runtime's default seccomp/capabilities settings. This may lead to warnings like `get_mempolicy: Operation not permitted`. Functionality is not affected, but NUMA memory binding/migration optimizations may not take effect and performance can be suboptimal. + +To enable these optimizations inside Docker with the least privilege, you can follow below tips: + +```bash +docker run ... --cap-add SYS_NICE --security-opt seccomp=unconfined ... + +# 1) `--cap-add SYS_NICE` is to address `get_mempolicy` EPERM issue. + +# 2) `--security-opt seccomp=unconfined` is to enable `migrate_pages` for `numa_migrate_pages()`. +# Actually, `seccomp=unconfined` bypasses the seccomp for container, +# if it's unacceptable, you can customize your own seccomp profile, +# based on docker/runtime default.json and add `migrate_pages` to `SCMP_ACT_ALLOW` list. + +# reference : https://docs.docker.com/engine/security/seccomp/ +``` + +Alternatively, running with `--privileged=true` also works but is broader and not generally recommended. + +In K8S, the following configuration can be added to workload yaml to achieve the same effect as above: + +```yaml +securityContext: + seccompProfile: + type: Unconfined + capabilities: + add: + - SYS_NICE +``` diff --git a/docs/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu/apple.inc.md index 124a41adf1ae2..7e2ed55008a57 100644 --- a/docs/getting_started/installation/cpu/apple.inc.md +++ b/docs/getting_started/installation/cpu/apple.inc.md @@ -52,6 +52,24 @@ uv pip install -e . 1 error generated. ``` + --- + + If the build fails with C++11/C++17 compatibility errors like the following, the issue is that the build system is defaulting to an older C++ standard: + + ```text + [...] error: 'constexpr' is not a type + [...] error: expected ';' before 'constexpr' + [...] error: 'constexpr' does not name a type + ``` + + **Solution**: Your compiler might be using an older C++ standard. Edit `cmake/cpu_extension.cmake` and add `set(CMAKE_CXX_STANDARD 17)` before `set(CMAKE_CXX_STANDARD_REQUIRED ON)`. + + To check your compiler's C++ standard support: + ```bash + clang++ -std=c++17 -pedantic -dM -E -x c++ /dev/null | grep __cplusplus + ``` + On Apple Clang 16 you should see: `#define __cplusplus 201703L` + # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] diff --git a/docs/getting_started/installation/cpu/arm.inc.md b/docs/getting_started/installation/cpu/arm.inc.md index cac578eefb1d7..e45baa0aa4938 100644 --- a/docs/getting_started/installation/cpu/arm.inc.md +++ b/docs/getting_started/installation/cpu/arm.inc.md @@ -48,6 +48,10 @@ docker run --rm \ --dtype=bfloat16 \ other vLLM OpenAI server arguments ``` + +!!! tip + An alternative of `--privileged=true` is `--cap-add SYS_NICE --security-opt seccomp=unconfined`. + # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/s390x.inc.md b/docs/getting_started/installation/cpu/s390x.inc.md index c1917267ce91b..f9c4ccb942fac 100644 --- a/docs/getting_started/installation/cpu/s390x.inc.md +++ b/docs/getting_started/installation/cpu/s390x.inc.md @@ -89,6 +89,9 @@ docker run --rm \ other vLLM OpenAI server arguments ``` +!!! tip + An alternative of `--privileged true` is `--cap-add SYS_NICE --security-opt seccomp=unconfined`. + # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md index f7af259ace628..836da33f65317 100644 --- a/docs/getting_started/installation/cpu/x86.inc.md +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -44,6 +44,7 @@ docker build -f docker/Dockerfile.cpu \ # Launching OpenAI server docker run --rm \ --security-opt seccomp=unconfined \ + --cap-add SYS_NICE \ --shm-size=4g \ -p 8000:8000 \ -e VLLM_CPU_KVCACHE_SPACE= \ diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 275232e12e08c..9e64c6f2540af 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -168,6 +168,7 @@ There are scenarios where the PyTorch dependency cannot be easily installed with To build vLLM using an existing PyTorch installation: ```bash +# install PyTorch first, either from PyPI or from source git clone https://github.com/vllm-project/vllm.git cd vllm python use_existing_torch.py @@ -175,6 +176,17 @@ uv pip install -r requirements/build.txt uv pip install --no-build-isolation -e . ``` +Alternatively: if you are exclusively using `uv` to create and manage virtual environments, it has [a unique mechanism](https://docs.astral.sh/uv/concepts/projects/config/#disabling-build-isolation) +for disabling build isolation for specific packages. vLLM can leverage this mechanism to specify `torch` as the package to disable build isolation for: + +```bash +# install PyTorch first, either from PyPI or from source +git clone https://github.com/vllm-project/vllm.git +cd vllm +# pip install -e . does not work directly, only uv can do this +uv pip install -e . +``` + ##### Use the local cutlass for compilation Currently, before starting the build process, vLLM fetches cutlass code from GitHub. However, there may be scenarios where you want to use a local version of cutlass instead. diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index 80e99d3034d39..37c6647929b51 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -1,6 +1,6 @@ # --8<-- [start:installation] -vLLM supports AMD GPUs with ROCm 6.3. +vLLM supports AMD GPUs with ROCm 6.3 or above. !!! tip [Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm. @@ -11,8 +11,9 @@ vLLM supports AMD GPUs with ROCm 6.3. # --8<-- [end:installation] # --8<-- [start:requirements] -- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100/1101), Radeon RX 9000 series (gfx1200/1201) -- ROCm 6.3 +- GPU: MI200s (gfx90a), MI300 (gfx942), MI350 (gfx950), Radeon RX 7900 series (gfx1100/1101), Radeon RX 9000 series (gfx1200/1201) +- ROCm 6.3 or above + - MI350 requires ROCm 7.0 or above # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] @@ -32,35 +33,35 @@ Currently, there are no pre-built ROCm wheels. - [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html) - [PyTorch](https://pytorch.org/) - For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. + For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3. Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example: ```bash # Install PyTorch pip uninstall torch -y - pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3 + pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4 ``` -1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton) +1. Install [Triton for ROCm](https://github.com/triton-lang/triton) - Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md) + Install ROCm's Triton (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md) ```bash python3 -m pip install ninja cmake wheel pybind11 pip uninstall -y triton - git clone https://github.com/OpenAI/triton.git + git clone https://github.com/triton-lang/triton.git cd triton git checkout e5be006 - cd python - pip3 install . + if [ ! -f setup.py ]; then cd python; fi + python3 setup.py install cd ../.. ``` !!! note If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent. -2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention) +2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/Dao-AILab/flash-attention) Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention#amd-rocm-support) Alternatively, wheels intended for vLLM use can be accessed under the releases. @@ -68,9 +69,9 @@ Currently, there are no pre-built ROCm wheels. For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`. ```bash - git clone https://github.com/ROCm/flash-attention.git + git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention - git checkout b7d29fb + git checkout 1a7f4dfa git submodule update --init GPU_ARCHS="gfx90a" python3 setup.py install cd .. @@ -119,7 +120,7 @@ Currently, there are no pre-built ROCm wheels. This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation. !!! tip - - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. + - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm-up step before collecting perf numbers. - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. - To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention. - The ROCm version of PyTorch, ideally, should match the ROCm driver version. @@ -194,16 +195,6 @@ To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default: DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.rocm -t vllm-rocm . ``` -To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image: - -```bash -DOCKER_BUILDKIT=1 docker build \ - --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" \ - -f docker/Dockerfile.rocm \ - -t vllm-rocm \ - . -``` - To run the above docker image `vllm-rocm`, use the below command: ??? console "Command" @@ -218,8 +209,7 @@ To run the above docker image `vllm-rocm`, use the below command: --device /dev/kfd \ --device /dev/dri \ -v :/app/model \ - vllm-rocm \ - bash + vllm-rocm ``` Where the `` is the location where the model is stored, for example, the weights for llama2 or llama3 models. diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md index b77c4e00cf0c4..ed1dc0418cf7e 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -3,13 +3,16 @@ vLLM initially supports basic model inference and serving on Intel GPU platform. !!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. + There are no pre-built wheels for this device, so you need build vLLM from source. Or you can use pre-built images which are based on vLLM released versions. # --8<-- [end:installation] # --8<-- [start:requirements] - Supported Hardware: Intel Data Center GPU, Intel ARC GPU -- OneAPI requirements: oneAPI 2025.0 +- OneAPI requirements: oneAPI 2025.1 +- Python: 3.12 +!!! warning + The provided IPEX whl is Python3.12 specific so this version is a MUST. # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] @@ -24,7 +27,7 @@ Currently, there are no pre-built XPU wheels. # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] -- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.0 or later. +- First, install required [driver](https://dgpu-docs.intel.com/driver/installation.html#installing-gpu-drivers) and [Intel OneAPI](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) 2025.1 or later. - Second, install Python packages for vLLM XPU backend building: ```bash @@ -40,14 +43,10 @@ pip install -v -r requirements/xpu.txt VLLM_TARGET_DEVICE=xpu python setup.py install ``` -!!! note - - FP16 is the default data type in the current XPU backend. The BF16 data - type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. - # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] -Currently, there are no pre-built XPU images. +Currently, we release prebuilt XPU images at docker [hub](https://hub.docker.com/r/intel/vllm/tags) based on vLLM released version. For more information, please refer release [note](https://github.com/intel/ai-containers/blob/main/vllm). # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] @@ -65,14 +64,14 @@ docker run -it \ # --8<-- [end:build-image-from-source] # --8<-- [start:supported-features] -XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. We require Ray as the distributed runtime backend. For example, a reference execution like following: +XPU platform supports **tensor parallel** inference/serving and also supports **pipeline parallel** as a beta feature for online serving. For **pipeline parallel**, we support it on single node with mp as the backend. For example, a reference execution like following: ```bash python -m vllm.entrypoints.openai.api_server \ --model=facebook/opt-13b \ --dtype=bfloat16 \ --max_model_len=1024 \ - --distributed-executor-backend=ray \ + --distributed-executor-backend=mp \ --pipeline-parallel-size=2 \ -tp=8 ``` diff --git a/docs/getting_started/installation/python_env_setup.inc.md b/docs/getting_started/installation/python_env_setup.inc.md index 423bf9b00d07f..06794f8d3120e 100644 --- a/docs/getting_started/installation/python_env_setup.inc.md +++ b/docs/getting_started/installation/python_env_setup.inc.md @@ -1,4 +1,4 @@ -It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: +It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment using the following commands: ```bash uv venv --python 3.12 --seed diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index 051a2d904406d..91454ec272b81 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -165,6 +165,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): # Generate documentation for each parser for stem, parser in parsers.items(): doc_path = ARGPARSE_DOC_DIR / f"{stem}.md" - with open(doc_path, "w") as f: + # Specify encoding for building on Windows + with open(doc_path, "w", encoding="utf-8") as f: f.write(parser.format_help()) logger.info("Argparse generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 881df791698e2..0cbaebb598a34 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -106,13 +106,41 @@ class Example: def determine_title(self) -> str: if not self.is_code: - with open(self.main_file) as f: + # Specify encoding for building on Windows + with open(self.main_file, encoding="utf-8") as f: first_line = f.readline().strip() match = re.match(r'^#\s+(?P.+)$', first_line) if match: return match.group('title') return fix_case(self.path.stem.replace("_", " ").title()) + def fix_relative_links(self, content: str) -> str: + """ + Fix relative links in markdown content by converting them to gh-file + format. + + Args: + content (str): The markdown content to process + + Returns: + str: Content with relative links converted to gh-file format + """ + # Regex to match markdown links [text](relative_path) + # This matches links that don't start with http, https, ftp, or # + link_pattern = r'\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)' + + def replace_link(match): + link_text = match.group(1) + relative_path = match.group(2) + + # Make relative to repo root + gh_file = (self.main_file.parent / relative_path).resolve() + gh_file = gh_file.relative_to(ROOT_DIR) + + return f'[{link_text}](gh-file:{gh_file})' + + return re.sub(link_pattern, replace_link, content) + def generate(self) -> str: content = f"# {self.title}\n\n" content += f"Source <gh-file:{self.path.relative_to(ROOT_DIR)}>.\n\n" @@ -120,14 +148,16 @@ class Example: # Use long code fence to avoid issues with # included files containing code fences too code_fence = "``````" - # Skip the title from md snippets as it's been included above - start_line = 2 + if self.is_code: - content += f"{code_fence}{self.main_file.suffix[1:]}\n" - start_line = 1 - content += f'--8<-- "{self.main_file}:{start_line}"\n' - if self.is_code: - content += f"{code_fence}\n" + content += (f"{code_fence}{self.main_file.suffix[1:]}\n" + f'--8<-- "{self.main_file}"\n' + f"{code_fence}\n") + else: + with open(self.main_file) as f: + # Skip the title from md snippets as it's been included above + main_content = f.readlines()[1:] + content += self.fix_relative_links("".join(main_content)) content += "\n" if not self.other_files: @@ -174,6 +204,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): doc_path = EXAMPLE_DOC_DIR / example.category / example_name if not doc_path.parent.exists(): doc_path.parent.mkdir(parents=True) - with open(doc_path, "w+") as f: + # Specify encoding for building on Windows + with open(doc_path, "w+", encoding="utf-8") as f: f.write(example.generate()) logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index d2fbb1870dde6..0521a22c07029 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -228,7 +228,7 @@ outputs = llm.embed(["Follow the white rabbit."], print(outputs[0].outputs) ``` -A code example can be found here: <gh-file:examples/offline_inference/embed_matryoshka_fy.py> +A code example can be found here: <gh-file:examples/offline_inference/pooling/embed_matryoshka_fy.py> ### Online Inference @@ -258,4 +258,4 @@ Expected output: {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} ``` -An OpenAI client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py> +An OpenAI client example can be found here: <gh-file:examples/online_serving/pooling/openai_embedding_matryoshka_fy.py> diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 4b4cebb6a31c2..b67ebcbe3c81a 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -322,15 +322,15 @@ th { | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `ApertusForCausalLM` | Apertus | `swiss-ai/Apertus-8B-2509`, `swiss-ai/Apertus-70B-Instruct-2509`, etc. | ✅︎ | ✅︎ | ✅︎ | | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ | -| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | -| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | | `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | @@ -365,8 +365,8 @@ th { | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | -| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ | -| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ | +| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | ✅︎ | ✅︎ | +| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | ✅︎ | | `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -382,11 +382,13 @@ th { | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MotifForCausalLM` | Motif-1-Tiny | `Motif-Technologies/Motif-2.6B`, `Motif-Technologies/Motif-2.6b-v1.1-LC`, etc. | ✅︎ | ✅︎ | | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | @@ -395,12 +397,13 @@ th { | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | -| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | ✅︎ | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3NextForCausalLM` | Qwen3NextMoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | @@ -421,9 +424,6 @@ Some models are supported only via the [Transformers backend](#transformers). Th !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. -!!! note - Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture. - ### Pooling Models See [this page](./pooling_models.md) for more information on how to use pooling models. @@ -440,6 +440,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A |--------------|--------|-------------------|----------------------|---------------------------|---------------------| | `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ | | `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ | | `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ | @@ -525,7 +526,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ``` !!! note - Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>. + Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/pooling/qwen3_reranker.py>. ```bash vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' @@ -553,6 +554,17 @@ If your model is not in the above list, we will try to automatically convert the For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. +#### Token Classification + +These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) API. + +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------| +| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ | + +!!! note + Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>. + [](){ #supported-mm-models } ## List of Multimodal Language Models @@ -619,9 +631,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | -| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | -| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | @@ -643,11 +653,11 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | +| `MiDashengLMModel` | MiDashengLM | T + A<sup>+</sup> | `mispeech/midashenglm-7b` | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | | `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | @@ -661,7 +671,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | @@ -761,8 +773,9 @@ Speech2Text models trained specifically for Automatic Speech Recognition. | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | -| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | | ✅︎ | ✅︎ | +| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | ✅︎ | +| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | ### Pooling Models diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index 280b3322b11c3..7489fc2609831 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -10,7 +10,7 @@ Before using EP, you need to install the necessary dependencies. We are actively 1. **Install DeepEP and pplx-kernels**: Set up host environment following vLLM's guide for EP kernels [here](gh-file:tools/ep_kernels). 2. **Install DeepGEMM library**: Follow the [official instructions](https://github.com/deepseek-ai/DeepGEMM#installation). -3. **For disaggregated serving**: Install UCX and NIXL following the [script](gh-file:tools/install_nixl.sh). +3. **For disaggregated serving**: Install `gdrcopy` by running the [`install_gdrcopy.sh`](gh-file:tools/install_gdrcopy.sh) script (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). ### Backend Selection Guide @@ -123,18 +123,46 @@ When enabled, vLLM collects load statistics with every forward pass and periodic ### EPLB Parameters +Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. The available keys and their descriptions are: + | Parameter | Description | Default | |-----------|-------------|---------| -| `--eplb-window-size` | Number of engine steps to track for rebalancing decisions | - | -| `--eplb-step-interval` | Frequency of rebalancing (every N engine steps) | - | -| `--eplb-log-balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` | -| `--num-redundant-experts` | Additional global experts per EP rank beyond equal distribution | `0` | +| `window_size`| Number of engine steps to track for rebalancing decisions | 1000 | +| `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 | +| `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` | +| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` | + +For example: + +```bash +vllm serve Qwen/Qwen3-30B-A3B \ + --enable-eplb \ + --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' +``` + +??? tip "Prefer individual arguments instead of JSON?" + + ```bash + vllm serve Qwen/Qwen3-30B-A3B \ + --enable-eplb \ + --eplb-config.window_size 1000 \ + --eplb-config.step_interval 3000 \ + --eplb-config.num_redundant_experts 2 \ + --eplb-config.log_balancedness true + ``` ### Expert Distribution Formula - **Default**: Each EP rank has `NUM_TOTAL_EXPERTS ÷ NUM_EP_RANKS` experts - **With redundancy**: Each EP rank has `(NUM_TOTAL_EXPERTS + NUM_REDUNDANT_EXPERTS) ÷ NUM_EP_RANKS` experts +### Memory Footprint Overhead + +EPLB uses redundant experts that need to fit in GPU memory. This means that EPLB may not be a good fit for memory constrained environments or when KV cache space is at a premium. + +This overhead equals `NUM_MOE_LAYERS * BYTES_PER_EXPERT * (NUM_TOTAL_EXPERTS + NUM_REDUNDANT_EXPERTS) ÷ NUM_EP_RANKS`. +For DeepSeekV3, this is approximately `2.4 GB` for one redundant expert per EP rank. + ### Example Command Single node deployment with EPLB enabled: @@ -146,12 +174,10 @@ VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 vllm serve deepseek-ai/DeepSeek-V --data-parallel-size 8 \ # Data parallelism --enable-expert-parallel \ # Enable EP --enable-eplb \ # Enable load balancer - --eplb-log-balancedness \ # Log balancing metrics - --eplb-window-size 1000 \ # Track last 1000 engine steps - --eplb-step-interval 3000 # Rebalance every 3000 steps + --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' ``` -For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--num-redundant-experts` to 32 in large scale use cases so the most popular experts are always available. +For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available. ## Disaggregated Serving (Prefill/Decode Split) @@ -165,7 +191,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok ### Setup Steps -1. **Install KV Connector**: Install NIXL using the [installation script](gh-file:tools/install_nixl.sh) +1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip. 2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}` diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index dfed15d4ace97..bac3f6c1fe90c 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -133,7 +133,7 @@ completion = client.chat.completions.create( {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], extra_body={ - "guided_choice": ["positive", "negative"] + "structured_outputs": {"choice": ["positive", "negative"]} } ) ``` @@ -239,7 +239,7 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai If the model has a [chat template][chat-template], you can replace `inputs` with a list of `messages` (same schema as [Chat API][chat-api]) which will be treated as a single prompt to the model. -Code example: <gh-file:examples/online_serving/openai_embedding_client.py> +Code example: <gh-file:examples/online_serving/pooling/openai_embedding_client.py> #### Multi-modal inputs @@ -313,14 +313,15 @@ and passing a list of `messages` in the request. Refer to the examples below for `MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code example below for details. -Full example: <gh-file:examples/online_serving/openai_chat_embedding_client_for_multimodal.py> +Full example: <gh-file:examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py> #### Extra parameters -The following [pooling parameters][pooling-params] are supported. +The following [pooling parameters][vllm.PoolingParams] are supported. ```python ---8<-- "vllm/entrypoints/openai/protocol.py:embedding-pooling-params" +--8<-- "vllm/pooling_params.py:common-pooling-params" +--8<-- "vllm/pooling_params.py:embedding-pooling-params" ``` The following extra parameters are supported by default: @@ -374,7 +375,7 @@ The following extra parameters are supported: ```python --8<-- "vllm/entrypoints/openai/protocol.py:transcription-extra-params" ``` - + [](){ #translations-api } ### Translations API @@ -421,7 +422,7 @@ Our Pooling API encodes input prompts using a [pooling model](../models/pooling_ The input format is the same as [Embeddings API][embeddings-api], but the output data can contain an arbitrary nested list, not just a 1-D list of floats. -Code example: <gh-file:examples/online_serving/openai_pooling_client.py> +Code example: <gh-file:examples/online_serving/pooling/openai_pooling_client.py> [](){ #classification-api } @@ -431,7 +432,7 @@ Our Classification API directly supports Hugging Face sequence-classification mo We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. -Code example: <gh-file:examples/online_serving/openai_classification_client.py> +Code example: <gh-file:examples/online_serving/pooling/openai_classification_client.py> #### Example Requests @@ -527,10 +528,11 @@ curl -v "http://127.0.0.1:8000/classify" \ #### Extra parameters -The following [pooling parameters][pooling-params] are supported. +The following [pooling parameters][vllm.PoolingParams] are supported. ```python ---8<-- "vllm/entrypoints/openai/protocol.py:classification-pooling-params" +--8<-- "vllm/pooling_params.py:common-pooling-params" +--8<-- "vllm/pooling_params.py:classification-pooling-params" ``` The following extra parameters are supported: @@ -733,10 +735,11 @@ Full example: <gh-file:examples/online_serving/openai_cross_encoder_score_for_mu #### Extra parameters -The following [pooling parameters][pooling-params] are supported. +The following [pooling parameters][vllm.PoolingParams] are supported. ```python ---8<-- "vllm/entrypoints/openai/protocol.py:score-pooling-params" +--8<-- "vllm/pooling_params.py:common-pooling-params" +--8<-- "vllm/pooling_params.py:classification-pooling-params" ``` The following extra parameters are supported: @@ -760,7 +763,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin [Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with popular open-source tools. -Code example: <gh-file:examples/online_serving/jinaai_rerank_client.py> +Code example: <gh-file:examples/online_serving/pooling/jinaai_rerank_client.py> #### Example Request @@ -815,10 +818,11 @@ Result documents will be sorted by relevance, and the `index` property can be us #### Extra parameters -The following [pooling parameters][pooling-params] are supported. +The following [pooling parameters][vllm.PoolingParams] are supported. ```python ---8<-- "vllm/entrypoints/openai/protocol.py:rerank-pooling-params" +--8<-- "vllm/pooling_params.py:common-pooling-params" +--8<-- "vllm/pooling_params.py:classification-pooling-params" ``` The following extra parameters are supported: diff --git a/docs/serving/parallelism_scaling.md b/docs/serving/parallelism_scaling.md index fa7fc1b290d50..cef1127fc5c15 100644 --- a/docs/serving/parallelism_scaling.md +++ b/docs/serving/parallelism_scaling.md @@ -66,7 +66,7 @@ Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM uses Ray to manage the distributed execution of tasks across multiple nodes and control where execution happens. -Ray also offers high-level APIs for large-scale [offline batch inference](https://docs.ray.io/en/latest/data/working-with-llms.html) and [online serving](https://docs.ray.io/en/latest/serve/llm/serving-llms.html) that can leverage vLLM as the engine. These APIs add production-grade fault tolerance, scaling, and distributed observability to vLLM workloads. +Ray also offers high-level APIs for large-scale [offline batch inference](https://docs.ray.io/en/latest/data/working-with-llms.html) and [online serving](https://docs.ray.io/en/latest/serve/llm) that can leverage vLLM as the engine. These APIs add production-grade fault tolerance, scaling, and distributed observability to vLLM workloads. For details, see the [Ray documentation](https://docs.ray.io/en/latest/index.html). @@ -104,7 +104,7 @@ Note that `VLLM_HOST_IP` is unique for each worker. Keep the shells running thes From any node, enter a container and run `ray status` and `ray list nodes` to verify that Ray finds the expected number of nodes and GPUs. !!! tip - Alternatively, set up the Ray cluster using KubeRay. For more information, see [KubeRay vLLM documentation](https://docs.ray.io/en/latest/cluster/kubernetes/examples/vllm-rayservice.html). + Alternatively, set up the Ray cluster using KubeRay. For more information, see [KubeRay vLLM documentation](https://docs.ray.io/en/latest/cluster/kubernetes/examples/rayserve-llm-example.html). ### Running vLLM on a Ray cluster diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index b92c6cef4a3fa..6e700d1faaa9c 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -40,6 +40,34 @@ If other strategies don't solve the problem, it's likely that the vLLM instance - `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL. - `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time. +## Breakpoints + +Setting normal `pdb` breakpoints may not work in vLLM's codebase if they are executed in a subprocess. You will experience something like: + +``` text + File "/usr/local/uv/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/bdb.py", line 100, in trace_dispatch + return self.dispatch_line(frame) + ^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/usr/local/uv/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/bdb.py", line 125, in dispatch_line + if self.quitting: raise BdbQuit + ^^^^^^^^^^^^^ +bdb.BdbQuit +``` + +One solution is using [forked-pdb](https://github.com/Lightning-AI/forked-pdb). Install with `pip install fpdb` and set a breakpoint with something like: + +``` python +__import__('fpdb').ForkedPdb().set_trace() +``` + +Another option is to disable multiprocessing entirely, with the `VLLM_ENABLE_V1_MULTIPROCESSING` environment variable. +This keeps the scheduler in the same process, so you can use stock `pdb` breakpoints: + +``` python +import os +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +``` + ## Incorrect network setup The vLLM instance cannot get the correct IP address if you have a complicated network config. You can find a log such as `DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl` and the IP address should be the correct one. @@ -295,4 +323,5 @@ This indicates vLLM failed to initialize the NCCL communicator, possibly due to ## Known Issues - In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759). -- To circumvent a NCCL [bug](https://github.com/NVIDIA/nccl/issues/1234) , all vLLM processes will set an environment variable `NCCL_CUMEM_ENABLE=0` to disable NCCL's `cuMem` allocator. It does not affect performance but only gives memory benefits. When external processes want to set up a NCCL connection with vLLM's processes, they should also set this environment variable, otherwise, inconsistent environment setup will cause NCCL to hang or crash, as observed in the [RLHF integration](https://github.com/OpenRLHF/OpenRLHF/pull/604) and the [discussion](gh-issue:5723#issuecomment-2554389656) . +- To address a memory overhead issue in older NCCL versions (see [bug](https://github.com/NVIDIA/nccl/issues/1234)), vLLM versions `>= 0.4.3, <= 0.10.1.1` would set the environment variable `NCCL_CUMEM_ENABLE=0`. External processes connecting to vLLM also needed to set this variable to prevent hangs or crashes. Since the underlying NCCL bug was fixed in NCCL 2.22.3, this override was removed in newer vLLM versions to allow for NCCL performance optimizations. +- In some PCIe machines (e.g. machines without NVLink), if you see an error like `transport/shm.cc:590 NCCL WARN Cuda failure 217 'peer access is not supported between these two devices'`, it's likely caused by a driver bug. See [this issue](https://github.com/NVIDIA/nccl/issues/1838) for more details. In that case, you can try to set `NCCL_CUMEM_HOST_ENABLE=0` to disable the feature, or upgrade your driver to the latest version. diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index f71805436a6ae..340aaf54bb720 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | Model Type | Status | |-----------------------------|------------------------------------------------------------------------------------| | **Decoder-only Models** | <nobr>🚀 Optimized</nobr> | -| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> | +| **Encoder-Decoder Models** | <nobr>🟢 Whisper only</nobr> | | **Embedding Models** | <nobr>🟢 Functional</nobr> | | **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> | | **Multimodal Models** | <nobr>🟢 Functional</nobr> | @@ -110,7 +110,7 @@ Models using selective state-space mechanisms instead of standard transformer at Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`,`FalconMambaForCausalLM`) are supported. Hybrid models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`, `Plamo2ForCausalLM`). Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`, `Lfm2ForCausalLM`). @@ -118,8 +118,9 @@ Please note that prefix caching is not yet supported for any of the above models #### Encoder-Decoder Models -Models requiring cross-attention between separate encoder and decoder (e.g., `BartForConditionalGeneration`, `MllamaForConditionalGeneration`) -are not yet supported. +Whisper is supported. Other models requiring cross-attention between separate +encoder and decoder (e.g., `BartForConditionalGeneration`, +`MllamaForConditionalGeneration`) are not supported. ### Features diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 22cb8b057dac7..65a87d2dd9e8e 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -117,7 +117,7 @@ def run_gemma3n(question: str, audio_count: int) -> ModelRequestData: # Granite Speech def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: - # NOTE - the setting in this example are somehat different than what is + # NOTE - the setting in this example are somewhat different from what is # optimal for granite speech, and it is generally recommended to use beam # search. Check the model README for suggested settings. # https://huggingface.co/ibm-granite/granite-speech-3.3-8b @@ -146,6 +146,36 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: ) +# MiDashengLM +def run_midashenglm(question: str, audio_count: int): + model_name = "mispeech/midashenglm-7b" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}, + ) + + audio_in_prompt = "".join( + ["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)] + ) + + default_system = "You are a helpful language and speech assistant." + + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) + + # MiniCPM-O def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model_name = "openbmb/MiniCPM-o-2_6" @@ -352,6 +382,7 @@ model_example_map = { "voxtral": run_voxtral, "gemma3n": run_gemma3n, "granite_speech": run_granite_speech, + "midashenglm": run_midashenglm, "minicpmo": run_minicpmo, "phi4_mm": run_phi4mm, "phi4_multimodal": run_phi4_multimodal, diff --git a/examples/offline_inference/chat_with_tools.py b/examples/offline_inference/chat_with_tools.py index 6e56e24f2092c..3a95b1fdfbabc 100644 --- a/examples/offline_inference/chat_with_tools.py +++ b/examples/offline_inference/chat_with_tools.py @@ -143,5 +143,5 @@ outputs = llm.chat(messages, sampling_params, tools=tools) print(outputs[0].outputs[0].text.strip()) # yields -# 'The weather in Dallas, TX is 85 degrees fahrenheit. ' +# 'The weather in Dallas, TX is 85 degrees Fahrenheit. ' # 'It is partly cloudly, with highs in the 90's.' diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index dd7559451c4c6..98fe36d0fb796 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -87,6 +87,16 @@ def parse_args(): default=0.8, help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), ) + parser.add_argument( + "--enable-dbo", + action="store_true", + help=("Enable microbatched execution"), + ) + parser.add_argument( + "--compilation-config", + type=int, + help=("Compilation optimization (O) level 0-3."), + ) parser.add_argument( "--quantization", type=str, @@ -106,7 +116,9 @@ def main( trust_remote_code, max_num_seqs, max_model_len, + compilation_config, gpu_memory_utilization, + enable_dbo, quantization, ): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) @@ -161,7 +173,9 @@ def main( max_num_seqs=max_num_seqs, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, + enable_dbo=enable_dbo, quantization=quantization, + compilation_config=compilation_config, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -218,7 +232,9 @@ if __name__ == "__main__": args.trust_remote_code, args.max_num_seqs, args.max_model_len, + args.compilation_config, args.gpu_memory_utilization, + args.enable_dbo, args.quantization, ), ) diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index 05a361fee0717..f619fa584f801 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -30,12 +30,12 @@ def run_prefill(prefill_done): ] sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - # Using PyNcclConnector to transmit KV caches between vLLM instances. + # Using P2pNcclConnector to transmit KV caches between vLLM instances. # This instance is the prefill node (kv_producer, rank 0). # The number of parallel instances for KV cache transfer is set to 2, - # as required for PyNcclConnector. + # as required for P2pNcclConnector. ktc = KVTransferConfig( - kv_connector="PyNcclConnector", + kv_connector="P2pNcclConnector", kv_role="kv_producer", kv_rank=0, kv_parallel_size=2, @@ -74,12 +74,12 @@ def run_decode(prefill_done): ] sampling_params = SamplingParams(temperature=0, top_p=0.95) - # Using PyNcclConnector to transmit KV caches between vLLM instances. + # Using P2pNcclConnector to transmit KV caches between vLLM instances. # This instance is the decode node (kv_consumer, rank 1). # The number of parallel instances for KV cache transfer is set to 2, - # as required for PyNcclConnector. + # as required for P2pNcclConnector. ktc = KVTransferConfig( - kv_connector="PyNcclConnector", + kv_connector="P2pNcclConnector", kv_role="kv_consumer", kv_rank=1, kv_parallel_size=2, diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py deleted file mode 100644 index d2ba27cd1e027..0000000000000 --- a/examples/offline_inference/dolphin.py +++ /dev/null @@ -1,311 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import copy -import os -from dataclasses import dataclass - -import cv2 -import numpy as np -import regex as re -from PIL import Image -from transformers import DonutProcessor - -from vllm import LLM, SamplingParams -from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt -from vllm.multimodal.utils import fetch_image - - -# Copied from https://github.com/bytedance/Dolphin/utils/utils.py -@dataclass -class ImageDimensions: - original_w: int - original_h: int - padded_w: int - padded_h: int - - -# Copied from https://github.com/bytedance/Dolphin/utils/utils.py -def map_to_original_coordinates( - x1, y1, x2, y2, dims: ImageDimensions -) -> tuple[int, int, int, int]: - try: - top = (dims.padded_h - dims.original_h) // 2 - left = (dims.padded_w - dims.original_w) // 2 - orig_x1 = max(0, x1 - left) - orig_y1 = max(0, y1 - top) - orig_x2 = min(dims.original_w, x2 - left) - orig_y2 = min(dims.original_h, y2 - top) - if orig_x2 <= orig_x1: - orig_x2 = min(orig_x1 + 1, dims.original_w) - if orig_y2 <= orig_y1: - orig_y2 = min(orig_y1 + 1, dims.original_h) - return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) - except Exception as e: - print(f"map_to_original_coordinates error: {str(e)}") - return 0, 0, min(100, dims.original_w), min(100, dims.original_h) - - -# Copied from https://github.com/bytedance/Dolphin/utils/utils.py -def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2): - if isinstance(image, str): - image = cv2.imread(image) - img_h, img_w = image.shape[:2] - new_boxes = [] - for box in boxes: - best_box = copy.deepcopy(box) - - def check_edge(img, current_box, i, is_vertical): - edge = current_box[i] - gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - _, binary = cv2.threshold( - gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU - ) - if is_vertical: - line = binary[current_box[1] : current_box[3] + 1, edge] - else: - line = binary[edge, current_box[0] : current_box[2] + 1] - transitions = np.abs(np.diff(line)) - return np.sum(transitions) / len(transitions) - - edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] - current_box = copy.deepcopy(box) - current_box[0] = min(max(current_box[0], 0), img_w - 1) - current_box[1] = min(max(current_box[1], 0), img_h - 1) - current_box[2] = min(max(current_box[2], 0), img_w - 1) - current_box[3] = min(max(current_box[3], 0), img_h - 1) - - for i, direction, is_vertical in edges: - best_score = check_edge(image, current_box, i, is_vertical) - if best_score <= threshold: - continue - for step in range(max_pixels): - current_box[i] += direction - if i == 0 or i == 2: - current_box[i] = min(max(current_box[i], 0), img_w - 1) - else: - current_box[i] = min(max(current_box[i], 0), img_h - 1) - score = check_edge(image, current_box, i, is_vertical) - if score < best_score: - best_score = score - best_box = copy.deepcopy(current_box) - if score <= threshold: - break - new_boxes.append(best_box) - return new_boxes - - -# Copied from https://github.com/bytedance/Dolphin/utils/utils.py -def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): - try: - x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) - x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) - x1, y1, x2, y2 = ( - max(0, min(x1, dims.padded_w - 1)), - max(0, min(y1, dims.padded_h - 1)), - max(0, min(x2, dims.padded_w)), - max(0, min(y2, dims.padded_h)), - ) - if x2 <= x1: - x2 = min(x1 + 1, dims.padded_w) - if y2 <= y1: - y2 = min(y1 + 1, dims.padded_h) - new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) - x1, y1, x2, y2 = new_boxes[0] - x1, y1, x2, y2 = ( - max(0, min(x1, dims.padded_w - 1)), - max(0, min(y1, dims.padded_h - 1)), - max(0, min(x2, dims.padded_w)), - max(0, min(y2, dims.padded_h)), - ) - if x2 <= x1: - x2 = min(x1 + 1, dims.padded_w) - if y2 <= y1: - y2 = min(y1 + 1, dims.padded_h) - if previous_box is not None: - prev_x1, prev_y1, prev_x2, prev_y2 = previous_box - if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): - y1 = prev_y2 - y1 = min(y1, dims.padded_h - 1) - if y2 <= y1: - y2 = min(y1 + 1, dims.padded_h) - new_previous_box = [x1, y1, x2, y2] - orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( - x1, y1, x2, y2, dims - ) - return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box - except Exception as e: - print(f"process_coordinates error: {str(e)}") - orig_x1, orig_y1, orig_x2, orig_y2 = ( - 0, - 0, - min(100, dims.original_w), - min(100, dims.original_h), - ) - return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] - - -# Copied from https://github.com/bytedance/Dolphin/utils/utils.py -def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: - try: - image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) - original_h, original_w = image_cv.shape[:2] - max_size = max(original_h, original_w) - top = (max_size - original_h) // 2 - bottom = max_size - original_h - top - left = (max_size - original_w) // 2 - right = max_size - original_w - left - padded_image = cv2.copyMakeBorder( - image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0) - ) - padded_h, padded_w = padded_image.shape[:2] - dimensions = ImageDimensions( - original_w=original_w, - original_h=original_h, - padded_w=padded_w, - padded_h=padded_h, - ) - return padded_image, dimensions - except Exception as e: - print(f"prepare_image error: {str(e)}") - h, w = image.height, image.width - dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) - return np.zeros((h, w, 3), dtype=np.uint8), dimensions - - -# Copied from https://github.com/bytedance/Dolphin/utils/utils.py -def parse_layout_string(bbox_str): - """Parse layout string using regular expressions""" - pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" - matches = re.finditer(pattern, bbox_str) - - parsed_results = [] - for match in matches: - coords = [float(match.group(i)) for i in range(1, 5)] - label = match.group(5).strip() - parsed_results.append((coords, label)) - - return parsed_results - - -model_id = "ByteDance/Dolphin" - -# The input image size for Dolphin is 896 x 896, -# and the patch_size is 4 x 4. -# Therefore, the initial number of patches is: -# Height: 896 / 4 = 224 patches -# Width: 896 / 4 = 224 patches - -# The Dolphin model uses a staged downsampling approach, -# defined by the "depths": [2, 2, 14, 2] configuration. -# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, -# which halves the feature map's dimensions (dividing both height and width by 2). -# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112. -# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56. -# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28. - -# Because vLLM needs to fill the image features with an encoder_prompt, -# and the encoder_prompt will have `<pad>` tokens added when tokenized, -# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783. -encoder_prompt = "".join(["0"] * 783) -sampling_params = SamplingParams( - temperature=0.0, - max_tokens=2048, -) - -processor = DonutProcessor.from_pretrained(model_id) -llm = LLM( - model=model_id, - dtype="float16", - max_num_seqs=8, - hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, -) - -parser = argparse.ArgumentParser() -parser.add_argument( - "--image_path", type=str, default=None, help="Path to a local image file." -) -args = parser.parse_args() - -if args.image_path: - if not os.path.exists(args.image_path): - raise FileNotFoundError(f"Error: File not found at {args.image_path}") - image = Image.open(args.image_path).convert("RGB") -else: - image = fetch_image( - "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" - ) - - -prompt = "Parse the reading order of this document. " -decoder_prompt = f"<s>{prompt}<Answer/>" -decoder_prompt_tokens = TokensPrompt( - prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ - "input_ids" - ] -) -enc_dec_prompt = ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), - decoder_prompt=decoder_prompt_tokens, -) -layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) -layout_result_str = layout_outputs[0].outputs[0].text -print(f"Layout analysis output:\n{layout_result_str}") - -padded_image, dims = prepare_image(image) -layout_results = parse_layout_string(layout_result_str) -text_table_elements = [] -previous_box = None -reading_order = 0 -for bbox_coords, label in layout_results: - if label == "fig": - continue - try: - x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( - process_coordinates(bbox_coords, padded_image, dims, previous_box) - ) - cropped = padded_image[y1:y2, x1:x2] - if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: - pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) - prompt_ocr = ( - "Parse the table in the image. " - if label == "tab" - else "Read text in the image. " - ) - text_table_elements.append( - { - "crop": pil_crop, - "prompt": prompt_ocr, - "reading_order": reading_order, - } - ) - reading_order += 1 - except Exception as e: - print(f"Error processing bbox (label: {label}): {str(e)}") - continue - -if text_table_elements: - batch_prompts = [] - for elem in text_table_elements: - decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>" - decoder_prompt_tokens = TokensPrompt( - prompt_token_ids=processor.tokenizer( - decoder_prompt_str, add_special_tokens=False - )["input_ids"] - ) - enc_dec_prompt = ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} - ), - decoder_prompt=decoder_prompt_tokens, - ) - batch_prompts.append(enc_dec_prompt) - batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params) - for i, output in enumerate(batch_outputs): - text_table_elements[i]["text"] = output.outputs[0].text.strip() - -print("------" * 8) -text_table_elements.sort(key=lambda x: x["reading_order"]) -for elem in text_table_elements: - print(elem.get("text", "")) diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py deleted file mode 100644 index df6c1eaf4a21e..0000000000000 --- a/examples/offline_inference/encoder_decoder.py +++ /dev/null @@ -1,193 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Demonstrate prompting of text-to-text -encoder/decoder models, specifically BART and mBART. - -This script is refactored to allow model selection via command-line arguments. -""" - -import argparse -from typing import NamedTuple, Optional - -from vllm import LLM, SamplingParams -from vllm.inputs import ( - ExplicitEncoderDecoderPrompt, - TextPrompt, - TokensPrompt, - zip_enc_dec_prompts, -) - - -class ModelRequestData(NamedTuple): - """ - Holds the configuration for a specific model, including its - HuggingFace ID and the prompts to use for the demo. - """ - - model_id: str - encoder_prompts: list - decoder_prompts: list - hf_overrides: Optional[dict] = None - - -def get_bart_config() -> ModelRequestData: - """ - Returns the configuration for facebook/bart-large-cnn. - This uses the exact test cases from the original script. - """ - encoder_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "An encoder prompt", - ] - decoder_prompts = [ - "A decoder prompt", - "Another decoder prompt", - ] - return ModelRequestData( - model_id="facebook/bart-large-cnn", - encoder_prompts=encoder_prompts, - decoder_prompts=decoder_prompts, - ) - - -def get_mbart_config() -> ModelRequestData: - """ - Returns the configuration for facebook/mbart-large-en-ro. - This uses prompts suitable for an English-to-Romanian translation task. - """ - encoder_prompts = [ - "The quick brown fox jumps over the lazy dog.", - "How are you today?", - ] - decoder_prompts = ["", ""] - hf_overrides = {"architectures": ["MBartForConditionalGeneration"]} - return ModelRequestData( - model_id="facebook/mbart-large-en-ro", - encoder_prompts=encoder_prompts, - decoder_prompts=decoder_prompts, - hf_overrides=hf_overrides, - ) - - -MODEL_GETTERS = { - "bart": get_bart_config, - "mbart": get_mbart_config, -} - - -def create_all_prompt_types( - encoder_prompts_raw: list, - decoder_prompts_raw: list, - tokenizer, -) -> list: - """ - Generates a list of diverse prompt types for demonstration. - This function is generic and uses the provided raw prompts - to create various vLLM input objects. - """ - text_prompt_raw = encoder_prompts_raw[0] - text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)]) - tokens_prompt = TokensPrompt( - prompt_token_ids=tokenizer.encode( - encoder_prompts_raw[2 % len(encoder_prompts_raw)] - ) - ) - - decoder_tokens_prompt = TokensPrompt( - prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0]) - ) - single_prompt_examples = [ - text_prompt_raw, - text_prompt, - tokens_prompt, - ] - explicit_pair_examples = [ - ExplicitEncoderDecoderPrompt( - encoder_prompt=text_prompt_raw, - decoder_prompt=decoder_tokens_prompt, - ), - ExplicitEncoderDecoderPrompt( - encoder_prompt=text_prompt, - decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)], - ), - ExplicitEncoderDecoderPrompt( - encoder_prompt=tokens_prompt, - decoder_prompt=text_prompt, - ), - ] - zipped_prompt_list = zip_enc_dec_prompts( - encoder_prompts_raw, - decoder_prompts_raw, - ) - return single_prompt_examples + explicit_pair_examples + zipped_prompt_list - - -def create_sampling_params() -> SamplingParams: - """Create a sampling params object.""" - return SamplingParams( - temperature=0, - top_p=1.0, - min_tokens=0, - max_tokens=30, - ) - - -def print_outputs(outputs: list): - """Formats and prints the generation outputs.""" - print("-" * 80) - for i, output in enumerate(outputs): - prompt = output.prompt - encoder_prompt = output.encoder_prompt - generated_text = output.outputs[0].text - print(f"Output {i + 1}:") - print(f"Encoder Prompt: {encoder_prompt!r}") - print(f"Decoder Prompt: {prompt!r}") - print(f"Generated Text: {generated_text!r}") - print("-" * 80) - - -def main(args): - """Main execution function.""" - model_key = args.model - if model_key not in MODEL_GETTERS: - raise ValueError( - f"Unknown model: {model_key}. " - f"Available models: {list(MODEL_GETTERS.keys())}" - ) - config_getter = MODEL_GETTERS[model_key] - model_config = config_getter() - - print(f"🚀 Running demo for model: {model_config.model_id}") - llm = LLM( - model=model_config.model_id, - dtype="float", - hf_overrides=model_config.hf_overrides, - ) - tokenizer = llm.llm_engine.get_tokenizer_group() - prompts = create_all_prompt_types( - encoder_prompts_raw=model_config.encoder_prompts, - decoder_prompts_raw=model_config.decoder_prompts, - tokenizer=tokenizer, - ) - sampling_params = create_sampling_params() - outputs = llm.generate(prompts, sampling_params) - print_outputs(outputs) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="A flexible demo for vLLM encoder-decoder models." - ) - parser.add_argument( - "--model", - "-m", - type=str, - default="bart", - choices=MODEL_GETTERS.keys(), - help="The short name of the model to run.", - ) - args = parser.parse_args() - main(args) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 655f9f3fce7ae..4a1b0c40604b2 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -5,6 +5,7 @@ This example shows how to use vLLM for running offline inference with the explicit/implicit prompt format on enc-dec LMMs for text generation. """ +import os import time from collections.abc import Sequence from dataclasses import asdict @@ -12,8 +13,6 @@ from typing import NamedTuple from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset -from vllm.assets.image import ImageAsset -from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -22,114 +21,9 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] -def run_donut(): - engine_args = EngineArgs( - model="naver-clova-ix/donut-base-finetuned-docvqa", - max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, - dtype="float16", - hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, - ) - - # The input image size for donut-base-finetuned-docvqa is 2560 x 1920, - # and the patch_size is 4 x 4. - # Therefore, the initial number of patches is: - # Height: 1920 / 4 = 480 patches - # Width: 2560 / 4 = 640 patches - # The Swin model uses a staged downsampling approach, - # defined by the "depths": [2, 2, 14, 2] configuration. - # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, - # which halves the feature map's dimensions (dividing both height and width by 2). - # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. - # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. - # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. - # Because vLLM needs to fill the image features with an encoder_prompt, - # and the encoder_prompt will have `<pad>` tokens added when tokenized, - # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. - prompts = [ - { - "encoder_prompt": { - "prompt": "".join(["$"] * 4799), - "multi_modal_data": { - "image": fetch_image( - "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" - ) # noqa: E501 - }, - }, - "decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501 - }, - ] - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - -def run_florence2(): - engine_args = EngineArgs( - model="microsoft/Florence-2-large", - tokenizer="Isotr0py/Florence-2-tokenizer", - max_num_seqs=8, - trust_remote_code=True, - limit_mm_per_prompt={"image": 1}, - dtype="half", - ) - - prompts = [ - { # implicit prompt with task token - "prompt": "<DETAILED_CAPTION>", - "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image}, - }, - { # explicit encoder/decoder prompt - "encoder_prompt": { - "prompt": "Describe in detail what is shown in the image.", - "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image}, - }, - "decoder_prompt": "", - }, - ] - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - -def run_mllama(): - engine_args = EngineArgs( - model="meta-llama/Llama-3.2-11B-Vision-Instruct", - max_model_len=8192, - max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, - dtype="half", - ) - - prompts = [ - { # Implicit prompt - "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501 - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image, - }, - }, - { # Explicit prompt - "encoder_prompt": { - "prompt": "<|image|>", - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image, - }, - }, - "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501 - }, - ] - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - def run_whisper(): + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + engine_args = EngineArgs( model="openai/whisper-large-v3-turbo", max_model_len=448, @@ -163,9 +57,6 @@ def run_whisper(): model_example_map = { - "donut": run_donut, - "florence2": run_florence2, - "mllama": run_mllama, "whisper": run_whisper, } @@ -179,7 +70,7 @@ def parse_args(): "--model-type", "-m", type=str, - default="mllama", + default="whisper", choices=model_example_map.keys(), help='Huggingface "model_type".', ) diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor/custom.py similarity index 93% rename from examples/offline_inference/logits_processor.py rename to examples/offline_inference/logits_processor/custom.py index 3e122319169eb..4112a498f37ab 100644 --- a/examples/offline_inference/logits_processor.py +++ b/examples/offline_inference/logits_processor/custom.py @@ -56,7 +56,6 @@ class DummyLogitsProcessor(LogitsProcessor): self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: - """Never impacts greedy sampling""" return False def update_state(self, batch_update: Optional[BatchUpdate]): @@ -75,13 +74,12 @@ class DummyLogitsProcessor(LogitsProcessor): return logits # Save target values before modification - rows_list = list(self.req_info.keys()) cols = torch.tensor( - [self.req_info[i] for i in rows_list], - dtype=torch.long, - device=logits.device, + list(self.req_info.values()), dtype=torch.long, device=logits.device + ) + rows = torch.tensor( + list(self.req_info.keys()), dtype=torch.long, device=logits.device ) - rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) values_to_keep = logits[rows, cols].clone() # Mask all but target tokens diff --git a/examples/offline_inference/logits_processor/custom_req.py b/examples/offline_inference/logits_processor/custom_req.py new file mode 100644 index 0000000000000..4c19bb4ce2bae --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates wrapping a request-level logits processor to be +compatible with vLLM's batch-level logits processing + +For demo purposes, a dummy logits processor is employed which, if +`target_token` is passed as a keyword argument to `SamplingParams.extra_args`, +will mask out all tokens except `target_token`. This logits processor can be +applied to a vector of logits associated with a single decode step for a single +request. The logits processor cannot be applied to a request which does not +pass in a `target_token` custom argument. + +The request-level dummy logits processor is wrapped to create a batch-level +logits processor, which can apply the logits processor to output logits from +all requests in the persistent batch in a given decode step. For requests which +do not provide a `target_token` argument, the corresponding row of `logits` +will not be modified. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect the `target_token` to be decoded in each step, yielding an output +similar to that shown below: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ +""" + +from typing import Any, Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Optional[Any] = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/logits_processor/custom_req_init.py b/examples/offline_inference/logits_processor/custom_req_init.py new file mode 100644 index 0000000000000..62947d122e01c --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req_init.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates a special case of wrapping a request-level logits +processor, namely the case where it is necessary to utilize engine config or +environment info passed to the constructor. The subclass must override the +wrapper base class `__init__()` method to access the engine config, the device +identifier, or the flag which indicates whether pinned memory is available. + +For demo purposes, a request-level dummy logits processor is employed which +causes the same token (`target_token`) to be decoded in each step. The +request-level dummy logits processor is wrapped to create a batch-level logits +processor, which can apply the logits processor to output logits from all +requests in the persistent batch in a given decode step. + +The wrapped dummy logits processor below models a scenario where we must +disable the logits processor on non-"cuda" platforms. The wrapper base class +`__init__()` is overridden in order to check this condition and set a flag. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect that on a "cuda" device the output will look something like: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ + +which indicates that the logits processor is running. However, on a non-"cuda" +device, the first and third requests would not repeat the same token. +""" + +from typing import Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of overriding the wrapper class `__init__()` in order to utilize + info about the device type""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + super().__init__(vllm_config, device, is_pin_memory) + self.is_cuda = device.type == "cuda" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value, and the device + must be "cuda"-type + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + if ( + not self.is_cuda + or ( + target_token := params.extra_args + and params.extra_args.get("target_token") + ) + is None + ): + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/neuron.py b/examples/offline_inference/neuron.py deleted file mode 100644 index 7826629a36d01..0000000000000 --- a/examples/offline_inference/neuron.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - -def main(): - # Create an LLM. - llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - max_num_seqs=8, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in transformers-neuronx. - # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=1024, - block_size=1024, - # ruff: noqa: E501 - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - tensor_parallel_size=2, - ) - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py deleted file mode 100644 index 8b1d235ff9742..0000000000000 --- a/examples/offline_inference/neuron_eagle.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This example shows how to run offline inference with an EAGLE speculative -decoding model on neuron. To use EAGLE speculative decoding, you must use -a draft model that is specifically fine-tuned for EAGLE speculation. -Additionally, to use EAGLE with NxD Inference, the draft model must include -the LM head weights from the target model. These weights are shared between -the draft and target model. -""" - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "What is annapurna labs?", -] - - -def main(): - # Create a sampling params object. - sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) - - # Create an LLM. - llm = LLM( - model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", - speculative_config={ - "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", - "num_speculative_tokens": 5, - "max_model_len": 2048, - }, - max_num_seqs=4, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in neuronx-distributed-inference. - max_model_len=2048, - block_size=2048, - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - tensor_parallel_size=32, - override_neuron_config={ - "enable_eagle_speculation": True, - "enable_fused_speculation": True, - }, - ) - - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, \n\n\n Generated text: {generated_text!r}") - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_int8_quantization.py b/examples/offline_inference/neuron_int8_quantization.py deleted file mode 100644 index c0ecfac508996..0000000000000 --- a/examples/offline_inference/neuron_int8_quantization.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os - -from vllm import LLM, SamplingParams - -# creates XLA hlo graphs for all the context length buckets. -os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" -# creates XLA hlo graphs for all the token gen buckets. -os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" -# Quantizes neuron model weight to int8 , -# The default config for quantization is int8 dtype. -os.environ["NEURON_QUANT_DTYPE"] = "s8" - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - -def main(): - # Create an LLM. - llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - max_num_seqs=8, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in transformers-neuronx. - # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=2048, - block_size=2048, - # ruff: noqa: E501 - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - quantization="neuron_quant", - override_neuron_config={ - "cast_logits_dtype": "bfloat16", - }, - tensor_parallel_size=2, - ) - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_multimodal.py b/examples/offline_inference/neuron_multimodal.py deleted file mode 100644 index 26f7505f2fa53..0000000000000 --- a/examples/offline_inference/neuron_multimodal.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import requests -import torch -from neuronx_distributed_inference.models.mllama.utils import add_instruct -from PIL import Image - -from vllm import LLM, SamplingParams, TextPrompt - - -def get_image(image_url): - image = Image.open(requests.get(image_url, stream=True).raw) - return image - - -# Model Inputs -PROMPTS = [ - "What is in this image? Tell me a story", - "What is the recipe of mayonnaise in two sentences?", - "Describe this image", - "What is the capital of Italy famous for?", -] -IMAGES = [ - get_image( - "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" - ), - None, - get_image( - "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" - ), - None, -] -SAMPLING_PARAMS = [ - dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16) - for _ in range(len(PROMPTS)) -] - - -def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params): - # Prepare all inputs for mllama generation, including: - # 1. put text prompt into instruct chat template - # 2. compose single text and single image prompt into Vllm's prompt class - # 3. prepare sampling parameters - input_image = single_image - has_image = torch.tensor([1]) - if isinstance(single_image, torch.Tensor) and single_image.numel() == 0: - has_image = torch.tensor([0]) - - instruct_prompt = add_instruct(prompt, has_image) - inputs = TextPrompt(prompt=instruct_prompt) - - if input_image is not None: - inputs["multi_modal_data"] = {"image": input_image} - - sampling_params = SamplingParams(**sampling_params) - return inputs, sampling_params - - -def print_outputs(outputs): - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def main(): - assert ( - len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS) - ), f"""Text, image prompts and sampling parameters should have the - same batch size; but got {len(PROMPTS)}, {len(IMAGES)}, - and {len(SAMPLING_PARAMS)}""" - - # Create an LLM. - llm = LLM( - model="meta-llama/Llama-3.2-11B-Vision-Instruct", - max_num_seqs=1, - max_model_len=4096, - block_size=4096, - device="neuron", - tensor_parallel_size=32, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True, - "save_sharded_checkpoint": True, - "on_device_sampling_config": { - "global_topk": 1, - "dynamic": False, - "deterministic": False, - }, - }, - ) - - batched_inputs = [] - batched_sample_params = [] - for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS): - inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params) - # test batch-size = 1 - outputs = llm.generate(inputs, sampling_params) - print_outputs(outputs) - batched_inputs.append(inputs) - batched_sample_params.append(sampling_params) - - # test batch-size = 4 - outputs = llm.generate(batched_inputs, batched_sample_params) - print_outputs(outputs) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py deleted file mode 100644 index 7fc22caee742d..0000000000000 --- a/examples/offline_inference/neuron_speculation.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This example shows how to run offline inference with a speculative -decoding model on neuron. -""" - -import os - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, I am a language model and I can help", - "The president of the United States is", - "The capital of France is", -] - - -def config_buckets(): - """Configure context length and token gen buckets.""" - # creates XLA hlo graphs for all the context length buckets. - os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" - # creates XLA hlo graphs for all the token gen buckets. - os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" - - -def initialize_llm(): - """Create an LLM with speculative decoding.""" - return LLM( - model="openlm-research/open_llama_7b", - speculative_config={ - "model": "openlm-research/open_llama_3b", - "num_speculative_tokens": 4, - "max_model_len": 2048, - }, - max_num_seqs=4, - max_model_len=2048, - block_size=2048, - device="neuron", - tensor_parallel_size=32, - ) - - -def process_requests(llm: LLM, sampling_params: SamplingParams): - """Generate texts from prompts and print them.""" - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def main(): - """Main function that sets up the llm and processes prompts.""" - config_buckets() - llm = initialize_llm() - # Create a sampling params object. - sampling_params = SamplingParams(max_tokens=100, top_k=1) - process_requests(llm, sampling_params) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/pooling/README.md b/examples/offline_inference/pooling/README.md new file mode 100644 index 0000000000000..79afbd9cfac47 --- /dev/null +++ b/examples/offline_inference/pooling/README.md @@ -0,0 +1,39 @@ +# Pooling models + +## Convert llm model to seq cls + +```bash +# for BAAI/bge-reranker-v2-gemma +# Caution: "Yes" and "yes" are two different tokens +python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls +# for mxbai-rerank-v2 +python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls +# for Qwen3-Reranker +python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls +``` + +## Embed jina_embeddings_v3 usage + +Only text matching task is supported for now. See <gh-pr:16120> + +```bash +python examples/offline_inference/pooling/embed_jina_embeddings_v3.py +``` + +## Embed matryoshka dimensions usage + +```bash +python examples/offline_inference/pooling/embed_matryoshka_fy.py +``` + +## Named Entity Recognition (NER) usage + +```bash +python examples/offline_inference/pooling/ner.py +``` + +## Qwen3 reranker usage + +```bash +python examples/offline_inference/pooling/qwen3_reranker.py +``` diff --git a/examples/offline_inference/convert_model_to_seq_cls.py b/examples/offline_inference/pooling/convert_model_to_seq_cls.py similarity index 100% rename from examples/offline_inference/convert_model_to_seq_cls.py rename to examples/offline_inference/pooling/convert_model_to_seq_cls.py diff --git a/examples/offline_inference/embed_jina_embeddings_v3.py b/examples/offline_inference/pooling/embed_jina_embeddings_v3.py similarity index 100% rename from examples/offline_inference/embed_jina_embeddings_v3.py rename to examples/offline_inference/pooling/embed_jina_embeddings_v3.py diff --git a/examples/offline_inference/embed_matryoshka_fy.py b/examples/offline_inference/pooling/embed_matryoshka_fy.py similarity index 100% rename from examples/offline_inference/embed_matryoshka_fy.py rename to examples/offline_inference/pooling/embed_matryoshka_fy.py diff --git a/examples/offline_inference/pooling/ner.py b/examples/offline_inference/pooling/ner.py new file mode 100644 index 0000000000000..f18742fac0d54 --- /dev/null +++ b/examples/offline_inference/pooling/ner.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults( + model="boltuix/NeuroBERT-NER", + runner="pooling", + enforce_eager=True, + trust_remote_code=True, + ) + return parser.parse_args() + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Barack Obama visited Microsoft headquarters in Seattle on January 2025." + ] + + # Create an LLM. + llm = LLM(**vars(args)) + tokenizer = llm.get_tokenizer() + label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label + + # Run inference + outputs = llm.encode(prompts) + + for prompt, output in zip(prompts, outputs): + logits = output.outputs.data + predictions = logits.argmax(dim=-1) + + # Map predictions to labels + tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids) + labels = [label_map[p.item()] for p in predictions] + + # Print results + for token, label in zip(tokens, labels): + if token not in tokenizer.all_special_tokens: + print(f"{token:15} → {label}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/pooling/qwen3_reranker.py similarity index 100% rename from examples/offline_inference/qwen3_reranker.py rename to examples/offline_inference/pooling/qwen3_reranker.py diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index b6007b9f46301..1a5879a6d35f5 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -45,7 +45,11 @@ datamodule_config = { class PrithviMAE: def __init__(self, model): self.model = LLM( - model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True + model=model, + skip_tokenizer_init=True, + dtype="float16", + enforce_eager=True, + model_impl="terratorch", ) def run(self, input_data, location_coords): diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index 8023cd6677762..418c40645f9f2 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -12,13 +12,13 @@ from vllm.pooling_params import PoolingParams # multimodal data. In this specific case this example will take a geotiff # image as input, process it using the multimodal data processor, and # perform inference. -# Reuirement - install plugin at: +# Requirement - install plugin at: # https://github.com/christian-pinto/prithvi_io_processor_plugin def main(): torch.set_default_dtype(torch.float16) - image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 img_prompt = dict( data=image_url, @@ -36,7 +36,8 @@ def main(): # to avoid the model going OOM. # The maximum number depends on the available GPU memory max_num_seqs=32, - io_processor_plugin="prithvi_to_tiff_india", + io_processor_plugin="prithvi_to_tiff", + model_impl="terratorch", ) pooling_params = PoolingParams(task="encode", softmax=False) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 65621023ab6ce..360fd79b55aad 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -28,12 +28,15 @@ Learn more about Ray placement groups: https://docs.ray.io/en/latest/placement-groups.html """ +import gc import os import ray import torch +import zmq from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from torch.multiprocessing.reductions import reduce_tensor from vllm import LLM @@ -86,20 +89,72 @@ class RayTrainingActor: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(0) + self.zmq_context = zmq.Context() + self.zmq_address_counter = 0 + self.zmq_handle = None def report_device_id(self) -> str: return self.device_uuid - def get_weight_ipc_handles(self): - from torch.multiprocessing.reductions import reduce_tensor + def get_zmq_handles(self) -> dict[str, str]: + suffix = f"{self.device_uuid}-{self.zmq_address_counter}" + self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock" + self.zmq_address_counter += 1 + return {self.device_uuid: self.zmq_handle} - data = {} - for name, p in self.model.named_parameters(): - # A training actor might hold only a subset of the weights and may - # need to gather weights from other actors. For demonstration - # purposes, each training actor owns the full weight set. - data[name] = reduce_tensor(p.detach()) - return {self.device_uuid: data} + def update_weights(self): + # align size to avoid misaligned address + align_size = 256 + + def get_size(p: torch.Tensor) -> int: + return (p.nbytes + align_size - 1) // align_size * align_size + + named_parameters: dict[str, torch.nn.Parameter] = dict( + self.model.named_parameters() + ) + max_tensor_size = max(get_size(p) for p in named_parameters.values()) + # use max_tensor_size * 2 as buffer size + buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0") + s = self.zmq_context.socket(zmq.REQ) + s.bind(self.zmq_handle) + handle = reduce_tensor(buffer) + + offset = 0 + buckets: list[tuple[list[dict], list[torch.Tensor]]] = [] + named_tensors: list[dict] = [] + real_tensors: list[torch.Tensor] = [] + for name, p in named_parameters.items(): + size = get_size(p) + if offset + size > buffer.numel(): + buckets.append((named_tensors, real_tensors)) + named_tensors, real_tensors = [], [] + offset = 0 + # assume tensors are contiguous + named_tensors.append( + {"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset} + ) + real_tensors.append(p) + offset += size + if named_tensors: + buckets.append((named_tensors, real_tensors)) + s.send_pyobj(handle) + s.recv() + for named_tensors, real_tensors in buckets: + offset = 0 + for p in real_tensors: + buffer[offset : offset + p.nbytes].data.copy_( + p.data.view(-1).view(dtype=torch.uint8), non_blocking=True + ) + offset += get_size(p) + torch.cuda.synchronize() + s.send_pyobj(named_tensors) + s.recv() + s.send_pyobj(None) + s.recv() + s.close() + del buffer + gc.collect() + torch.cuda.empty_cache() # Ray manages four GPUs. @@ -175,18 +230,22 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0] # the second inference engine. assert training_actor_device_ids[2:] == inference_engine_device_ids[1] -print("Gather all the IPC handles from the training actors.") -ipc_handles = {} +print("Gather all the ZMQ handles from the training actors.") +zmq_handles = {} for actor in training_actors: - ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) + zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) + +print(f"ZMQ handles: {zmq_handles}") print("Update the weights of the inference engines.") -for llm in inference_engines: - ray.get( - llm.collective_rpc.remote( - "update_weights_from_ipc_handles", args=(ipc_handles,) - ) - ) +ray.get( + [actor.update_weights.remote() for actor in training_actors] + + [ + llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,)) + for llm in inference_engines + ] +) + print("Check if the weights are updated.") for llm in inference_engines: assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index d2a8419ffabcd..c0e60b9793407 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +from typing import Callable, Optional, TypedDict + import torch +import zmq def stateless_init_process_group(master_address, master_port, rank, world_size, device): @@ -66,6 +70,27 @@ class WorkerExtension: return weights_updated +def rebuild_ipc( + handle: tuple[Callable, tuple], device_id: Optional[int] = None +) -> torch.Tensor: + func, args = handle + list_args = list(args) + if device_id is not None: + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +class FlattenedTensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + # specify the start offset of this tensor in shared ipc_buffer tensor + offset: int + + class ColocateWorkerExtension: """ The class for vLLM's worker to inherit from, in the colocate setting. @@ -76,27 +101,62 @@ class ColocateWorkerExtension: should pass the full qualified name as `worker_extension_cls` argument. """ + def update_weights_from_ipc(self, zmq_handles: dict[str, str]): + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + assert self.device is not None + if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + socket = self._zmq_ctx.socket(zmq.REP) + socket.connect(zmq_handles[self.report_device_id()]) + buffer: Optional[torch.Tensor] = None + while True: + payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = ( + socket.recv_pyobj() + ) + if payload is None: + # means the update is done + process_weights_after_loading( + self.model_runner.model, self.model_config, self.device + ) + torch.cuda.synchronize() + socket.send(b"") + break + if isinstance(payload, tuple): + # an ipc handle that vLLM can use `func, args = handle` + # and `func(*args)` to rebuild GPU tensor. + buffer = rebuild_ipc(payload, self.device.index) + assert buffer.dtype == torch.uint8 + socket.send(b"") + continue + assert isinstance(payload, list) + assert buffer is not None + weights = [] + for item in payload: + shape = item["shape"] + if isinstance(shape, (list, tuple)): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + weights.append((item["name"], tensor)) + self.model_runner.model.load_weights(weights=weights) + del weights + torch.cuda.synchronize() + socket.send(b"") + + socket.close() + del buffer + gc.collect() + torch.cuda.empty_cache() + def report_device_id(self) -> str: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(self.device.index) return self.device_uuid - def update_weights_from_ipc_handles(self, ipc_handles): - handles = ipc_handles[self.device_uuid] - device_id = self.device.index - weights = [] - for name, handle in handles.items(): - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) - self.model_runner.model.load_weights(weights=weights) - torch.cuda.synchronize() - def check_weights_changed(self): """ Check if the weights are updated to 0. diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 5af232cb6af6a..004e75b204642 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -53,7 +53,6 @@ def parse_args(): "--method", type=str, default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], ) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) @@ -118,6 +117,11 @@ def main(): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } + elif args.method.endswith("mtp"): + speculative_config = { + "method": args.method, + "num_speculative_tokens": args.num_spec_tokens, + } else: raise ValueError(f"unknown method: {args.method}") diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index 88d87beb4874d..6b6099f71b120 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -This file demonstrates the example usage of guided decoding -to generate structured outputs using vLLM. It shows how to apply -different guided decoding techniques such as Choice, Regex, JSON schema, -and Grammar to produce structured and formatted results -based on specific prompts. +This file demonstrates the example usage of structured outputs +in vLLM. It shows how to apply different constraints such as choice, +regex, json schema, and grammar to produce structured and formatted +results based on specific prompts. """ from enum import Enum @@ -13,19 +12,23 @@ from enum import Enum from pydantic import BaseModel from vllm import LLM, SamplingParams -from vllm.sampling_params import GuidedDecodingParams +from vllm.sampling_params import StructuredOutputsParams MAX_TOKENS = 50 -# Guided decoding by Choice (list of possible options) -guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"]) -sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice) +# Structured outputs by Choice (list of possible options) +structured_outputs_params_choice = StructuredOutputsParams( + choice=["Positive", "Negative"] +) +sampling_params_choice = SamplingParams( + structured_outputs=structured_outputs_params_choice +) prompt_choice = "Classify this sentiment: vLLM is wonderful!" -# Guided decoding by Regex -guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") +# Structured outputs by Regex +structured_outputs_params_regex = StructuredOutputsParams(regex=r"\w+@\w+\.com\n") sampling_params_regex = SamplingParams( - guided_decoding=guided_decoding_params_regex, + structured_outputs=structured_outputs_params_regex, stop=["\n"], max_tokens=MAX_TOKENS, ) @@ -36,7 +39,7 @@ prompt_regex = ( ) -# Guided decoding by JSON using Pydantic schema +# Structured outputs by JSON using Pydantic schema class CarType(str, Enum): sedan = "sedan" suv = "SUV" @@ -51,17 +54,16 @@ class CarDescription(BaseModel): json_schema = CarDescription.model_json_schema() -guided_decoding_params_json = GuidedDecodingParams(json=json_schema) +structured_outputs_params_json = StructuredOutputsParams(json=json_schema) sampling_params_json = SamplingParams( - guided_decoding=guided_decoding_params_json, - max_tokens=MAX_TOKENS, + structured_outputs=structured_outputs_params_json, max_tokens=MAX_TOKENS ) prompt_json = ( - "Generate a JSON with the brand, model and car_type of" + "Generate a JSON with the brand, model and car_type of " "the most iconic car from the 90's" ) -# Guided decoding by Grammar +# Structured outputs by Grammar simplified_sql_grammar = """ root ::= select_statement select_statement ::= "SELECT " column " from " table " where " condition @@ -70,13 +72,15 @@ table ::= "table_1 " | "table_2 " condition ::= column "= " number number ::= "1 " | "2 " """ -guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar) +structured_outputs_params_grammar = StructuredOutputsParams( + grammar=simplified_sql_grammar +) sampling_params_grammar = SamplingParams( - guided_decoding=guided_decoding_params_grammar, + structured_outputs=structured_outputs_params_grammar, max_tokens=MAX_TOKENS, ) prompt_grammar = ( - "Generate an SQL query to show the 'username' and 'email'from the 'users' table." + "Generate an SQL query to show the 'username' and 'email' from the 'users' table." ) @@ -93,16 +97,16 @@ def main(): llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100) choice_output = generate_output(prompt_choice, sampling_params_choice, llm) - format_output("Guided decoding by Choice", choice_output) + format_output("Structured outputs by Choice", choice_output) regex_output = generate_output(prompt_regex, sampling_params_regex, llm) - format_output("Guided decoding by Regex", regex_output) + format_output("Structured outputs by Regex", regex_output) json_output = generate_output(prompt_json, sampling_params_json, llm) - format_output("Guided decoding by JSON", json_output) + format_output("Structured outputs by JSON", json_output) grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm) - format_output("Guided decoding by Grammar", grammar_output) + format_output("Structured outputs by Grammar", grammar_output) if __name__ == "__main__": diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index 9776f4fe322b9..0093b63b0b1f3 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -42,7 +42,7 @@ def main(): llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct" # Set `enforce_eager=True` to avoid ahead-of-time compilation. - # In real workloads, `enforace_eager` should be `False`. + # In real workloads, `enforce_eager` should be `False`. llm = LLM(**llm_args) outputs = llm.generate(prompts, sampling_params) print("-" * 50) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index b104113b88213..de3f3afc17948 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -204,28 +204,6 @@ def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: ) -# Florence2 -def run_florence2(questions: list[str], modality: str) -> ModelRequestData: - assert modality == "image" - - engine_args = EngineArgs( - model="microsoft/Florence-2-large", - tokenizer="Isotr0py/Florence-2-tokenizer", - max_model_len=4096, - max_num_seqs=2, - trust_remote_code=True, - dtype="bfloat16", - limit_mm_per_prompt={modality: 1}, - ) - - prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions] - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - # Fuyu def run_fuyu(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1008,44 +986,6 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData: ) -# LLama 3.2 -def run_mllama(questions: list[str], modality: str) -> ModelRequestData: - assert modality == "image" - - model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" - - # Note: The default setting of max_num_seqs (256) and - # max_model_len (131072) for this model may cause OOM. - # You may lower either to run this example on lower-end GPUs. - - # The configuration below has been confirmed to launch on a single L40 GPU. - engine_args = EngineArgs( - model=model_name, - max_model_len=8192, - max_num_seqs=2, - limit_mm_per_prompt={modality: 1}, - ) - - tokenizer = AutoTokenizer.from_pretrained(model_name) - messages = [ - [ - { - "role": "user", - "content": [{"type": "image"}, {"type": "text", "text": question}], - } - ] - for question in questions - ] - prompts = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=False - ) - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - # Molmo def run_molmo(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1497,6 +1437,80 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ) +# Qwen3-VL-Dense +def run_qwen3_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Qwen/Qwen3-VL-4B-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# Qwen3-VL-MOE +def run_qwen3_vl_moe(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Qwen/Qwen3-VL-30B-A3B-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # R-4B def run_r_vl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1665,7 +1679,6 @@ model_example_map = { "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, "ernie45_vl": run_ernie45_vl, - "florence2": run_florence2, "fuyu": run_fuyu, "gemma3": run_gemma3, "gemma3n": run_gemma3n, @@ -1691,7 +1704,6 @@ model_example_map = { "minicpmv": run_minicpmv, "minimax_vl_01": run_minimax_vl_01, "mistral3": run_mistral3, - "mllama": run_mllama, "molmo": run_molmo, "nemotron_vl": run_nemotron_vl, "NVLM_D": run_nvlm_d, @@ -1707,6 +1719,8 @@ model_example_map = { "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_omni": run_qwen2_5_omni, + "qwen3_vl": run_qwen3_vl, + "qwen3_vl_moe": run_qwen3_vl_moe, "rvl": run_r_vl, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, @@ -1716,6 +1730,15 @@ model_example_map = { } +MODELS_NEED_VIDEO_METADATA = [ + "glm4_1v", + "glm4_5v", + "glm4_5v_fp8", + "qwen3_vl", + "qwen3_vl_moe", +] + + def get_multi_modal_input(args): """ return { @@ -1740,12 +1763,13 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question + needs_metadata = args.model_type in MODELS_NEED_VIDEO_METADATA video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata vid_questions = ["Why is this video funny?"] return { - "data": [(video, metadata)] if args.model_type == "glm4_1v" else video, + "data": ([(video, metadata)] if needs_metadata else video), "questions": vid_questions, } @@ -1764,6 +1788,7 @@ def apply_image_repeat( probs = [1.0 - image_repeat_prob, image_repeat_prob] inputs = [] + inputs_with_empty_media = [] cur_image = data for i in range(num_prompts): if image_repeat_prob is not None: @@ -1774,14 +1799,25 @@ def apply_image_repeat( new_val = (i // 256 // 256, i // 256, i % 256) cur_image.putpixel((0, 0), new_val) + uuid = "uuid_{}".format(i) + inputs.append( { "prompt": prompts[i % len(prompts)], "multi_modal_data": {modality: cur_image}, + "multi_modal_uuids": {modality: uuid}, } ) - return inputs + inputs_with_empty_media.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: None}, + "multi_modal_uuids": {modality: uuid}, + } + ) + + return inputs, inputs_with_empty_media @contextmanager @@ -1860,6 +1896,13 @@ def parse_args(): help="If True, then use different prompt (with the same multi-modal " "data) for each request.", ) + + parser.add_argument( + "--verify-mm-cache-hit-with-uuids", + action="store_true", + help="If True, will send all requests in a second batch with empty mm " + "data to verify cache hits with UUIDs.", + ) return parser.parse_args() @@ -1903,26 +1946,48 @@ def main(args): assert args.num_prompts > 0 if args.num_prompts == 1: # Single inference + uuid = "uuid_0" inputs = { "prompt": prompts[0], "multi_modal_data": {modality: data}, + "multi_modal_uuids": {modality: uuid}, + } + inputs_with_empty_media = { + "prompt": prompts[0], + "multi_modal_data": {modality: None}, + "multi_modal_uuids": {modality: uuid}, } else: # Batch inference if args.image_repeat_prob is not None: # Repeat images with specified probability of "image_repeat_prob" - inputs = apply_image_repeat( - args.image_repeat_prob, args.num_prompts, data, prompts, modality + inputs, inputs_with_empty_media = apply_image_repeat( + args.image_repeat_prob, + args.num_prompts, + data, + prompts, + modality, ) else: # Use the same image for all prompts - inputs = [ - { - "prompt": prompts[i % len(prompts)], - "multi_modal_data": {modality: data}, - } - for i in range(args.num_prompts) - ] + inputs = [] + inputs_with_empty_media = [] + for i in range(args.num_prompts): + uuid = "uuid_{}".format(i) + inputs.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: data}, + "multi_modal_uuids": {modality: uuid}, + } + ) + inputs_with_empty_media.append( + { + "prompt": prompts[i % len(prompts)], + "multi_modal_data": {modality: None}, + "multi_modal_uuids": {modality: uuid}, + } + ) # Add LoRA request if applicable lora_request = ( @@ -1942,6 +2007,26 @@ def main(args): print(generated_text) print("-" * 50) + if args.verify_mm_cache_hit_with_uuids: + try: + # Verify cache hits with UUIDs + print( + "Sending a second batch of requests with empty media" + " and matching UUIDs." + ) + outputs = llm.generate( + inputs_with_empty_media, + sampling_params=sampling_params, + lora_request=lora_request, + ) + print("-" * 50) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + print("-" * 50) + except Exception as e: + print(f"Failed to verify cache hits with UUIDs. Error: {e}") + if __name__ == "__main__": args = parse_args() diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 01c2905cf26d8..51b41f34b2ff6 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -637,26 +637,6 @@ def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData: ) -def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" - - # The configuration below has been confirmed to launch on a single L40 GPU. - engine_args = EngineArgs( - model=model_name, - max_model_len=8192, - max_num_seqs=2, - limit_mm_per_prompt={"image": len(image_urls)}, - ) - - img_prompt = "Given the first image <|image|> and the second image<|image|>" - prompt = f"<|begin_of_text|>{img_prompt}, {question}?" - return ModelRequestData( - engine_args=engine_args, - prompt=prompt, - image_data=[fetch_image(url) for url in image_urls], - ) - - def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "nvidia/NVLM-D-72B" @@ -1253,7 +1233,6 @@ model_example_map = { "llava-next": load_llava_next, "llava-onevision": load_llava_onevision, "mistral3": load_mistral3, - "mllama": load_mllama, "NVLM_D": load_nvlm_d, "ovis": load_ovis, "ovis2_5": load_ovis2_5, diff --git a/examples/online_serving/dashboards/README.md b/examples/online_serving/dashboards/README.md new file mode 100644 index 0000000000000..30cea6b24d57e --- /dev/null +++ b/examples/online_serving/dashboards/README.md @@ -0,0 +1,87 @@ +# Monitoring Dashboards + +This directory contains monitoring dashboard configurations for vLLM, providing +comprehensive observability for your vLLM deployments. + +## Dashboard Platforms + +We provide dashboards for two popular observability platforms: + +- **[Grafana](https://grafana.com)** +- **[Perses](https://perses.dev)** + +## Dashboard Format Approach + +All dashboards are provided in **native formats** that work across different +deployment methods: + +### Grafana (JSON) + +- ✅ Works with any Grafana instance (cloud, self-hosted, Docker) +- ✅ Direct import via Grafana UI or API +- ✅ Can be wrapped in Kubernetes operators when needed +- ✅ No vendor lock-in or deployment dependencies + +### Perses (YAML) + +- ✅ Works with standalone Perses instances +- ✅ Compatible with Perses API and CLI +- ✅ Supports Dashboard-as-Code workflows +- ✅ Can be wrapped in Kubernetes operators when needed + +## Dashboard Contents + +Both platforms provide equivalent monitoring capabilities: + +| Dashboard | Description | +|-----------|-------------| +| **Performance Statistics** | Tracks latency, throughput, and performance metrics | +| **Query Statistics** | Monitors request volume, query performance, and KPIs | + +## Quick Start + +First, navigate to this example's directory: + +```bash +cd examples/online_serving/dashboards +``` + +### Grafana + +Import the JSON directly into the Grafana UI, or use the API: + +```bash +curl -X POST http://grafana/api/dashboards/db \ + -H "Content-Type: application/json" \ + -d @grafana/performance_statistics.json +``` + +### Perses + +Import via the Perses CLI: + +```bash +percli apply -f perses/performance_statistics.yaml +``` + +## Requirements + +- **Prometheus** metrics from your vLLM deployment +- **Data source** configured in your monitoring platform +- **vLLM metrics** enabled and accessible + +## Platform-Specific Documentation + +For detailed deployment instructions and platform-specific options, see: + +- **[Grafana Documentation](./grafana)** - JSON dashboards, operator usage, manual import +- **[Perses Documentation](./perses)** - YAML specs, CLI usage, operator wrapping + +## Contributing + +When adding new dashboards, please: + +1. Provide native formats (JSON for Grafana, YAML specs for Perses) +2. Update platform-specific README files +3. Ensure dashboards work across deployment methods +4. Test with the latest platform versions diff --git a/examples/online_serving/dashboards/grafana/README.md b/examples/online_serving/dashboards/grafana/README.md new file mode 100644 index 0000000000000..e42b0f814367d --- /dev/null +++ b/examples/online_serving/dashboards/grafana/README.md @@ -0,0 +1,59 @@ +# Grafana Dashboards for vLLM Monitoring + +This directory contains Grafana dashboard configurations (as JSON) designed to monitor +vLLM performance and metrics. + +## Requirements + +- Grafana 8.0+ +- Prometheus data source configured in Grafana +- vLLM deployment with Prometheus metrics enabled + +## Dashboard Descriptions + +- **[performance_statistics.json](./performance_statistics.json)**: Tracks performance metrics including latency and + throughput for your vLLM service. +- **[query_statistics.json](./query_statistics.json)**: Tracks query performance, request volume, and key + performance indicators for your vLLM service. + +## Deployment Options + +### Manual Import (Recommended) + +The easiest way to use these dashboards is to manually import the JSON configurations +directly into your Grafana instance: + +1. Navigate to your Grafana instance +2. Click the '+' icon in the sidebar +3. Select 'Import' +4. Copy and paste the JSON content from the dashboard files, or upload the JSON files + directly + +### Grafana Operator + +If you're using the [Grafana Operator](https://github.com/grafana-operator/grafana-operator) +in Kubernetes, you can wrap these JSON configurations in a `GrafanaDashboard` custom +resource: + +```yaml +# Note: Adjust the instanceSelector to match your Grafana instance's labels +# You can check with: kubectl get grafana -o yaml +apiVersion: grafana.integreatly.org/v1beta1 +kind: GrafanaDashboard +metadata: + name: vllm-performance-dashboard +spec: + instanceSelector: + matchLabels: + dashboards: grafana # Adjust to match your Grafana instance labels + folder: "vLLM Monitoring" + json: | + # Replace this comment with the complete JSON content from + # performance_statistics.json - The JSON should start with { and end with } +``` + +Then apply to your cluster: + +```bash +kubectl apply -f your-dashboard.yaml -n <namespace> +``` diff --git a/examples/online_serving/dashboards/grafana/performance_statistics.json b/examples/online_serving/dashboards/grafana/performance_statistics.json new file mode 100644 index 0000000000000..390d3dd6d2594 --- /dev/null +++ b/examples/online_serving/dashboards/grafana/performance_statistics.json @@ -0,0 +1,1405 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 26, + "links": [], + "panels": [ + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 9, + "panels": [], + "title": "Graph: E2E latency over time ", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "End-to-End latency of requests, showing average and key percentiles over time.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Latency", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 18, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": true, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 1 + }, + "id": 1, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(vllm:e2e_request_latency_seconds_sum[$__interval]) / rate(vllm:e2e_request_latency_seconds_count[$__interval])", + "format": "table", + "legendFormat": "E2E Latency", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency over Time", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "99th percentile of End-to-End request latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P99", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 1 + }, + "id": 5, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency (P99)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "90th percentile of End-to-End request latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P90", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 1 + }, + "id": 4, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency (P90)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Average End-to-End request latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "Average", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 5 + }, + "id": 2, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "(sum(increase(vllm:e2e_request_latency_seconds_sum[$__range])) / sum(increase(vllm:e2e_request_latency_seconds_count[$__range])))", + "legendFormat": "Average E2E Latency", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency (Avg)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "50th percentile (median) of End-to-End request latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P50", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 5 + }, + "id": 3, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "E2E Latency (P50)", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 9 + }, + "id": 8, + "panels": [], + "title": "Graph: TTFT(Time To First Token) over time ", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Time to first token (TTFT) latency, showing average and key percentiles over time.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Latency", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 18, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 10 + }, + "id": 10, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(vllm:time_to_first_token_seconds_sum[$__interval]) / rate(vllm:time_to_first_token_seconds_count[$__interval])", + "format": "table", + "legendFormat": "TTFT (Avg)", + "range": true, + "refId": "A" + } + ], + "title": "TTFT Over Time", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "99th percentile of Time To First Token latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P99", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 10 + }, + "id": 14, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__range])))", + "legendFormat": "TTFT (p99)", + "range": true, + "refId": "A" + } + ], + "title": "TTFT (P99)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "90th percentile of Time To First Token latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "P90", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 10 + }, + "id": 13, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__range])))", + "legendFormat": "TTFT (p90)", + "range": true, + "refId": "A" + } + ], + "title": "TTFT (P90)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Average Time To First Token latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "displayName": "Average", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 14 + }, + "id": 11, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "(sum(increase(vllm:time_to_first_token_seconds_sum[$__range])) / sum(increase(vllm:time_to_first_token_seconds_count[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "TTFT (Avg)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "50th percentile (median) of Time To First Token latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "displayName": "P50", + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 14 + }, + "id": 12, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orietitletChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "TTFT (P50)", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 18 + }, + "id": 7, + "panels": [], + "title": "ITL (Iteration Latency / Time Per Output Token) over time.", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Iteration latency, or average time taken to generate a single output token, with percentiles.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Latency", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 17, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 19 + }, + "id": 15, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(vllm:time_per_output_token_seconds_sum[$__interval]) / rate(vllm:time_per_output_token_seconds_count[$__interval])", + "legendFormat": "ITL (Avg)", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__interval])))", + "hide": false, + "instant": false, + "legendFormat": "ITL (p50)", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__interval])))", + "hide": false, + "instant": false, + "legendFormat": "ITL (p90)", + "range": true, + "refId": "C" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__interval])))", + "hide": false, + "instant": false, + "legendFormat": "ITL (p99)", + "range": true, + "refId": "D" + } + ], + "title": "ITL (Time Per Output Token) Over Time", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "90th percentile of Iteration Latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 19 + }, + "id": 18, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "ITL (P90)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "99th percentile of Iteration Latency over the selected time range.\n\n", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 19 + }, + "id": 19, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "ITL (P99)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Average Iteration Latency (time per output token) over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 12, + "y": 23 + }, + "id": 16, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "(sum(increase(vllm:time_per_output_token_seconds_sum[$__range])) / sum(increase(vllm:time_per_output_token_seconds_count[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "ITL (Avg)", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "50th percentile (median) of Iteration Latency over the selected time range.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 4, + "w": 6, + "x": 18, + "y": 23 + }, + "id": 17, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__range])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "ITL (P50)", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 27 + }, + "id": 6, + "panels": [], + "title": "TPS (Tokens Per Second)", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Rate of tokens processed per second, including prompt and generation phases.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "tokens/sec (tps)" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 28 + }, + "id": 20, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(vllm:generation_tokens_total[$__interval])", + "legendFormat": "Generation TPS", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(vllm:prompt_tokens_total[$__interval])", + "hide": false, + "instant": false, + "legendFormat": "Prompt TPS", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "rate(vllm:iteration_tokens_total_count[$__interval])", + "hide": false, + "instant": false, + "legendFormat": "Overall Iteration TPS", + "range": true, + "refId": "C" + } + ], + "title": "TPS (Tokens Per Second) Over Time", + "type": "timeseries" + } + ], + "preload": false, + "schemaVersion": 40, + "tags": [], + "templating": { + "list": [ + { + "name": "DS_PROMETHEUS", + "type": "datasource", + "label": "datasource", + "query": "prometheus", + "refresh": 1, + "current": { + "text": "Prometheus", + "value": "prometheus" + } + }, + { + "current": { + "text": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)", + "value": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)" + }, + "label": "Aggregation", + "name": "agg_method", + "options": [ + { + "selected": true, + "text": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)", + "value": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)" + } + ], + "query": "avg : Average\n0.50 : P50\n0.90 : P90\n0.99 : P99\n0.999 : Max (Approx)", + "type": "custom" + }, + { + "current": { + "text": [ + "granite-33-2b-instruct" + ], + "value": [ + "granite-33-2b-instruct" + ] + }, + "definition": "label_values(vllm:generation_tokens_total,model_name)", + "includeAll": true, + "label": "Deployment_ID", + "multi": true, + "name": "Deployment_id", + "options": [], + "query": { + "qryType": 1, + "query": "label_values(vllm:generation_tokens_total,model_name)", + "refId": "PrometheusVariableQueryEditor-VariableQuery" + }, + "refresh": 1, + "regex": "", + "type": "query" + } + ] + }, + "time": { + "from": "now-12h", + "to": "now" + }, + "timezone": "browser", + "uid": "performance-statistics", + "title": "Performance Statistics", + "version": 40, + "weekStart": "" +} \ No newline at end of file diff --git a/examples/online_serving/dashboards/grafana/query_statistics.json b/examples/online_serving/dashboards/grafana/query_statistics.json new file mode 100644 index 0000000000000..880f6c5d71764 --- /dev/null +++ b/examples/online_serving/dashboards/grafana/query_statistics.json @@ -0,0 +1,760 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "description": "High-level overview of VLLM model deployment behavior and key performance indicators. Designed for Data Scientists and Product Managers to monitor request volume, token throughput, and latency", + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 47, + "links": [], + "panels": [ + { + "collapsed": true, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 0 }, + "id": 20, + "panels": [], + "title": "Request Over Time", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { "type": "linear" }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { "group": "A", "mode": "none" }, + "thresholdsStyle": { "mode": "off" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "req/s" + }, + "overrides": [] + }, + "gridPos": { "h": 6, "w": 10, "x": 0, "y": 1 }, + "id": 1, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "single", "sort": "none" } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "editorMode": "code", + "expr": "sum by (model_name) (\n rate(vllm:request_success_total{model_name=~\"$Deployment_id\"}[$__rate_interval])\n)", + "interval": "1", + "legendFormat": "{{model_name}}", + "range": true, + "refId": "A" + } + ], + "title": "Successful Requests Over Time", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "req/s" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 1 }, + "id": 2, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["mean"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum(rate(vllm:request_success_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Requests Avg Rate", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calcultaions": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 17, "y": 1 }, + "id": 3, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "p50 Latency", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 4 }, + "id": 4, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "p90 Latency", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 17, "y": 4 }, + "id": 5, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "p99 Latency", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 7 }, + "id": 19, + "panels": [], + "title": "Size Distribution", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "fillOpacity": 80, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "lineWidth": 1, + "stacking": { "group": "A", "mode": "none" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 6, "w": 10, "x": 0, "y": 8 }, + "id": 6, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "single", "sort": "none" } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum by (le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "{{model_name}} le={{le}}", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size Distribution", + "type": "histogram" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "calculation ": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 8 }, + "id": 9, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum by(le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size p90", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calcultion": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 17, "y": 8 }, + "id": 8, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum by(le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size p50", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calcultaion": { "index": 0, "text": "mean" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 11 }, + "id": 7, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum(rate(vllm:prompt_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))\n/\nsum(rate(vllm:request_success_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size Avg", + "type": "stat" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "Last (not null)" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 17, "y": 11 }, + "id": 10, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le, model_name) (rate(vllm:request_prompt_tokens_bucket{model_name=~\"$Deployment_id\"}[$__rate_interval])))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Token Size p99", + "type": "stat" + }, + { + "collapsed": true, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 14 }, + "id": 18, + "panels": [], + "title": "Input Token Over Time", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { "type": "linear" }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { "group": "A", "mode": "none" }, + "thresholdsStyle": { "mode": "off" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 6, "w": 10, "x": 0, "y": 15 }, + "id": 11, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "single", "sort": "none" } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum by (model_name) (rate(vllm:prompt_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "{{model_name}}", + "range": true, + "refId": "A" + } + ], + "title": "Input Tokens Over Time", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "mean" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 15 }, + "id": 12, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum(rate(vllm:prompt_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Input Tokens/Sec Avg", + "type": "stat" + }, + { + "collapsed": false, + "gridPos": { "h": 1, "w": 24, "x": 0, "y": 21 }, + "id": 17, + "panels": [], + "title": "Output Token Over Time", + "type": "row" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { "legend": false, "tooltip": false, "viz": false }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { "type": "linear" }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { "group": "A", "mode": "none" }, + "thresholdsStyle": { "mode": "off" } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 6, "w": 10, "x": 0, "y": 22 }, + "id": 13, + "options": { + "legend": { "calcs": [], "displayMode": "list", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "single", "sort": "none" } + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum by (model_name) (rate(vllm:generation_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "{{model_name}}", + "range": true, + "refId": "A" + } + ], + "title": "Output Tokens Over Time", + "type": "timeseries" + }, + { + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, + "fieldConfig": { + "defaults": { + "color": { "mode": "thresholds" }, + "mappings": [ + { "options": { "Calculation": { "index": 0, "text": "mean" } }, "type": "value" } + ], + "thresholds": { + "mode": "absolute", + "steps": [{ "color": "green", "value": null }, { "color": "red", "value": 80 }] + }, + "unit": "cps" + }, + "overrides": [] + }, + "gridPos": { "h": 3, "w": 7, "x": 10, "y": 22 }, + "id": 14, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { "calcs": ["lastNotNull"], "fields": "", "values": false }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "11.3.0", + "targets": [ + { + "editorMode": "code", + "expr": "sum(rate(vllm:generation_tokens_total{model_name=~\"$Deployment_id\"}[$__rate_interval]))", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Output Tokens/Sec Avg", + "type": "stat" + } + ], + "preload": false, + "schemaVersion": 40, + "tags": [], + "templating": { + "list": [ + { + "current": { "text": "Prometheus", "value": "4184fc20-68a7-483a-8d9b-7caa59c680dd" }, + "label": "datasource", + "name": "DS_PROMETHEUS", + "options": [], + "query": "prometheus", + "refresh": 1, + "type": "datasource" + }, + { + "current": { "text": ["All"], "value": ["$__all"] }, + "definition": "label_values(vllm:request_success_total,model_name)", + "includeAll": true, + "label": "Deployment_ID", + "multi": true, + "name": "Deployment_id", + "options": [], + "query": { + "qryType": 1, + "query": "label_values(vllm:request_success_total,model_name)", + "refId": "PrometheusVariableQueryEditor-VariableQuery" + }, + "refresh": 1, + "regex": "", + "sort": 1, + "type": "query" + }, + { + "current": { "text": "All hours", "value": "All hours" }, + "hide": 2, + "label": "Rush Hours Only", + "name": "rush_hours", + "options": [ + { "selected": true, "text": "false", "value": "All hours" }, + { "selected": false, "text": "true", "value": "Rush hours" } + ], + "query": "false : All hours, true : Rush hours", + "type": "custom" + }, + { + "current": { "text": "All", "value": "All" }, + "hide": 2, + "label": "Rush Hours Type", + "name": "rush_hours_type", + "options": [ + { "selected": true, "text": "^All__.*$", "value": "All" }, + { "selected": false, "text": "^Static__.*$", "value": "Static" }, + { "selected": false, "text": "^Dynamic__.*$", "value": "Dynamic" } + ], + "query": "^All__.*$ : All, ^Static__.*$ : Static, ^Dynamic__.*$ : Dynamic", + "type": "custom" + }, + { + "current": { "text": "", "value": "" }, + "hide": 2, + "name": "query0", + "options": [], + "query": "", + "refresh": 1, + "regex": "", + "type": "query" + } + ] + }, + "time": { "from": "now-12h", "to": "now" }, + "timepicker": {}, + "timezone": "browser", + "title": "Query Statistics_New4", + "uid": "query-statistics4", + "version": 2, + "weekStart": "" +} + diff --git a/examples/online_serving/dashboards/perses/README.md b/examples/online_serving/dashboards/perses/README.md new file mode 100644 index 0000000000000..ae04fd17b1b9d --- /dev/null +++ b/examples/online_serving/dashboards/perses/README.md @@ -0,0 +1,48 @@ +# Perses Dashboards for vLLM Monitoring + +This directory contains Perses dashboard configurations designed to monitor vLLM +performance and metrics. + +## Requirements + +- Perses instance (standalone or via operator) +- Prometheus data source configured in Perses +- vLLM deployment with Prometheus metrics enabled + +## Dashboard Format + +We provide dashboards in the **native Perses YAML format** that works across all +deployment methods: + +- **Files**: `*.yaml` (native Perses dashboard specifications) +- **Format**: Pure dashboard specifications that work everywhere +- **Usage**: Works with standalone Perses, API imports, CLI, and file provisioning +- **Kubernetes**: Directly compatible with Perses Operator + +## Dashboard Descriptions + +- **[performance_statistics.yaml](./performance_statistics.yaml)**: Performance metrics with aggregated latency + statistics +- **[query_statistics.yaml](./query_statistics.yaml)**: Query performance and deployment metrics + +## Deployment Options + +### Direct Import to Perses + +Import the dashboard specifications via Perses API or CLI: + +```bash +percli apply -f performance_statistics.yaml +``` + +### Perses Operator (Kubernetes) + +The native YAML format works directly with the Perses Operator: + +```bash +kubectl apply -f performance_statistics.yaml -n <namespace> +``` + +### File Provisioning + +Place the YAML files in a Perses provisioning folder for automatic loading. diff --git a/examples/online_serving/dashboards/perses/performance_statistics.yaml b/examples/online_serving/dashboards/perses/performance_statistics.yaml new file mode 100644 index 0000000000000..2e8d24c3324b9 --- /dev/null +++ b/examples/online_serving/dashboards/perses/performance_statistics.yaml @@ -0,0 +1,764 @@ +kind: PersesDashboard +metadata: + name: performance-statistics + createdAt: 0001-01-01T00:00:00Z + updatedAt: 0001-01-01T00:00:00Z + version: 0 + project: "" +spec: + display: + name: Performance Statistics + + variables: + - kind: ListVariable + spec: + display: + name: Deployment_ID + hidden: false + name: Deployment_id + allowAllValue: true + allowMultiple: true + defaultValue: + - $__all + sort: alphabetical-asc + plugin: + kind: PrometheusLabelValuesVariable + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + labelName: model_name + matchers: + # Any one vllm metric that always carries model_name + - vllm:generation_tokens_total{} + + panels: + "1": + kind: Panel + spec: + display: + name: E2E Latency over Time + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + # avg latency by model = sum(rate(sum)) / sum(rate(count)) + query: > + sum by (model_name) (rate(vllm:e2e_request_latency_seconds_sum{model_name=~"$Deployment_id"}[$__interval])) + / + sum by (model_name) (rate(vllm:e2e_request_latency_seconds_count{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}}' + + "2": + kind: Panel + spec: + display: + name: E2E Latency (Avg) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + (sum by (model_name) (increase(vllm:e2e_request_latency_seconds_sum{model_name=~"$Deployment_id"}[$__range]))) + / + (sum by (model_name) (increase(vllm:e2e_request_latency_seconds_count{model_name=~"$Deployment_id"}[$__range]))) + + "3": + kind: Panel + spec: + display: + name: E2E Latency (P50) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.50, + sum by (le, model_name) ( + rate(vllm:e2e_request_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "4": + kind: Panel + spec: + display: + name: E2E Latency (P90) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.90, + sum by (le, model_name) ( + rate(vllm:e2e_request_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "5": + kind: Panel + spec: + display: + name: E2E Latency (P99) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.99, + sum by (le, model_name) ( + rate(vllm:e2e_request_latency_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "6": + kind: Panel + spec: + display: + name: TTFT over Time + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (model_name) (rate(vllm:time_to_first_token_seconds_sum{model_name=~"$Deployment_id"}[$__interval])) + / + sum by (model_name) (rate(vllm:time_to_first_token_seconds_count{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}}' + + "7": + kind: Panel + spec: + display: + name: TTFT (Avg) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + (sum by (model_name) (increase(vllm:time_to_first_token_seconds_sum{model_name=~"$Deployment_id"}[$__range]))) + / + (sum by (model_name) (increase(vllm:time_to_first_token_seconds_count{model_name=~"$Deployment_id"}[$__range]))) + + "8": + kind: Panel + spec: + display: + name: TTFT (P50) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.50, + sum by (le, model_name) ( + rate(vllm:time_to_first_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "9": + kind: Panel + spec: + display: + name: TTFT (P90) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.90, + sum by (le, model_name) ( + rate(vllm:time_to_first_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "10": + kind: Panel + spec: + display: + name: TTFT (P99) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.99, + sum by (le, model_name) ( + rate(vllm:time_to_first_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "11": + kind: Panel + spec: + display: + name: ITL (Time per Output Token) over Time + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (model_name) (rate(vllm:time_per_output_token_seconds_sum{model_name=~"$Deployment_id"}[$__interval])) + / + sum by (model_name) (rate(vllm:time_per_output_token_seconds_count{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}}' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.50, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + seriesNameFormat: '{{model_name}} p50' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.90, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + seriesNameFormat: '{{model_name}} p90' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.99, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + seriesNameFormat: '{{model_name}} p99' + + "12": + kind: Panel + spec: + display: + name: ITL (Avg) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + (sum by (model_name) (increase(vllm:time_per_output_token_seconds_sum{model_name=~"$Deployment_id"}[$__range]))) + / + (sum by (model_name) (increase(vllm:time_per_output_token_seconds_count{model_name=~"$Deployment_id"}[$__range]))) + + "13": + kind: Panel + spec: + display: + name: ITL (P50) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.50, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "14": + kind: Panel + spec: + display: + name: ITL (P90) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.90, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "15": + kind: Panel + spec: + display: + name: ITL (P99) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + histogram_quantile( + 0.99, + sum by (le, model_name) ( + rate(vllm:time_per_output_token_seconds_bucket{model_name=~"$Deployment_id"}[$__interval]) + ) + ) + + "16": + kind: Panel + spec: + display: + name: TPS (Tokens/sec) over Time + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (model_name) (rate(vllm:generation_tokens_total{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}} generation' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (model_name) (rate(vllm:prompt_tokens_total{model_name=~"$Deployment_id"}[$__interval])) + seriesNameFormat: '{{model_name}} prompt' + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + # overall iteration tokens/sec if exposed + query: > + rate(vllm:iteration_tokens_total_count[$__interval]) + seriesNameFormat: 'iteration overall' + + "17": + kind: Panel + spec: + display: + name: KV Cache Usage (avg %) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + # Multiply by 100 so we can read it as a percentage without setting a unit (avoids CUE unit conflicts) + query: > + 100 * avg(vllm:gpu_cache_usage_perc) + + "18": + kind: Panel + spec: + display: + name: Running Requests by Pod + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (pod) (vllm:num_requests_running) + seriesNameFormat: '{{pod}}' + + "19": + kind: Panel + spec: + display: + name: Waiting Requests by Pod + plugin: + kind: TimeSeriesChart + spec: + legend: + mode: table + position: bottom + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: > + sum by (pod) (vllm:num_requests_waiting) + seriesNameFormat: '{{pod}}' + + "20": + kind: Panel + spec: + display: + name: Running Requests (sum) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: sum(vllm:num_requests_running) + + "21": + kind: Panel + spec: + display: + name: Waiting Requests (sum) + plugin: + kind: StatChart + spec: + calculation: last-number + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: + kind: PrometheusDatasource + name: accelerators-thanos-querier-datasource + query: sum(vllm:num_requests_waiting) + + layouts: + - kind: Grid + spec: + display: + title: Overview + items: + - x: 0 + y: 0 + width: 6 + height: 3 + content: { $ref: '#/spec/panels/17' } # KV cache % + - x: 6 + y: 0 + width: 6 + height: 3 + content: { $ref: '#/spec/panels/20' } # running sum + - x: 12 + y: 0 + width: 6 + height: 3 + content: { $ref: '#/spec/panels/21' } # waiting sum + + - kind: Grid + spec: + display: + title: E2E Latency + items: + - x: 0 + y: 1 + width: 10 + height: 6 + content: { $ref: '#/spec/panels/1' } + - x: 10 + y: 1 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/2' } + - x: 17 + y: 1 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/3' } + - x: 10 + y: 4 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/4' } + - x: 17 + y: 4 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/5' } + + - kind: Grid + spec: + display: + title: TTFT + items: + - x: 0 + y: 8 + width: 10 + height: 6 + content: { $ref: '#/spec/panels/6' } + - x: 10 + y: 8 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/7' } + - x: 17 + y: 8 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/8' } + - x: 10 + y: 11 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/9' } + - x: 17 + y: 11 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/10' } + + - kind: Grid + spec: + display: + title: ITL (Time per Output Token) + items: + - x: 0 + y: 15 + width: 10 + height: 6 + content: { $ref: '#/spec/panels/11' } + - x: 10 + y: 15 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/12' } + - x: 17 + y: 15 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/13' } + - x: 10 + y: 18 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/14' } + - x: 17 + y: 18 + width: 7 + height: 3 + content: { $ref: '#/spec/panels/15' } + + - kind: Grid + spec: + display: + title: TPS (Prompt / Generation / Iteration) + items: + - x: 0 + y: 22 + width: 14 + height: 6 + content: { $ref: '#/spec/panels/16' } + + - kind: Grid + spec: + display: + title: Per-Pod Request State + items: + - x: 0 + y: 28 + width: 12 + height: 6 + content: { $ref: '#/spec/panels/18' } + - x: 12 + y: 28 + width: 12 + height: 6 + content: { $ref: '#/spec/panels/19' } + diff --git a/examples/online_serving/dashboards/perses/query_statistics.yaml b/examples/online_serving/dashboards/perses/query_statistics.yaml new file mode 100644 index 0000000000000..28109aae81511 --- /dev/null +++ b/examples/online_serving/dashboards/perses/query_statistics.yaml @@ -0,0 +1,392 @@ +kind: PersesDashboard +metadata: + name: query-statistics + createdAt: 0001-01-01T00:00:00Z + updatedAt: 0001-01-01T00:00:00Z + version: 0 + project: "" +spec: + display: + name: Query Statistics_New + + variables: + - kind: ListVariable + spec: + name: NS + display: { name: Namespace } + allowMultiple: false + defaultValue: llm-d + plugin: + kind: PrometheusLabelValuesVariable + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + labelName: namespace + matchers: + - up{service=~".*vllm.*"} + + - kind: ListVariable + spec: + name: SVC + display: { name: Service } + allowMultiple: false + defaultValue: vllm-qwen2-0-5b-sim + plugin: + kind: PrometheusLabelValuesVariable + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + labelName: service + matchers: + - up{namespace="$NS",service=~".*vllm.*"} + + - kind: ListVariable + spec: + name: MODEL + display: { name: Model (real vLLM) } + allowAllValue: true + allowMultiple: true + defaultValue: ["$__all"] + plugin: + kind: PrometheusLabelValuesVariable + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + labelName: model_name + matchers: + - vllm:request_success_total{namespace="$NS",service="$SVC"} + + panels: + + # --- Core (works on Simulator & Real) --- + core_running_now: + kind: Panel + spec: + display: { name: Running Requests (now) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum(vllm:num_requests_running{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_waiting_now: + kind: Panel + spec: + display: { name: Waiting Requests (now) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum(vllm:num_requests_waiting{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_kv_usage_now: + kind: Panel + spec: + display: { name: KV Cache Usage (0–1) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: avg(vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_running_ts: + kind: Panel + spec: + display: { name: Running Over Time } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (service) (vllm:num_requests_running{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_waiting_ts: + kind: Panel + spec: + display: { name: Waiting Over Time } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (service) (vllm:num_requests_waiting{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + core_targets_up: + kind: Panel + spec: + display: { name: Scrape Targets Up } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: count(up{namespace="$NS",service="$SVC"} == 1) or vector(0) + minStep: "15s" + + # --- KV Cache as Percent (works on Simulator & Real) --- + core_kv_usage_pct_now: + kind: Panel + spec: + display: { name: KV Cache Usage (%) – now } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + # multiply by 100 to present percentage; omit format.unit to avoid schema conflicts + query: (avg(vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) + minStep: "15s" + + core_kv_usage_pct_ts: + kind: Panel + spec: + display: { name: KV Cache Usage (%) – over time } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: (avg by (service) (vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) + minStep: "15s" + + # --- Per-Pod breakdowns (works on Simulator & Real) --- + per_pod_running_ts: + kind: Panel + spec: + display: { name: Running by Pod } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (pod) (vllm:num_requests_running{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + per_pod_waiting_ts: + kind: Panel + spec: + display: { name: Waiting by Pod } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (pod) (vllm:num_requests_waiting{namespace="$NS",service="$SVC"}) or vector(0) + minStep: "15s" + + per_pod_kv_pct_ts: + kind: Panel + spec: + display: { name: KV Cache (%) by Pod } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + # if your exporter labels kv metric with pod (the sim does), this works; otherwise it will just return empty + query: (avg by (pod) (vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) + minStep: "15s" + + # --- Real vLLM only (zeros on simulator) --- + real_req_rate_ts: + kind: Panel + spec: + display: { name: Request Rate (real vLLM) } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (model_name) (rate(vllm:request_success_total{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval])) or vector(0) + minStep: "15s" + + real_p50: + kind: Panel + spec: + display: { name: p50 Latency (real vLLM) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: histogram_quantile(0.50, sum by (le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval]))) or vector(0) + minStep: "15s" + + real_p90: + kind: Panel + spec: + display: { name: p90 Latency (real vLLM) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: histogram_quantile(0.90, sum by (le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval]))) or vector(0) + minStep: "15s" + + real_p99: + kind: Panel + spec: + display: { name: p99 Latency (real vLLM) } + plugin: { kind: StatChart, spec: { calculation: last-number } } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: histogram_quantile(0.99, sum by (le, model_name) (rate(vllm:e2e_request_latency_seconds_bucket{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval]))) or vector(0) + minStep: "15s" + + real_input_tokens_ts: + kind: Panel + spec: + display: { name: Input Tokens / sec (real vLLM) } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (model_name) (rate(vllm:prompt_tokens_total{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval])) or vector(0) + minStep: "15s" + + real_output_tokens_ts: + kind: Panel + spec: + display: { name: Output Tokens / sec (real vLLM) } + plugin: + kind: TimeSeriesChart + spec: + legend: { mode: table, position: bottom } + visual: { display: line, lineWidth: 1, areaOpacity: 0.3 } + queries: + - kind: TimeSeriesQuery + spec: + plugin: + kind: PrometheusTimeSeriesQuery + spec: + datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } + query: sum by (model_name) (rate(vllm:generation_tokens_total{namespace="$NS",service="$SVC",model_name=~"$MODEL"}[$__interval])) or vector(0) + minStep: "15s" + + layouts: + - kind: Grid + spec: + display: { title: Core (Sim & Real) } + items: + - { x: 0, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_running_now' } } + - { x: 6, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_waiting_now' } } + - { x: 12, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_kv_usage_now' } } + - { x: 18, y: 0, width: 6, height: 3, content: { $ref: '#/spec/panels/core_targets_up' } } + - { x: 0, y: 3, width: 12, height: 6, content: { $ref: '#/spec/panels/core_running_ts' } } + - { x: 12, y: 3, width: 12, height: 6, content: { $ref: '#/spec/panels/core_waiting_ts' } } + + - kind: Grid + spec: + display: { title: KV Cache (%) } + items: + - { x: 0, y: 9, width: 6, height: 3, content: { $ref: '#/spec/panels/core_kv_usage_pct_now' } } + - { x: 6, y: 9, width: 18, height: 6, content: { $ref: '#/spec/panels/core_kv_usage_pct_ts' } } + + - kind: Grid + spec: + display: { title: Per-Pod breakdowns } + items: + - { x: 0, y: 15, width: 12, height: 6, content: { $ref: '#/spec/panels/per_pod_running_ts' } } + - { x: 12, y: 15, width: 12, height: 6, content: { $ref: '#/spec/panels/per_pod_waiting_ts' } } + - { x: 0, y: 21, width: 24, height: 6, content: { $ref: '#/spec/panels/per_pod_kv_pct_ts' } } + + - kind: Grid + spec: + display: { title: Real vLLM only (shows 0 on simulator) } + items: + - { x: 0, y: 27, width: 12, height: 6, content: { $ref: '#/spec/panels/real_req_rate_ts' } } + - { x: 12, y: 27, width: 4, height: 3, content: { $ref: '#/spec/panels/real_p50' } } + - { x: 16, y: 27, width: 4, height: 3, content: { $ref: '#/spec/panels/real_p90' } } + - { x: 20, y: 27, width: 4, height: 3, content: { $ref: '#/spec/panels/real_p99' } } + - { x: 0, y: 33, width: 12, height: 6, content: { $ref: '#/spec/panels/real_input_tokens_ts' } } + - { x: 12, y: 33, width: 12, height: 6, content: { $ref: '#/spec/panels/real_output_tokens_ts' } } + diff --git a/examples/online_serving/disaggregated_prefill.sh b/examples/online_serving/disaggregated_prefill.sh index 6925dc8af07e9..d434e22b1ae88 100644 --- a/examples/online_serving/disaggregated_prefill.sh +++ b/examples/online_serving/disaggregated_prefill.sh @@ -53,7 +53,7 @@ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & # decoding instance, which is the KV consumer CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ @@ -62,7 +62,7 @@ CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ --kv-transfer-config \ - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & # wait until prefill and decode instances are ready wait_for_server 8100 diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index f238c66234dcc..9fd55fc9ddc94 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -6,6 +6,8 @@ import msgspec import zmq from msgspec.msgpack import Decoder +from vllm.v1.core.kv_cache_utils import BlockHash + # # Types copied from vllm.distributed.kv_events @@ -22,8 +24,8 @@ class KVCacheEvent( class BlockStored(KVCacheEvent): - block_hashes: list[int] - parent_block_hash: Optional[int] + block_hashes: list[BlockHash] + parent_block_hash: Optional[BlockHash] token_ids: list[int] block_size: int lora_id: Optional[int] @@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent): class BlockRemoved(KVCacheEvent): - block_hashes: list[int] + block_hashes: list[BlockHash] medium: Optional[str] diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index ac5f79b56e49f..37216a5cfe574 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -266,10 +266,52 @@ def run_audio(model: str) -> None: print("Chat completion output from base64 encoded audio:", result) +def run_multi_audio(model: str) -> None: + from vllm.assets.audio import AudioAsset + + # Two different audios to showcase batched inference. + audio_url = AudioAsset("winning_call").url + audio_base64 = encode_base64_content_from_url(audio_url) + audio_url2 = AudioAsset("azacinto_foscolo").url + audio_base64_2 = encode_base64_content_from_url(audio_url2) + + # OpenAI-compatible schema (`input_audio`) + chat_completion_from_base64 = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Are these two audios the same?"}, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": "wav", + }, + }, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64_2, + "format": "wav", + }, + }, + ], + } + ], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from input audio:", result) + + example_function_map = { "text-only": run_text_only, "single-image": run_single_image, "multi-image": run_multi_image, + "multi-audio": run_multi_audio, "video": run_video, "audio": run_audio, } diff --git a/examples/online_serving/openai_chat_completion_client_with_tools_required.py b/examples/online_serving/openai_chat_completion_client_with_tools_required.py index 7eb8668213eef..6ff65b18f6674 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools_required.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools_required.py @@ -6,7 +6,7 @@ without any specific flags: ```bash VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \ - --guided-decoding-backend outlines + --structured-outputs-config.backend outlines ``` This example demonstrates how to generate chat completions diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh index f356d7d4529ea..56888c8aa0e4c 100644 --- a/examples/online_serving/openai_embedding_long_text/service.sh +++ b/examples/online_serving/openai_embedding_long_text/service.sh @@ -120,7 +120,7 @@ echo " - API Key: $API_KEY" echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN" echo "" echo "🧪 Test the server with:" -echo " python examples/online_serving/openai_embedding_long_text_client.py" +echo " python examples/online_serving/openai_embedding_long_text/client.py" echo "" echo "📚 Enhanced features enabled:" echo " ✅ Intelligent native pooling type detection" diff --git a/examples/online_serving/pooling/README.md b/examples/online_serving/pooling/README.md new file mode 100644 index 0000000000000..2c271b6a32bc2 --- /dev/null +++ b/examples/online_serving/pooling/README.md @@ -0,0 +1,49 @@ +# Pooling models + +## Cohere rerank usage + +```bash +python examples/online_serving/pooling/cohere_rerank_client.py +``` + +## Jinaai rerank usage + +```bash +python examples/online_serving/pooling/jinaai_rerank_client.py +``` + +## Named Entity Recognition (NER) usage + +```bash +python examples/online_serving/pooling/ner.py +``` + +## Openai chat embedding for multimodal usage + +```bash +python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py +``` + +## Openai classification usage + +```bash +python examples/online_serving/pooling/openai_classification_client.py +``` + +## Openai embedding usage + +```bash +python examples/online_serving/pooling/openai_embedding_client.py +``` + +## Openai embedding matryoshka dimensions usage + +```bash +python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py +``` + +## Openai pooling usage + +```bash +python examples/online_serving/pooling/openai_pooling_client.py +``` diff --git a/examples/online_serving/cohere_rerank_client.py b/examples/online_serving/pooling/cohere_rerank_client.py similarity index 100% rename from examples/online_serving/cohere_rerank_client.py rename to examples/online_serving/pooling/cohere_rerank_client.py diff --git a/examples/online_serving/jinaai_rerank_client.py b/examples/online_serving/pooling/jinaai_rerank_client.py similarity index 100% rename from examples/online_serving/jinaai_rerank_client.py rename to examples/online_serving/pooling/jinaai_rerank_client.py diff --git a/examples/online_serving/pooling/ner.py b/examples/online_serving/pooling/ner.py new file mode 100644 index 0000000000000..9ec2bd45a0fe5 --- /dev/null +++ b/examples/online_serving/pooling/ner.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER + +""" +Example online usage of Pooling API for Named Entity Recognition (NER). + +Run `vllm serve <model> --runner pooling` +to start up the server in vLLM. e.g. + +vllm serve boltuix/NeuroBERT-NER +""" + +import argparse + +import requests +import torch + + +def post_http_request(prompt: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER") + + return parser.parse_args() + + +def main(args): + from transformers import AutoConfig, AutoTokenizer + + api_url = f"http://{args.host}:{args.port}/pooling" + model_name = args.model + + # Load tokenizer and config + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) + label_map = config.id2label + + # Input text + text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025." + prompt = {"model": model_name, "input": text} + + pooling_response = post_http_request(prompt=prompt, api_url=api_url) + + # Run inference + output = pooling_response.json()["data"][0] + logits = torch.tensor(output["data"]) + predictions = logits.argmax(dim=-1) + inputs = tokenizer(text, return_tensors="pt") + + # Map predictions to labels + tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) + labels = [label_map[p.item()] for p in predictions] + assert len(tokens) == len(predictions) + + # Print results + for token, label in zip(tokens, labels): + if token not in tokenizer.all_special_tokens: + print(f"{token:15} → {label}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py similarity index 92% rename from examples/online_serving/openai_chat_embedding_client_for_multimodal.py rename to examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py index 771ad8511e972..30cb3325b9b18 100644 --- a/examples/online_serving/openai_chat_embedding_client_for_multimodal.py +++ b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py @@ -1,5 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 +"""Example Python client for multimodal embedding API using vLLM API server +NOTE: + start a supported multimodal embeddings model server with `vllm serve`, e.g. + vllm serve TIGER-Lab/VLM2Vec-Full --runner pooling --trust_remote_code --max_model_len=1024 +""" import argparse import base64 diff --git a/examples/online_serving/openai_classification_client.py b/examples/online_serving/pooling/openai_classification_client.py similarity index 86% rename from examples/online_serving/openai_classification_client.py rename to examples/online_serving/pooling/openai_classification_client.py index b10e7acbd26c1..d8dc2ef001112 100644 --- a/examples/online_serving/openai_classification_client.py +++ b/examples/online_serving/pooling/openai_classification_client.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Example Python client for classification API using vLLM API server +NOTE: + start a supported classification model server with `vllm serve`, e.g. + vllm serve jason9693/Qwen2.5-1.5B-apeach +""" import argparse import pprint diff --git a/examples/online_serving/openai_embedding_client.py b/examples/online_serving/pooling/openai_embedding_client.py similarity index 82% rename from examples/online_serving/openai_embedding_client.py rename to examples/online_serving/pooling/openai_embedding_client.py index 6bc390861e2ee..f5f6820d07d73 100644 --- a/examples/online_serving/openai_embedding_client.py +++ b/examples/online_serving/pooling/openai_embedding_client.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Example Python client for embedding API using vLLM API server +NOTE: + start a supported embeddings model server with `vllm serve`, e.g. + vllm serve intfloat/e5-small +""" from openai import OpenAI diff --git a/examples/online_serving/openai_embedding_matryoshka_fy.py b/examples/online_serving/pooling/openai_embedding_matryoshka_fy.py similarity index 100% rename from examples/online_serving/openai_embedding_matryoshka_fy.py rename to examples/online_serving/pooling/openai_embedding_matryoshka_fy.py diff --git a/examples/online_serving/openai_pooling_client.py b/examples/online_serving/pooling/openai_pooling_client.py similarity index 89% rename from examples/online_serving/openai_pooling_client.py rename to examples/online_serving/pooling/openai_pooling_client.py index 95555d41cbea5..569015746b128 100644 --- a/examples/online_serving/openai_pooling_client.py +++ b/examples/online_serving/pooling/openai_pooling_client.py @@ -4,7 +4,9 @@ Example online usage of Pooling API. Run `vllm serve <model> --runner pooling` -to start up the server in vLLM. +to start up the server in vLLM. e.g. + +vllm serve internlm/internlm2-1_8b-reward --trust-remote-code """ import argparse @@ -23,7 +25,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach") + parser.add_argument("--model", type=str, default="internlm/internlm2-1_8b-reward") return parser.parse_args() diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index 31301e0042cf4..611a7cbc89fa2 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -10,18 +10,19 @@ import requests # multimodal data. In this specific case this example will take a geotiff # image as input, process it using the multimodal data processor, and # perform inference. -# Reuirements : +# Requirements : # - install plugin at: # https://github.com/christian-pinto/prithvi_io_processor_plugin # - start vllm in serving mode with the below args # --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' +# --model-impl terratorch # --task embed --trust-remote-code # --skip-tokenizer-init --enforce-eager -# --io-processor-plugin prithvi_to_tiff_india +# --io-processor-plugin prithvi_to_tiff def main(): - image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501 + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 server_endpoint = "http://localhost:8000/pooling" request_payload_url = { diff --git a/examples/online_serving/prometheus_grafana/grafana.json b/examples/online_serving/prometheus_grafana/grafana.json index 3488956a5b24c..37abc9de926fd 100644 --- a/examples/online_serving/prometheus_grafana/grafana.json +++ b/examples/online_serving/prometheus_grafana/grafana.json @@ -402,7 +402,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "includeNullMetadata": false, "instant": false, @@ -418,7 +418,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.95, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.95, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "hide": false, "includeNullMetadata": false, @@ -435,7 +435,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.9, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.9, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "hide": false, "includeNullMetadata": false, @@ -452,7 +452,7 @@ }, "disableTextWrap": false, "editorMode": "builder", - "expr": "histogram_quantile(0.5, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", + "expr": "histogram_quantile(0.5, sum by(le) (rate(vllm:inter_token_latency_seconds_bucket{model_name=\"$model_name\"}[$__rate_interval])))", "fullMetaSearch": false, "hide": false, "includeNullMetadata": false, @@ -468,7 +468,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "rate(vllm:time_per_output_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_per_output_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", + "expr": "rate(vllm:inter_token_latency_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:inter_token_latency_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", "hide": false, "instant": false, "legendFormat": "Mean", @@ -476,7 +476,7 @@ "refId": "E" } ], - "title": "Time Per Output Token Latency", + "title": "Inter Token Latency", "type": "timeseries" }, { diff --git a/examples/online_serving/structured_outputs/structured_outputs.py b/examples/online_serving/structured_outputs/structured_outputs.py index 2a8f4637260c2..3ea6c73e90e8f 100644 --- a/examples/online_serving/structured_outputs/structured_outputs.py +++ b/examples/online_serving/structured_outputs/structured_outputs.py @@ -86,7 +86,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = { "content": "Classify this sentiment: vLLM is wonderful!", } ], - "extra_body": {"guided_choice": ["positive", "negative"]}, + "extra_body": {"structured_outputs": {"choice": ["positive", "negative"]}}, }, "regex": { "messages": [ @@ -96,7 +96,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = { } ], "extra_body": { - "guided_regex": r"[a-z0-9.]{1,20}@\w{6,10}\.com\n", + "structured_outputs": {"regex": r"[a-z0-9.]{1,20}@\w{6,10}\.com\n"}, }, }, "json": { @@ -122,7 +122,8 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = { } ], "extra_body": { - "guided_grammar": """ + "structured_outputs": { + "grammar": """ root ::= select_statement select_statement ::= "SELECT " column " from " table " where " condition @@ -135,6 +136,7 @@ condition ::= column "= " number number ::= "1 " | "2 " """, + } }, }, "structural_tag": { diff --git a/examples/tool_chat_template_phi4_mini.jinja b/examples/tool_chat_template_phi4_mini.jinja index 83886762c2893..6f40c38c20644 100644 --- a/examples/tool_chat_template_phi4_mini.jinja +++ b/examples/tool_chat_template_phi4_mini.jinja @@ -9,7 +9,7 @@ <|system|> {{ system_message }} {%- if tools %} -In addition to plain text responses, you can chose to call one or more of the provided functions. +In addition to plain text responses, you can choose to call one or more of the provided functions. Use the following rule to decide when to call a function: * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so @@ -19,7 +19,7 @@ If you decide to call functions: * prefix function calls with functools marker (no closing marker required) * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples - * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * respect the argument type formatting. E.g., if the type is number and format is float, write value 7 as 7.0 * make sure you pick the right functions that match the user intent diff --git a/find_cuda_init.py b/find_cuda_init.py deleted file mode 100644 index 308fc6fc2d61c..0000000000000 --- a/find_cuda_init.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import importlib -import traceback -from typing import Callable -from unittest.mock import patch - - -def find_cuda_init(fn: Callable[[], object]) -> None: - """ - Helper function to debug CUDA re-initialization errors. - - If `fn` initializes CUDA, prints the stack trace of how this happens. - """ - from torch.cuda import _lazy_init - - stack = None - - def wrapper(): - nonlocal stack - stack = traceback.extract_stack() - return _lazy_init() - - with patch("torch.cuda._lazy_init", wrapper): - fn() - - if stack is not None: - print("==== CUDA Initialized ====") - print("".join(traceback.format_list(stack)).strip()) - print("==========================") - - -if __name__ == "__main__": - find_cuda_init( - lambda: importlib.import_module("vllm.model_executor.models.llava")) diff --git a/mkdocs.yaml b/mkdocs.yaml index 507a80c41e8b4..6f2be65a18af8 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -79,6 +79,7 @@ plugins: - "re:vllm\\._.*" # Internal modules - "vllm.third_party" - "vllm.vllm_flash_attn" + - !ENV [API_AUTONAV_EXCLUDE, "re:^$"] # Match nothing by default - mkdocstrings: handlers: python: diff --git a/pyproject.toml b/pyproject.toml index e63f8aeae2787..fe55461db00be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,6 @@ follow_imports = "silent" # move the directory here and remove it from tools/mypy.sh files = [ "vllm/*.py", - "vllm/adapter_commons", "vllm/assets", "vllm/entrypoints", "vllm/core", @@ -145,6 +144,7 @@ skip_gitignore = true [tool.pytest.ini_options] markers = [ + "slow_test", "skip_global_cleanup", "core_model: enable this model test in each PR instead of only nightly", "hybrid_model: models that contain mamba layers (including pure SSM and hybrid architectures)", @@ -228,6 +228,7 @@ fo = "fo" ba = "ba" [tool.typos.type.py.extend-words] +ba = "ba" [tool.typos.type.cpp] extend-glob = ["*.cu"] @@ -344,3 +345,6 @@ extend-ignore-re = [] windo = "windo" [tool.typos.type.vimscript.extend-words] + +[tool.uv] +no-build-isolation-package = ["torch"] diff --git a/requirements/common.txt b/requirements/common.txt index e21abfb9a30bd..b8665104bd09a 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -20,12 +20,11 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.11.3 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" -outlines_core == 0.2.10 ; platform_machine != "s390x" -outlines == 0.1.11 ; platform_machine == "s390x" +outlines_core == 0.2.11 # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 -xgrammar == 0.1.21; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" +xgrammar == 0.1.23; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/requirements/kv_connectors.txt b/requirements/kv_connectors.txt index 262675a231206..3b610e0d97363 100644 --- a/requirements/kv_connectors.txt +++ b/requirements/kv_connectors.txt @@ -1 +1,2 @@ -lmcache \ No newline at end of file +lmcache +nixl >= 0.5.1 # Required for disaggregated prefill diff --git a/requirements/neuron.txt b/requirements/neuron.txt deleted file mode 100644 index 7df478eddde3f..0000000000000 --- a/requirements/neuron.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Common dependencies --r common.txt - -# Dependencies for Neuron devices -packaging>=24.2 -setuptools>=77.0.3,<80.0.0 -torch-neuronx >= 2.5.0 -neuronx-cc>=2.0.0a0 -torchvision # Required for Llama3.2 multimodal image preprocessing diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index affe562c24f6b..a86a8ab6df149 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -14,3 +14,4 @@ setuptools-scm>=8 wheel jinja2>=3.1.6 amdsmi==6.2.4 +timm>=1.0.17 diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index 25f950a99eceb..869fb28c3d85c 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -1,5 +1,6 @@ # Common dependencies -r common.txt +tblib==3.1.0 # entrypoints test # librosa==0.10.2.post1 # required by audio tests in entrypoints/openai diff --git a/requirements/rocm.txt b/requirements/rocm.txt index c3bb65b70a0b8..c129dd345c81a 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -8,7 +8,7 @@ numba == 0.61.2; python_version > '3.9' boto3 botocore datasets -ray>=2.10.0,<2.45.0 +ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. peft pytest-asyncio tensorizer==2.10.1 @@ -17,4 +17,5 @@ setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 -conch-triton-kernels==1.2.1 \ No newline at end of file +conch-triton-kernels==1.2.1 +timm>=1.0.17 \ No newline at end of file diff --git a/requirements/test.in b/requirements/test.in index 5b1688c76c954..451bd73879107 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -6,6 +6,7 @@ pytest-asyncio pytest-rerunfailures pytest-shard pytest-timeout +pytest-cov # testing utils backoff # required for phi4mm test @@ -21,6 +22,7 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests +tblib # for pickling test exceptions timm >=1.0.17 # required for internvl and gemma3n-mm test torch==2.8.0 torchaudio==2.8.0 @@ -53,5 +55,5 @@ runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 -terratorch==1.1rc2 # required for PrithviMAE test decord==0.6.0 +terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test diff --git a/requirements/test.txt b/requirements/test.txt index 0b728ebfb0071..39040f210b2fd 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -135,9 +135,11 @@ colorful==0.5.6 # via ray contourpy==1.3.0 # via matplotlib +coverage==7.10.6 + # via pytest-cov cramjam==2.9.0 # via fastparquet -cupy-cuda12x==13.3.0 +cupy-cuda12x==13.6.0 # via ray cycler==0.12.1 # via matplotlib @@ -686,7 +688,9 @@ platformdirs==4.3.6 plotly==5.24.1 # via genai-perf pluggy==1.5.0 - # via pytest + # via + # pytest + # pytest-cov polars==1.29.0 # via mteb pooch==1.8.2 @@ -786,6 +790,7 @@ pytest==8.3.5 # buildkite-test-collector # genai-perf # pytest-asyncio + # pytest-cov # pytest-forked # pytest-mock # pytest-rerunfailures @@ -796,6 +801,8 @@ pytest==8.3.5 # terratorch pytest-asyncio==0.24.0 # via -r requirements/test.in +pytest-cov==6.3.0 + # via -r requirements/test.in pytest-forked==1.6.0 # via -r requirements/test.in pytest-mock==3.14.0 @@ -1032,6 +1039,8 @@ tabledata==1.3.3 # via pytablewriter tabulate==0.9.0 # via sacrebleu +tblib==3.1.0 + # via -r requirements/test.in tcolorpy==0.1.6 # via pytablewriter tenacity==9.0.0 @@ -1042,7 +1051,7 @@ tensorboardx==2.6.4 # via lightning tensorizer==2.10.1 # via -r requirements/test.in -terratorch==1.1rc2 +terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn diff --git a/requirements/xpu.txt b/requirements/xpu.txt index 4607c3efdf14c..74f5b05b2382a 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -10,10 +10,10 @@ wheel jinja2>=3.1.6 datasets # for benchmark scripts numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding ---extra-index-url=https://download.pytorch.org/whl/xpu +nixl==0.3.0 # for PD disaggregation torch==2.8.0+xpu torchaudio torchvision -pytorch-triton-xpu ---extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -intel-extension-for-pytorch==2.8.10+xpu +--extra-index-url=https://download.pytorch.org/whl/xpu + +intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.8.10.post0%2Bxpu-cp312-cp312-linux_x86_64.whl diff --git a/setup.py b/setup.py index ffe8ec4e79af7..e4c40d22b928d 100644 --- a/setup.py +++ b/setup.py @@ -56,8 +56,6 @@ elif (sys.platform.startswith("linux") and torch.version.cuda is None # fallback to cpu VLLM_TARGET_DEVICE = "cpu" -MAIN_CUDA_VERSION = "12.8" - def is_sccache_available() -> bool: return which("sccache") is not None and \ @@ -413,8 +411,7 @@ def _no_device() -> bool: def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None - return (VLLM_TARGET_DEVICE == "cuda" and has_cuda - and not (_is_neuron() or _is_tpu())) + return (VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu()) def _is_hip() -> bool: @@ -422,10 +419,6 @@ def _is_hip() -> bool: or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None -def _is_neuron() -> bool: - return VLLM_TARGET_DEVICE == "neuron" - - def _is_tpu() -> bool: return VLLM_TARGET_DEVICE == "tpu" @@ -470,25 +463,6 @@ def get_rocm_version(): return None -def get_neuronxcc_version(): - import sysconfig - site_dir = sysconfig.get_paths()["purelib"] - version_file = os.path.join(site_dir, "neuronxcc", "version", - "__init__.py") - - # Check if the command was executed successfully - with open(version_file) as fp: - content = fp.read() - - # Extract the version using a regular expression - match = re.search(r"__version__ = '(\S+)'", content) - if match: - # Return the version string - return match.group(1) - else: - raise RuntimeError("Could not find Neuron version in the output") - - def get_nvcc_cuda_version() -> Version: """Get the CUDA version from nvcc. @@ -531,7 +505,7 @@ def get_vllm_version() -> str: version += f"{sep}precompiled" else: cuda_version = str(get_nvcc_cuda_version()) - if cuda_version != MAIN_CUDA_VERSION: + if cuda_version != envs.VLLM_MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] # skip this for source tarball, required for pypi if "sdist" not in sys.argv: @@ -539,14 +513,8 @@ def get_vllm_version() -> str: elif _is_hip(): # Get the Rocm Version rocm_version = get_rocm_version() or torch.version.hip - if rocm_version and rocm_version != MAIN_CUDA_VERSION: + if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION: version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}" - elif _is_neuron(): - # Get the Neuron version - neuron_version = str(get_neuronxcc_version()) - if neuron_version != MAIN_CUDA_VERSION: - neuron_version_str = neuron_version.replace(".", "")[:3] - version += f"{sep}neuron{neuron_version_str}" elif _is_tpu(): version += f"{sep}tpu" elif _is_cpu(): @@ -591,8 +559,6 @@ def get_requirements() -> list[str]: requirements = modified_requirements elif _is_hip(): requirements = _read_requirements("rocm.txt") - elif _is_neuron(): - requirements = _read_requirements("neuron.txt") elif _is_tpu(): requirements = _read_requirements("tpu.txt") elif _is_cpu(): @@ -601,7 +567,7 @@ def get_requirements() -> list[str]: requirements = _read_requirements("xpu.txt") else: raise ValueError( - "Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.") + "Unsupported platform, please use CUDA, ROCm, or CPU.") return requirements @@ -688,13 +654,15 @@ setup( "bench": ["pandas", "datasets"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], - "runai": - ["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"], + "runai": [ + "runai-model-streamer >= 0.14.0", "runai-model-streamer-gcs", + "google-cloud-storage", "runai-model-streamer-s3", "boto3" + ], "audio": ["librosa", "soundfile", "mistral_common[audio]"], # Required for audio processing "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile - "flashinfer": ["flashinfer-python==0.2.14.post1"], + "flashinfer": ["flashinfer-python==0.3.1"], # Optional deps for AMD FP4 quantization support "petit-kernel": ["petit-kernel"], }, diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py deleted file mode 100644 index ec6b20f5e04b9..0000000000000 --- a/tests/async_engine/api_server_async_engine.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""vllm.entrypoints.api_server with some extra logging for testing.""" -from collections.abc import Iterable -from typing import Any - -import uvicorn -from fastapi.responses import JSONResponse, Response - -import vllm.entrypoints.api_server -import vllm.envs as envs -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.utils import FlexibleArgumentParser - -app = vllm.entrypoints.api_server.app - - -class AsyncLLMEngineWithStats(AsyncLLMEngine): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._num_aborts = 0 - - async def _engine_abort(self, request_ids: Iterable[str]): - ids = list(request_ids) - self._num_aborts += len(ids) - await super()._engine_abort(ids) - - def testing_stats(self) -> dict[str, Any]: - return {"num_aborted_requests": self._num_aborts} - - -@app.get("/stats") -def stats() -> Response: - """Get the statistics of the engine.""" - return JSONResponse(engine.testing_stats()) - - -if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--port", type=int, default=8000) - parser = AsyncEngineArgs.add_cli_args(parser) - args = parser.parse_args() - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngineWithStats.from_engine_args(engine_args) - vllm.entrypoints.api_server.engine = engine - uvicorn.run(app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE) diff --git a/tests/async_engine/conftest.py b/tests/async_engine/conftest.py deleted file mode 100644 index 375b248ebedaa..0000000000000 --- a/tests/async_engine/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py deleted file mode 100644 index 90f63e7ea17db..0000000000000 --- a/tests/async_engine/test_api_server.py +++ /dev/null @@ -1,113 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import subprocess -import sys -import time -from multiprocessing import Pool -from pathlib import Path - -import pytest -import requests - - -def _query_server(prompt: str, max_tokens: int = 5) -> dict: - response = requests.post("http://localhost:8000/generate", - json={ - "prompt": prompt, - "max_tokens": max_tokens, - "temperature": 0, - "ignore_eos": True - }) - response.raise_for_status() - return response.json() - - -def _query_server_long(prompt: str) -> dict: - return _query_server(prompt, max_tokens=500) - - -@pytest.fixture -def api_server(distributed_executor_backend: str): - script_path = Path(__file__).parent.joinpath( - "api_server_async_engine.py").absolute() - commands = [ - sys.executable, - "-u", - str(script_path), - "--model", - "facebook/opt-125m", - "--host", - "127.0.0.1", - "--distributed-executor-backend", - distributed_executor_backend, - ] - - # API Server Test Requires V0. - my_env = os.environ.copy() - my_env["VLLM_USE_V1"] = "0" - uvicorn_process = subprocess.Popen(commands, env=my_env) - yield - uvicorn_process.terminate() - - -@pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"]) -def test_api_server(api_server, distributed_executor_backend: str): - """ - Run the API server and test it. - - We run both the server and requests in separate processes. - - We test that the server can handle incoming requests, including - multiple requests at the same time, and that it can handle requests - being cancelled without crashing. - """ - with Pool(32) as pool: - # Wait until the server is ready - prompts = ["warm up"] * 1 - result = None - while not result: - try: - for r in pool.map(_query_server, prompts): - result = r - break - except requests.exceptions.ConnectionError: - time.sleep(1) - - # Actual tests start here - # Try with 1 prompt - for result in pool.map(_query_server, prompts): - assert result - - num_aborted_requests = requests.get( - "http://localhost:8000/stats").json()["num_aborted_requests"] - assert num_aborted_requests == 0 - - # Try with 100 prompts - prompts = ["test prompt"] * 100 - for result in pool.map(_query_server, prompts): - assert result - - with Pool(32) as pool: - # Cancel requests - prompts = ["canceled requests"] * 100 - pool.map_async(_query_server_long, prompts) - time.sleep(0.01) - pool.terminate() - pool.join() - - # check cancellation stats - # give it some time to update the stats - time.sleep(1) - - num_aborted_requests = requests.get( - "http://localhost:8000/stats").json()["num_aborted_requests"] - assert num_aborted_requests > 0 - - # check that server still runs after cancellations - with Pool(32) as pool: - # Try with 100 prompts - prompts = ["test prompt after canceled"] * 100 - for result in pool.map(_query_server, prompts): - assert result diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py deleted file mode 100644 index 1851eeeda7905..0000000000000 --- a/tests/async_engine/test_request_tracker.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.engine.async_llm_engine import RequestTracker -from vllm.outputs import RequestOutput - - -@pytest.mark.asyncio -async def test_request_tracker(): - tracker = RequestTracker() - stream_1 = tracker.add_request("1") - assert tracker.new_requests_event.is_set() - await tracker.wait_for_new_requests() - new, aborted = tracker.get_new_and_aborted_requests() - assert not tracker.new_requests_event.is_set() - assert len(new) == 1 - assert new[0]["request_id"] == "1" - assert not aborted - assert not stream_1.finished - - stream_2 = tracker.add_request("2") - stream_3 = tracker.add_request("3") - assert tracker.new_requests_event.is_set() - await tracker.wait_for_new_requests() - new, aborted = tracker.get_new_and_aborted_requests() - assert not tracker.new_requests_event.is_set() - assert len(new) == 2 - assert new[0]["request_id"] == "2" - assert new[1]["request_id"] == "3" - assert not aborted - assert not stream_2.finished - assert not stream_3.finished - - # request_ids must be unique - with pytest.raises(KeyError): - tracker.add_request("1") - assert not tracker.new_requests_event.is_set() - - tracker.abort_request("1") - new, aborted = tracker.get_new_and_aborted_requests() - assert len(aborted) == 1 - assert "1" in aborted - assert not new - assert stream_1.finished - - stream_4 = tracker.add_request("4") - tracker.abort_request("4") - assert tracker.new_requests_event.is_set() - await tracker.wait_for_new_requests() - new, aborted = tracker.get_new_and_aborted_requests() - # aborted new requests will cancel each other out - - # there's no need for them to propagate into the - # engine - assert not aborted - assert not new - assert stream_4.finished - - stream_5 = tracker.add_request("5") - assert tracker.new_requests_event.is_set() - tracker.process_request_output( - RequestOutput("2", "output", [], [], [], finished=True)) - await tracker.wait_for_new_requests() - new, aborted = tracker.get_new_and_aborted_requests() - assert not tracker.new_requests_event.is_set() - assert not aborted - assert len(new) == 1 - assert new[0]["request_id"] == "5" - assert stream_2.finished - assert not stream_5.finished diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index a3b09cc817917..fba18f197074b 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -62,6 +62,8 @@ def _fix_prompt_embed_outputs( @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("async_scheduling", [True, False]) +@pytest.mark.parametrize("model_executor", ["uni", "mp"]) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models( monkeypatch: pytest.MonkeyPatch, @@ -70,6 +72,8 @@ def test_models( backend: str, max_tokens: int, enforce_eager: bool, + async_scheduling: bool, + model_executor: str, enable_prompt_embeds: bool, ) -> None: @@ -77,6 +81,12 @@ def test_models( "VLLM_USE_V1") and envs.VLLM_USE_V1: pytest.skip("enable_prompt_embeds is not supported in v1.") + if not envs.VLLM_USE_V1: + if async_scheduling: + pytest.skip("async_scheduling only supported in v1.") + if model_executor != "uni": + pytest.skip("only test uniproc executor for v0.") + if backend == "XFORMERS" and model == "google/gemma-2-2b-it": pytest.skip( f"{backend} does not support gemma2 with full context length.") @@ -98,11 +108,15 @@ def test_models( prompt_embeds = hf_model.get_prompt_embeddings( example_prompts) - with VllmRunner(model, - max_model_len=8192, - enforce_eager=enforce_eager, - enable_prompt_embeds=enable_prompt_embeds, - gpu_memory_utilization=0.7) as vllm_model: + with VllmRunner( + model, + max_model_len=8192, + enforce_eager=enforce_eager, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, + async_scheduling=async_scheduling, + distributed_executor_backend=model_executor, + ) as vllm_model: if enable_prompt_embeds: vllm_outputs = vllm_model.generate_greedy( prompt_embeds, max_tokens) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py deleted file mode 100644 index db2fa2f6bef6f..0000000000000 --- a/tests/basic_correctness/test_preemption.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Compare the short outputs of HF and vLLM when using greedy sampling. - -VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test. - -Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 -pytest tests/basic_correctness/test_preemption.py`. -""" -import pytest -from prometheus_client import REGISTRY - -import vllm.envs as envs -from vllm import SamplingParams -from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, - ENABLE_ARTIFICIAL_PREEMPT) - -from ..models.utils import check_outputs_equal - -MODELS = [ - "distilbert/distilgpt2", -] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - We should enable this for V1, but VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT, - so use VLLM_USE_V1=0 for all tests in the file. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -@pytest.fixture(scope="module", autouse=True) -def check_settings(): - assert ENABLE_ARTIFICIAL_PREEMPT is True, ( - "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1." - "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 " - "pytest tests/basic_correctness/test_preemption.py`") - - -@pytest.fixture -def distributed_executor_backend() -> str: - # When SPMD worker is used, use distributed_executor_backend="ray" - # to test delta input optimization works with preemption. - return "ray" if envs.VLLM_USE_RAY_SPMD_WORKER else "mp" - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [96]) -@pytest.mark.parametrize("chunked_prefill_token_size", [16]) -def test_chunked_prefill_recompute( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, - distributed_executor_backend: str, -) -> None: - """Ensure that chunked prefill works with preemption.""" - max_num_seqs = min(chunked_prefill_token_size, 256) - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_batched_tokens = chunked_prefill_token_size - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=enable_chunked_prefill, - max_num_seqs=max_num_seqs, - distributed_executor_backend=distributed_executor_backend, - disable_log_stats=False, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) - - for i in range(len(example_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_preemption( - caplog_vllm, - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - distributed_executor_backend: str, -) -> None: - """By default, recompute preemption is enabled""" - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner( - model, - dtype=dtype, - disable_log_stats=False, - distributed_executor_backend=distributed_executor_backend, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) - total_preemption = ( - vllm_model.llm.llm_engine.scheduler[0].num_cumulative_preemption) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - assert ("is preempted by PreemptionMode.RECOMPUTE mode because there " - "is not enough KV cache space." in caplog_vllm.text) - # Ensure the count bucket of request-level histogram metrics matches - # the number of requests as a simple sanity check to ensure metrics are - # generated - preemption_metrics = None - for m in REGISTRY.collect(): - if m.name == "vllm:num_preemptions": - preemption_metrics = m - assert preemption_metrics is not None - total_recorded_preemption = 0 - for sample in preemption_metrics.samples: - total_recorded_preemption += sample.value - assert total_preemption == total_recorded_preemption - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -def test_preemption_infeasible( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - distributed_executor_backend: str, -) -> None: - """Verify infeasible preemption request will be ignored.""" - BLOCK_SIZE = 16 - prefill_blocks = 2 - decode_blocks = max_tokens // BLOCK_SIZE - with vllm_runner( - model, - dtype=dtype, - block_size=BLOCK_SIZE, - # Not enough gpu blocks to complete a single sequence. - # preemption should happen, and the sequence should be - # ignored instead of hanging forever. - num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, - max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), - distributed_executor_backend=distributed_executor_backend, - ) as vllm_model: - sampling_params = SamplingParams(max_tokens=max_tokens, - ignore_eos=True) - req_outputs = vllm_model.llm.generate( - example_prompts, - sampling_params=sampling_params, - ) - - assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) - - # Verify the request is ignored and not hang. - for req_output in req_outputs: - outputs = req_output.outputs - assert len(outputs) == 1 - assert outputs[0].finish_reason == "length" diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py index bfcf274727e27..fafbef5f37180 100644 --- a/tests/benchmarks/test_serve_cli.py +++ b/tests/benchmarks/test_serve_cli.py @@ -45,3 +45,34 @@ def test_bench_serve(server): print(result.stderr) assert result.returncode == 0, f"Benchmark failed: {result.stderr}" + +@pytest.mark.benchmark +def test_bench_serve_chat(server): + command = [ + "vllm", + "bench", + "serve", + "--model", + MODEL_NAME, + "--host", + server.host, + "--port", + str(server.port), + "--dataset-name", + "random", + "--random-input-len", + "32", + "--random-output-len", + "4", + "--num-prompts", + "5", + "--endpoint", + "/v1/chat/completions", + "--backend", + "openai-chat", + ] + result = subprocess.run(command, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + assert result.returncode == 0, f"Benchmark failed: {result.stderr}" diff --git a/tests/ci_envs.py b/tests/ci_envs.py new file mode 100644 index 0000000000000..d16ecce1ef8dd --- /dev/null +++ b/tests/ci_envs.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +These envs only work for a small part of the tests, fix what you need! +""" + +import os +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + VLLM_CI_NO_SKIP: bool = False + VLLM_CI_DTYPE: Optional[str] = None + VLLM_CI_HEAD_DTYPE: Optional[str] = None + VLLM_CI_HF_DTYPE: Optional[str] = None + +environment_variables: dict[str, Callable[[], Any]] = { + # A model family has many models with the same architecture. + # By default, a model family tests only one model. + # Through this flag, all models can be tested. + "VLLM_CI_NO_SKIP": lambda: bool(int(os.getenv("VLLM_CI_NO_SKIP", "0"))), + # Allow changing the dtype used by vllm in tests + "VLLM_CI_DTYPE": lambda: os.getenv("VLLM_CI_DTYPE", None), + # Allow changing the head dtype used by vllm in tests + "VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None), + # Allow changing the head dtype used by transformers in tests + "VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None), +} + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) + + +def is_set(name: str): + """Check if an environment variable is explicitly set.""" + if name in environment_variables: + return name in os.environ + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/compile/backend.py b/tests/compile/backend.py index ace4d25534cdd..2c4287950dcfe 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -64,4 +64,8 @@ class TestBackend: num_pre = len(list(find_op_nodes(op, self.graph_pre_pass))) num_post = len(list(find_op_nodes(op, self.graph_post_pass))) assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph" - assert num_post > 0, f"Op {op.name()} not found in post-pass graph" \ No newline at end of file + assert num_post > 0, f"Op {op.name()} not found in post-pass graph" + + def op_count(self, op: OpOverload, before=False) -> int: + graph = self.graph_pre_pass if before else self.graph_post_pass + return len(list(find_op_nodes(op, graph))) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 97140a9db7af6..2454f85342eba 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -61,6 +61,16 @@ backend_configs = { "cudagraph_mode": "FULL_AND_PIECEWISE", }, specific_gpu_arch=(9, 0)), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": + BackendConfig(name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0)), # Cutlass MLA on Blackwell "CutlassMLA": BackendConfig( @@ -102,7 +112,7 @@ backend_configs = { test_params_full_cudagraph = [] # deepseek-ai/DeepSeek-V2-Lite with MLA -MLA_backends = ["FlashMLA", "CutlassMLA"] +MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] for mla_backend in MLA_backends: test_params_full_cudagraph.append( pytest.param( diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index f5e2d9ddb7528..5cfebfce9ea2a 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -4,9 +4,9 @@ Test (piecewise) compilation with a simple model where multiple submodules are compiled and graph captured separately. """ + import torch from torch import nn -from torch.library import Library from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter @@ -15,10 +15,9 @@ from vllm.compilation.decorators import (ignore_torch_compile, from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa +# This import automatically registers `torch.ops.silly.attention` +from .. import silly_attention # noqa: F401 BATCH_SIZE = 32 MLP_SIZE = 128 @@ -26,27 +25,6 @@ HIDDEN_SIZE = 1024 RANDOM_SEED = 0 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @support_torch_compile class ParentModel(nn.Module): @@ -134,7 +112,7 @@ class SimpleModelWithTwoGraphs(ParentModel): # Test will fail without set_model_tag here with error: # "ValueError: too many values to unpack (expected 3)" # This is because CompiledAttention and CompiledAttentionTwo - # have different implmentations but the same torch.compile + # have different implementations but the same torch.compile # cache dir will be used as default prefix is 'model_tag' with set_model_tag("attn_one"): self.attn_one = CompiledAttention( diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 2d1a72d44ec70..84f4945c82725 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -4,10 +4,10 @@ Test the piecewise compilation with a simple model so that we can exactly calculate the expected output and side effects. """ + import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile @@ -15,35 +15,9 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -global_counter = 0 - -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - global global_counter - global_counter += 1 - print(f"{global_counter=}") - out.copy_(q) - out[0] += 1 - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) +# This import automatically registers `torch.ops.silly.attention` +from ..silly_attention import get_global_counter, reset_global_counter @support_torch_compile @@ -59,8 +33,7 @@ class SillyModel(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overall effect: - x += 1 - x[0] += 2 + x = 3 * x + 19 global_counter += 2 """ x = x + 1 @@ -78,6 +51,7 @@ class SillyModel(nn.Module): @pytest.mark.parametrize("use_inductor", [True, False]) +@torch.inference_mode() def test_simple_piecewise_compile(use_inductor): assert VLLM_USE_V1 @@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor): model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() - global global_counter - global_counter = 0 + reset_global_counter() with set_forward_context( None, vllm_config=vllm_config, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, batch_descriptor=BatchDescriptor(num_tokens=2, )): output = model(input) - assert global_counter == 2 - assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + assert get_global_counter() == 2 + assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index bcfd0d834c5db..cba7517647e51 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -14,38 +14,15 @@ from typing import Any, Optional import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) +# This import automatically registers `torch.ops.silly.attention` +from .. import silly_attention # noqa: F401 @dataclass diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py new file mode 100644 index 0000000000000..13eb0bf4b1fa1 --- /dev/null +++ b/tests/compile/silly_attention.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared PyTorch custom silly attention for compilation tests. +Centralizes custom operation definitions to avoid duplicate registrations. +""" + +import torch +from torch.library import Library + +from vllm.utils import direct_register_custom_op + +# Shared library for all compilation test operations +# Using "silly" namespace to match existing test expectations +# import this file will automatically register +# torch ops for testing (like silly.attention) +silly_lib = Library("silly", "FRAGMENT") + +# Global counter that counts the number of times attention is invoked +_global_counter = 0 + + +def get_global_counter(): + """Get the current global counter value""" + return _global_counter + + +def reset_global_counter(): + """Reset the global counter to 0""" + global _global_counter + _global_counter = 0 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """ + Unified attention implementation that depends on + all inputs and affects the output. + Always increments a global counter that tests can use or ignore. + """ + global _global_counter + + # Always increment the global counter + _global_counter += 1 + + # Unified implementation that depends on all inputs + out.copy_(q + k + v) + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """Fake implementation for testing""" + return + + +# Register the unified attention operation +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 422cb94b036ca..fd2b1866e62e1 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -23,7 +23,7 @@ class TestSetting: fullgraph: bool -# we cannot afford testing the full Catesian product +# we cannot afford testing the full Cartesian product # of all models and all levels @pytest.mark.parametrize( "test_setting", @@ -62,8 +62,12 @@ class TestSetting: TestSetting( model="BAAI/bge-multilingual-gemma2", model_args=[ - "--runner", "pooling", "--dtype", "bfloat16", - "--max-model-len", "2048" + "--runner", + "pooling", + "--dtype", + "bfloat16", + "--max-model-len", + "2048", ], pp_size=1, tp_size=1, @@ -71,17 +75,15 @@ class TestSetting: method="encode", fullgraph=True, ), - # TODO: bert models are not supported in V1 yet - # # encoder-based embedding model (BERT) - # TestSetting( - # model="BAAI/bge-base-en-v1.5", - # model_args=["--runner", "pooling"], - # pp_size=1, - # tp_size=1, - # attn_backend="XFORMERS", - # method="encode", - # fullgraph=True, - # ), + TestSetting( + model="BAAI/bge-base-en-v1.5", + model_args=["--runner", "pooling"], + pp_size=1, + tp_size=1, + attn_backend="FLASH_ATTN", + method="encode", + fullgraph=True, + ), # vision language model TestSetting( model="microsoft/Phi-3.5-vision-instruct", @@ -92,7 +94,8 @@ class TestSetting: method="generate_with_image", fullgraph=False, ), - ]) + ], +) def test_compile_correctness( monkeypatch: pytest.MonkeyPatch, test_setting: TestSetting, diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 51f8ddd566d56..d73586d53ff3e 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import (ignore_torch_compile, @@ -10,36 +9,14 @@ from vllm.compilation.decorators import (ignore_torch_compile, from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa +# This import automatically registers `torch.ops.silly.attention` +from . import silly_attention # noqa: F401 BATCH_SIZE = 32 MLP_SIZE = 128 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @torch.inference_mode def run_model(vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode): @@ -151,7 +128,7 @@ def test_ignore_torch_compile_decorator(): run_model(vllm_config, mod_C, cudagraph_runtime_mode) -# Only enable torch.compile if +# Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=True @support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. kv_sharing_fast_prefill) @@ -173,7 +150,7 @@ class B(nn.Module): return x -# Only enable torch.compile if +# Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=False @support_torch_compile(enable_if=lambda vllm_config: not vllm_config. cache_config.kv_sharing_fast_prefill) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index c4229f93464ac..eedb9bdcd5299 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -15,9 +15,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity) + Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) from vllm.platforms import current_platform +from ..utils import override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -26,9 +27,9 @@ FP8_DTYPE = current_platform.fp8_dtype() class TestModel(torch.nn.Module): def __init__(self, hidden_size: int, eps: float, static: bool, - force_fp8_e4m3fnuz: bool, *args, **kwargs): + cuda_force_torch: bool, *args, **kwargs): super().__init__(*args, **kwargs) - self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz + self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN @@ -42,11 +43,12 @@ class TestModel(torch.nn.Module): torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(2) ] - self.fp8_linear = Fp8LinearOp( - force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, - act_quant_static=static, - act_quant_group_shape=group_shape, - ) + + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=static, + act_quant_group_shape=group_shape, + ) def forward(self, x): resid = torch.sqrt(x) @@ -81,11 +83,14 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) +# cuda_force_torch used to test torch code path on platforms that +# cutlass_fp8_supported() == True. +@pytest.mark.parametrize("cuda_force_torch", + [True, False] if cutlass_fp8_supported() else [True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - force_fp8_e4m3fnuz): + cuda_force_torch): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -102,7 +107,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, fusion_pass = FusionPass.instance(vllm_config) backend = TestBackend(noop_pass, fusion_pass) - model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz) + model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index dba668cfa16a6..6baf4bf83f499 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -40,13 +40,12 @@ backend_unfused: Optional[TestBackend] = None @pytest.mark.parametrize( "model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) -@pytest.mark.parametrize( - "use_triton_fa", [True, False] if current_platform.is_rocm() else [False]) +@pytest.mark.parametrize("use_triton_fa", [True, False]) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test CUDA and ROCm") -def test_attention_fusion(example_prompts, monkeypatch, model: str, - quant_key: QuantKey, use_triton_fa: bool): +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="V0 attn quant fusion only on ROCm") +def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, + quant_key: QuantKey, use_triton_fa: bool): # Clean Dynamo cache to avoid reusing other test cases # (for some reason the reset at the end is not enough) torch._dynamo.reset() @@ -69,13 +68,17 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, backend="tests.compile.test_fusion_attn.backend_unfused", custom_ops=["+quant_fp8"], ) - vllm_config = VllmConfig(compilation_config=compile_config) + vllm_config = VllmConfig(compilation_config=compile_config, + model_config=ModelConfig( + model=model, + dtype=torch.bfloat16, + )) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) llm = LLM(model, enforce_eager=True, compilation_config=compile_config, - gpu_memory_utilization=0.9, + gpu_memory_utilization=0.5, max_model_len=2048) sampling_params = SamplingParams(temperature=0.0, @@ -93,7 +96,11 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, backend="tests.compile.test_fusion_attn.backend", custom_ops=["+quant_fp8"], ) - vllm_config = VllmConfig(compilation_config=compile_config) + vllm_config = VllmConfig(compilation_config=compile_config, + model_config=ModelConfig( + model=model, + dtype=torch.bfloat16, + )) # AttnFusionPass needs attention layers to be registered in config upon init # so we initialize it during compilation. @@ -102,7 +109,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, llm2 = LLM(model, enforce_eager=True, compilation_config=compile_config, - gpu_memory_utilization=0.9, + gpu_memory_utilization=0.5, max_model_len=2048) # check support @@ -171,6 +178,8 @@ class AttentionQuantPatternModel(torch.nn.Module): cache_config=vllm_config.cache_config, prefix="model.layers.0.self_attn.attn", ) + self.attn._k_scale = self.attn._k_scale.to(device) + self.attn._v_scale = self.attn._v_scale.to(device) self.block_size = 16 @@ -188,7 +197,7 @@ class AttentionQuantPatternModel(torch.nn.Module): device=self.device, ) - def build_attn_metadata(self, batch_size: int): + def build_attn_metadata(self, batch_size: int, use_hnd: bool): """Initialize attention metadata.""" # Create common attn metadata @@ -205,10 +214,8 @@ class AttentionQuantPatternModel(torch.nn.Module): num_blocks = batch_size * max_blocks # Create dummy KV cache for FlashInfer TRTLLM - # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] - # Create kv_cache in HND layout and permute to NHD layout - # (later will be permuted back to HND layout in forward pass) + # - NHD: [num_blocks, block_size, num_kv_heads, head_size] + # - HND: [num_blocks, num_kv_heads, block_size, head_size] kv_cache = torch.zeros(num_blocks, 2, self.num_kv_heads, @@ -216,7 +223,17 @@ class AttentionQuantPatternModel(torch.nn.Module): self.head_size, dtype=self.kv_cache_dtype, device=self.device) - kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + if current_platform.is_rocm(): + # k/v as 1st dimention + if use_hnd: + kv_cache = kv_cache.permute(1, 0, 2, 3, 4) + else: + kv_cache = kv_cache.permute(1, 0, 3, 2, 4) + else: + # k/v as 2nd dimention + # Create kv_cache in HND layout and permute to NHD layout + # (later will be permuted back to HND layout in forward pass) + kv_cache = kv_cache.permute(0, 1, 3, 2, 4) self.attn.kv_cache = [kv_cache] # Build attn metadata @@ -296,28 +313,51 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): out_dtype=attn_output.dtype) -@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) +if current_platform.is_cuda(): + MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel), + ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel)] + HEADS = [(64, 8), (40, 8)] +elif current_platform.is_rocm(): + MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV", + TestAttentionFp8StaticQuantPatternModel)] + HEADS = [(32, 8), (40, 8)] +else: + MODELS = [] + HEADS = [] + + +@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) -@pytest.mark.parametrize("batch_size", [7, 256, 533]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("model_name, model_class", - [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - TestAttentionFp8StaticQuantPatternModel), - ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - TestAttentionNvfp4QuantPatternModel)]) -@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) -@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +@pytest.mark.parametrize("batch_size", + [7, 256, 533] if current_platform.is_cuda() else [8]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("model_name, model_class", MODELS) +@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if + current_platform.is_cuda() else [_Backend.ROCM_FLASH]) +@pytest.mark.parametrize( + "split_attention", + [False, True] if current_platform.is_rocm() else [False]) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only test ROCm or CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), - reason="Only test on SM100(Blackwell)") +@pytest.mark.skipif(current_platform.is_cuda() + and not current_platform.is_device_capability((10, 0)), + reason="On CUDA only test on SM100(Blackwell)") +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only test ROCm or CUDA") def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, model_name: str, model_class: type[AttentionQuantPatternModel], - backend: _Backend, monkeypatch, dist_init): + backend: _Backend, split_attention: bool, + monkeypatch, dist_init): """Test AttentionStaticQuantPattern fusion pass""" monkeypatch.setenv("VLLM_USE_V1", "1") + if split_attention: + monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1") device = torch.device("cuda:0") torch.manual_seed(42) @@ -326,6 +366,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, model_config=ModelConfig( model=model_name, max_model_len=2048, + dtype=dtype, ), scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( @@ -368,7 +409,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, forward_ctx = get_forward_context() forward_ctx.attn_metadata = model_unfused.build_attn_metadata( - batch_size) + batch_size, use_hnd=split_attention) # Run model directly without compilation and fusion result_unfused = model_unfused(q, k, v) @@ -389,7 +430,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, model_fused = model_fused.to(device) forward_ctx = get_forward_context() - forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + forward_ctx.attn_metadata = model_fused.build_attn_metadata( + batch_size, use_hnd=split_attention) # Create test backend with fusion passes enabled noop_pass = NoOpEliminationPass(vllm_config) @@ -404,12 +446,19 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, assert model_compiled.attn._o_scale_float is None result_fused_1 = model_compiled(q, k, v) - # After the 1st round of the forward pass, output quant scale should be - # loaded into the attn layer's _o_scale_float, the 2nd round should - # reuse the loaded _o_scale_float - assert model_compiled.attn._o_scale_float is not None - result_fused_2 = model_compiled(q, k, v) - assert model_compiled.attn._o_scale_float is not None + if backend == _Backend.FLASHINFER: + # With the Flashinfer backend after the 1st round of the forward + # pass, output quant scale should be loaded into the attn layer's + # _o_scale_float, the 2nd round should reuse the loaded + # _o_scale_float + assert model_compiled.attn._o_scale_float is not None + result_fused_2 = model_compiled(q, k, v) + assert model_compiled.attn._o_scale_float is not None + + torch.testing.assert_close(result_unfused, + result_fused_2, + atol=1e-2, + rtol=1e-2) # Check attn fusion support quant_key = model_class.quant_key @@ -444,12 +493,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ "Attention should have output_block_scale after FP4 fusion" # noqa: E501 - # Check that results are closed + # Check that results are close torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2) - torch.testing.assert_close(result_unfused, - result_fused_2, - atol=1e-2, - rtol=1e-2) diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/test_noop_elimination.py new file mode 100644 index 0000000000000..242d531312675 --- /dev/null +++ b/tests/compile/test_noop_elimination.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm +from vllm.compilation.noop_elimination import NoOpEliminationPass +from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, + VllmConfig) + +from .backend import TestBackend + + +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("num_tokens", [256, 1024]) +@pytest.mark.parametrize("hidden_size", [64, 4096]) +def test_noop_elimination(dtype, num_tokens, hidden_size): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + + class Model(torch.nn.Module): + + def forward(self, x): + # Chain of reshapes + y = x.reshape(-1, 128, 32) + z = y.reshape(-1, 4096) + # No-op reshape + a = z.reshape(-1, 4096) + # Final reshape that should remain + b = a.reshape(-1, 128, 32) + # No-op slice + c = b[0:b.shape[0]] + # The pass should replace the result of this op with `c`. + d = torch.slice_scatter( + torch.ones_like(c), # Dummy tensor to be scattered into + c, # Source tensor + 0, # dim + 0, # start + c.shape[0], # end + ) + return d + + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_noop=True), + )) + with vllm.config.set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + + backend = TestBackend(noop_pass) + + model = Model() + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + ATOL, RTOL = (2e-3, 2e-3) + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + + # The no-op reshape and slice should be eliminated. + # The chain of reshapes should be fused into a single reshape. + assert backend.op_count(torch.ops.aten.reshape.default) == 1 + assert backend.op_count(torch.ops.aten.slice.Tensor) == 0 + assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0 + + +def test_non_noop_slice_preserved(): + """Ensure that a slice with end=-1 (dropping last row) is NOT eliminated. + + Regression test for a bug where end=-1 was treated like an inferred + dimension (reshape semantics) leading to incorrect elimination. + """ + torch.set_default_device("cuda") + x = torch.randn(16, 16) + + class SliceModel(torch.nn.Module): + + def forward(self, x): + base = x.clone() + src = torch.ones(15, 16) + y = torch.slice_scatter(base, src, dim=0, start=0, end=-1) + return x[0:-1, :], y + + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_noop=True), + )) + with vllm.config.set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + backend = TestBackend(noop_pass) + model = SliceModel() + ref = model(x) + compiled = torch.compile(model, backend=backend) + out = compiled(x) + torch.testing.assert_close(ref, out) + # The slice should remain (not a no-op). + assert backend.op_count(torch.ops.aten.slice.Tensor) == 1 + assert backend.op_count(torch.ops.aten.slice_scatter.default) == 1 diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index fcc2589e42116..ae190d25cad62 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import cast + import pytest import torch import vllm.envs as envs +from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant # yapf conflicts with isort for this block # yapf: disable @@ -17,9 +20,10 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, kNvfp4Quant) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) + Fp8LinearOp, cutlass_fp8_supported) from vllm.platforms import current_platform +from ..utils import override_cutlass_fp8_supported from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -32,7 +36,7 @@ def is_nvfp4_supported(): class TestSiluMulFp8QuantModel(torch.nn.Module): - def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs): + def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() self.wscale = torch.rand(1, dtype=torch.float32) @@ -40,11 +44,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - self.fp8_linear = Fp8LinearOp( - force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, - act_quant_static=True, - act_quant_group_shape=GroupShape.PER_TENSOR, - ) + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) def forward(self, x): y = self.silu_and_mul(x) @@ -63,24 +67,27 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): class TestSiluMulNvfp4QuantModel(torch.nn.Module): - def __init__(self, hidden_size: int, **kwargs): + def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() - self.w = torch.randint(256, (hidden_size, hidden_size // 2), - dtype=FP4_DTYPE) - self.wscale = torch.randn(hidden_size, - hidden_size // 16).to(dtype=FP8_DTYPE) - self.wscale2 = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) + + # create nvfp4 weight + w = torch.rand((hidden_size, hidden_size)) + self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w) + + # get global scale offline + _, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x)) + + self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale) def forward(self, x): y = self.silu_and_mul(x) - y_quant, y_block_scale = scaled_fp4_quant(y, 1 / self.scale) + y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale) out = cutlass_scaled_fp4_mm(a=y_quant, b=self.w, block_scale_a=y_block_scale, - block_scale_b=self.wscale, - alpha=self.scale * self.wscale2, + block_scale_b=self.w_block_scale, + alpha=self.alpha, out_dtype=y.dtype) return out @@ -91,21 +98,28 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): return [FUSED_OPS[kNvfp4Quant]] -@pytest.mark.parametrize("num_tokens", [64]) -@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize("num_tokens", [32, 64]) +@pytest.mark.parametrize("hidden_size", [128, 256]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize( - "model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] - if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]) -@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) + "model_class", + cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] + if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])) +# cuda_force_torch used to test torch code path on platforms that +# cutlass_fp8_supported() == True. +@pytest.mark.parametrize("cuda_force_torch", + [True, False] if cutlass_fp8_supported() else [True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, - force_fp8_e4m3fnuz): - if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz: +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class, + cuda_force_torch): + if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: pytest.skip("Duplicate tests for NVFP4") torch.set_default_device("cuda") - torch.set_default_dtype(torch.float16) + torch.set_default_dtype(dtype) + + x = torch.rand(num_tokens, hidden_size * 2) # Reshape pass is needed for the fusion pass to work config = VllmConfig() @@ -115,10 +129,10 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, backend = TestBackend(NoOpEliminationPass(config), fusion_pass) model = model_class(hidden_size=hidden_size, - force_fp8_e4m3fnuz=force_fp8_e4m3fnuz) + cuda_force_torch=cuda_force_torch, + x=x) # First dimension dynamic - x = torch.rand(num_tokens, hidden_size * 2) torch._dynamo.mark_dynamic(x, 0) result = model(x) @@ -127,10 +141,15 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, result2 = model2(x) # Check that it gives the same answer - torch.testing.assert_close(result[0].to(dtype=torch.float16), - result2[0].to(dtype=torch.float16), - atol=1e-3, - rtol=1e-3) + if model_class == TestSiluMulFp8QuantModel: + atol, rtol = 1e-3, 1e-3 + elif model_class == TestSiluMulNvfp4QuantModel: + atol, rtol = 1e-1, 1e-1 + + torch.testing.assert_close(result[0].to(dtype=dtype), + result2[0].to(dtype=dtype), + atol=atol, + rtol=rtol) # In pre-nodes, quant op should be present and fused kernels should not backend.check_before_ops(model.ops_in_model_before()) diff --git a/tests/conftest.py b/tests/conftest.py index 27db5422ceac2..0440e859fe02d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# ruff: noqa + +from tblib import pickling_support + +# Install support for pickling exceptions so that we can nicely propagate +# failures from tests running in a subprocess. +# This should be run before any custom exception subclasses are defined. +pickling_support.install() + +import http.server import json import math +import mimetypes import os +import socket import tempfile +import threading +from collections.abc import Generator from enum import Enum from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast @@ -32,6 +47,7 @@ from vllm.distributed import (cleanup_dist_env_and_memory, from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger +from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.sequence import Logprob @@ -1253,3 +1269,119 @@ def cli_config_file(): def cli_config_file_with_model(): """Return the path to the CLI config file with model.""" return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml") + + +class AssetHandler(http.server.BaseHTTPRequestHandler): + # _IMAGE_CACHE : Dict[str, bytes] = {} + + def log_message(self, *args, **kwargs): + pass + + def do_GET(self): + # Accepts paths like: /1280px-Venn_diagram_rgb.jpg + filename = self.path.lstrip("/") + if not filename or "." not in filename: + self.send_error(404, "Missing filename (expected /<name>.<ext>)") + return + + base, ext = filename.rsplit(".", 1) + ext = ext.lower() + + if ext not in ["jpg", "png"]: + self.send_error(404, f"Unsupported extension: .{ext}") + return + + try: + data = ImageAsset(base).read_bytes(ext=ext) + except Exception as e: + self.send_error(500, f"Failed to load asset: {ext} {base} {e} ") + return + + ctype, _ = mimetypes.guess_type(filename) + if ctype is None: + ctype = {"jpg": "image/jpg", "png": "image/png"}[ext] + self.send_response(200) + self.send_header("Content-Type", ctype) + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + +def _find_free_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +class LocalAssetServer: + + address: str + port: int + server: Optional[http.server.ThreadingHTTPServer] + thread: Optional[threading.Thread] + + def __init__(self, address: str = "127.0.0.1") -> None: + self.address = address + self.port = -1 + self.server = None + self.thread = None + + def __enter__(self): + self.port = _find_free_port() + self.server = http.server.ThreadingHTTPServer( + (self.address, self.port), AssetHandler) + self.thread = threading.Thread(target=self.server.serve_forever, + daemon=True) + self.thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.server: + self.server.shutdown() + del self.server + + if self.thread: + self.thread.join() + del self.thread + + if exc_type is None: + return None + + return False + + @property + def base_url(self) -> str: + assert self.port is not None + return f"http://{self.address}:{self.port}" + + def url_for(self, name: str) -> str: + """e.g., name='RGBA_comp.png' -> 'http://127.0.0.1:PORT/RGBA_comp.png'""" + return f"{self.base_url}/{name}" + + def get_image_asset(self, name: str) -> Image.Image: + return fetch_image(self.url_for(name)) + + +@pytest.fixture(scope="session") +def local_asset_server() -> Generator[LocalAssetServer, None, None]: + """ + Starts a thread based HTTP server bound to 127.0.0.1 on a random free port. + The server currently servers images at: + http://127.0.0.1:<port>/<name>.<ext> + """ + with LocalAssetServer() as srv: + yield srv + + +@pytest.fixture +def image_url(request, local_asset_server) -> str: + # request.param is one of the IMAGE_ASSETS filenames + name = request.param + return local_asset_server.url_for(name) + + +@pytest.fixture +def image_urls(request, local_asset_server) -> list[str]: + """Indirect fixture: takes a list of names, returns list of full URLs.""" + names: list[str] = request.param + return [local_asset_server.url_for(name) for name in names] diff --git a/tests/core/block/conftest.py b/tests/core/block/conftest.py deleted file mode 100644 index 6afe98d78ce81..0000000000000 --- a/tests/core/block/conftest.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - - -@pytest.fixture() -def should_do_global_cleanup_after_test() -> bool: - """Disable the global cleanup fixture for tests in this directory. This - provides a ~10x speedup for unit tests that don't load a model to GPU. - - This requires that tests in this directory clean up after themselves if they - use the GPU. - """ - return False diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py deleted file mode 100644 index e2c6c66b259c8..0000000000000 --- a/tests/core/block/e2e/conftest.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections.abc import Iterable -from typing import Callable, Optional - -import pytest - -from vllm import LLM -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.model_executor.utils import set_random_seed - - -@pytest.fixture -def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, seed) - - -@pytest.fixture -def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed) - - -def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - distinct_llm_kwargs, seed): - kwargs = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **distinct_llm_kwargs, - } - - def generator_inner(): - llm = LLM(**kwargs) - - set_random_seed(seed) - - yield llm - del llm - cleanup_dist_env_and_memory() - - for llm in generator_inner(): - yield llm - del llm - - -def get_text_from_llm_generator(llm_generator: Iterable[LLM], - prompts, - sampling_params, - llm_cb: Optional[Callable[[LLM], - None]] = None): - for llm in llm_generator: - if llm_cb: - llm_cb(llm) - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - text = [output.outputs[0].text for output in outputs] - del llm - - return text - - -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): - for llm in llm_generator: - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - token_ids = [output.outputs[0].token_ids for output in outputs] - del llm - - return token_ids diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py deleted file mode 100644 index 8de48ef59a013..0000000000000 --- a/tests/core/block/e2e/test_correctness.py +++ /dev/null @@ -1,479 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from itertools import cycle - -import pytest - -from vllm import SamplingParams - -from .conftest import get_token_ids_from_llm_generator - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Allow only 5 sequences of ~1024 tokens in worst case. - "block_size": 16, - "num_gpu_blocks_override": 5 * (64 + 1), - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "preemption_mode": "swap" -}, { - "preemption_mode": "recompute" -}]) -@pytest.mark.parametrize("batch_size", [10]) -@pytest.mark.parametrize("seed", [1]) -def test_block_manager_with_preemption(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify block manager produces same outputs even when there is preemption. - - This constructs two LLM, each with limited number of GPU blocks. The limit - is decided such that as the sequences in the batch grow, sequences must be - preempted and removed from cache. - - If the output token ids are equivalent, then we have confidence that the KV - cache is not corrupted. - - NOTE: We want a significant number of generated tokens so that any incorrect - KV mapping has time to build up error. - - NOTE(Kuntai): Though we have removed block manager v1, this test is still - useful as it asserts the behavior of block manager v2 (now it is called - SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we - keep this test. - """ - output_len = 1024 - temperature = 0.0 - - # We want to ensure equality even with preemption. - # We force the total block size to be 1 + cdiv(output_len, block_size) - # so that only one sequence can fit at a time (once the sequences grow). - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # Our prompts will generate 128 tokens; since the prompts themselves are - # small, we don't need much KV space beyond 128. - "max_model_len": 160, - - # skip cuda graph creation for fast test. - "enforce_eager": True, - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - "block_size": 16, - - # Allow only 2 sequences of ~128 tokens in worst case. - # Note 8 = 128/block_size - "num_gpu_blocks_override": 2 * (8 + 1), - }, - { - "block_size": 8, - - # Allow only 2 sequences of ~128 tokens in worst case. - # Note 16 = 128/block_size - "num_gpu_blocks_override": 2 * (16 + 2), - } - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "num_lookahead_slots": 0, -}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - # We run one test with block_size < lookahead_slots, one test with - # block_size > lookahead_slots - "num_lookahead_slots": 10, - "preemption_mode": "swap", - }, - { - "num_lookahead_slots": 10, - "preemption_mode": "recompute", - } - ]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, - test_llm_generator, - batch_size): - """Verify vLLM produces the same output with greedy sampling, when lookahead - scheduling is used vs. not. - - Lookahead scheduling is not expected to modify the output, as it simply - allocates empty slots ahead of the known token ids in a sliding fashion. - - This test constrains the total number of blocks to force preemption. It also - varies the block size so that the lookahead size is less than and greater - than the block size. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids without lookahead scheduling') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids with lookahead scheduling') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [ - { - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - "enable_chunked_prefill": True, - }, - ]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", - [{ - "block_size": 16, - "max_num_batched_tokens": 2, - "max_num_seqs": 2, - }, { - "block_size": 16, - "max_num_batched_tokens": 3, - "max_num_seqs": 2, - }, { - "block_size": 16, - "max_num_batched_tokens": 256, - "max_num_seqs": 10, - }]) -@pytest.mark.parametrize("baseline_llm_kwargs", [ - {}, -]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "num_lookahead_slots": 0, - }, - { - "num_lookahead_slots": 5, - }, -]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -def test_chunked_prefill_block_manager(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify that chunked prefill works with SelfAttnBlockSpaceManager, - with and without lookahead scheduling. - """ - output_len = 32 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - ("1 + " * 50) + " 1 = ", # Longer prompt. - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids with BlockManager') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids with BlockManager, with lookahead slots.') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Allow only 5 sequences of ~1024 tokens in worst case. - "block_size": 16, - "num_gpu_blocks_override": 5 * (64 + 1), - - # Enable prefill cache - "enable_prefix_caching": True, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "preemption_mode": "swap" -}, { - "preemption_mode": "recompute" -}]) -@pytest.mark.parametrize("batch_size", [10]) -@pytest.mark.parametrize("seed", [1]) -def test_block_manager_prefix_caching_enabled_with_preemption( - baseline_llm_generator, test_llm_generator, batch_size): - """Verify block manager produces same outputs even when there is preemption. - - This constructs two LLM, each with limited number of GPU blocks. The limit - is decided such that as the sequences in the batch grow, sequences must be - preempted and removed from cache. - - If the output token ids are equivalent, then we have confidence that the KV - cache is not corrupted. - - NOTE: We want a significant number of generated tokens so that any incorrect - KV mapping has time to build up error. - - NOTE(Kuntai): Though we have removed block manager v1, this test is still - useful as it asserts the behavior of block manager v2 (now it is called - SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we - keep this test. - """ - output_len = 1024 - temperature = 0.0 - - # We want to ensure equality even with preemption. - # We force the total block size to be 1 + cdiv(output_len, block_size) - # so that only one sequence can fit at a time (once the sequences grow). - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids from block manager') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids from block manager, with preemption') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Allow only 5 sequences of ~1024 tokens in worst case. - "block_size": 16, - "num_gpu_blocks_override": 5 * (64 + 1), - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "enable_prefix_caching": False -}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "enable_prefix_caching": True, - "preemption_mode": "swap" -}, { - "enable_prefix_caching": True, - "preemption_mode": "recompute" -}]) -@pytest.mark.parametrize("batch_size", [10]) -@pytest.mark.parametrize("seed", [1]) -def test_auto_prefix_caching_with_preemption(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify block manager v2 with auto prefix caching enabled produces same - outputs as auto prefix caching disabled, even when there is preemption. - - This constructs two LLM, each with limited number of GPU blocks. The limit - is decided such that as the sequences in the batch grow, sequences must be - preempted and removed from cache. - - If the output token ids are equivalent, then we have confidence that auto - prefix caching itself at least don't cause result error. - """ - output_len = 1024 - temperature = 0.0 - - # We want to ensure equality even with preemption. - # We force the total block size to be 1 + cdiv(output_len, block_size) - # so that only one sequence can fit at a time (once the sequences grow). - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids with APC disabled') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids with APC enabled') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # we keep the blocks small, so that hit eviction quickly - "max_model_len": 48, - "block_size": 16, - "num_gpu_blocks_override": 3, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "enable_prefix_caching": False -}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "enable_prefix_caching": True, -}]) -@pytest.mark.parametrize("seed", [1]) -def test_auto_prefix_caching_after_eviction_start(baseline_llm_generator, - test_llm_generator): - """Verify block manager v2 with auto prefix caching could work normally - even when eviction started. - With APC enabled, all blocks are held by native block at the beginning. - Then blocks are managed by evictor instead. If cache hit at the evictor's - block, then it could be reused, or we need to recompute its kv cache. - """ - output_len = 10 - temperature = 0.0 - - prompts = [ - "You are a helpful assistant. Please answer truthfully and write " - "out your thinking step by step to be sure you get the right answer. " - "If you make a mistake, attempt to correct it. who are you?", - "You are a helpful assistant. Please answer truthfully and write out " - "your thinking step by step to be sure you get the right answer. You " - "are helpful and harmless and you follow ethical guidelines. " - "who are you?" - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - print('Getting token ids with APC disabled') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids with APC enabled') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py deleted file mode 100644 index 27fe27a880e3d..0000000000000 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ /dev/null @@ -1,185 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random - -import pytest - -from tests.kernels.utils import override_backend_env_variable -from vllm import LLM, SamplingParams -from vllm.platforms import current_platform - -from .conftest import get_text_from_llm_generator - -# relatively small model with 4k sliding window -MODEL = "bigcode/starcoder2-3b" -BLOCK_SIZE = 16 - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": MODEL, - - # skip cuda graph creation for fast test. - "enforce_eager": True, - "block_size": BLOCK_SIZE, - # needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008 - "num_gpu_blocks_override": 100000 // BLOCK_SIZE, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [5]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) -def test_sliding_window_retrieval(baseline_llm_generator, test_llm_generator, - batch_size, seed, backend, monkeypatch): - """ - The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then - asks for value of one of them (which is outside the sliding window). - If we tell it upfront which we are going to be looking for, then - it answers correctly (mostly). - - Additionally, we compare the results of the v1 and v2 managers. - """ - if backend == "XFORMERS" and current_platform.is_rocm(): - pytest.skip("Xformers does not support ROCm/HIP.") - - override_backend_env_variable(monkeypatch, backend) - - sampling_params = SamplingParams( - max_tokens=1024, - ignore_eos=True, - temperature=0.0, - ) - - prompts, answer, indices = prep_prompts(batch_size) - - baseline_texts = get_text_from_llm_generator(baseline_llm_generator, - prompts, - sampling_params, - llm_cb=check_window(prompts)) - - check_answers(indices, answer, baseline_texts) - - print('Getting token ids from block manager v2') - test_texts = get_text_from_llm_generator(test_llm_generator, prompts, - sampling_params) - check_answers(indices, answer, test_texts) - - cmp = [ - expected_text == actual_text - for expected_text, actual_text in zip(baseline_texts, test_texts) - ] - print(cmp) - # make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768 - # however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290 - # states that xformers and flash_attn have different ideas about the window - # size anyways - assert sum(cmp) > 0.7 * len(cmp) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": MODEL, - - # skip cuda graph creation for fast test. - "enforce_eager": True, - "block_size": BLOCK_SIZE, - "num_gpu_blocks_override": 100000 // BLOCK_SIZE, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) -@pytest.mark.parametrize("batch_size", [5]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS"]) -def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, - backend, monkeypatch): - """ - This is similar to test_sliding_window_retrieval, however, it doesn't - compare against the v1 block manager since v1 doesn't support - chunked prefill with sliding window. - - The results with and without chunked prefill are not the same due to - numerical instabilities. - """ - if backend == "XFORMERS" and current_platform.is_rocm(): - pytest.skip("Xformers does not support ROCm/HIP.") - override_backend_env_variable(monkeypatch, backend) - - sampling_params = SamplingParams( - max_tokens=10, - ignore_eos=True, - temperature=0.0, - ) - - prompts, answer, indices = prep_prompts(batch_size) - - # We don't compare with the baseline model here, since the results - # slightly different due to different tailing in attention. - test_texts = get_text_from_llm_generator(test_llm_generator, - prompts, - sampling_params, - llm_cb=check_window(prompts)) - check_answers(indices, answer, test_texts) - - -def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): - """ - Generate prompts which a bunch of assignments, - then asking for the value of one of them. - The prompt is just under 10k tokens; sliding window is 4k - so the answer is outside sliding window, but should still be correct. - - Args: - batch_size: number of prompts to generate - ln_range: an argument to control the length of the prompt - """ - prompts: list[str] = [] - answer: list[int] = [] - indices: list[int] = [] - random.seed(1) - for _ in range(batch_size): - idx = random.randint(30, 90) - indices.append(idx) - prompt = "```python\n# We set a number of variables, " + \ - f"x{idx} will be important later\n" - ln = random.randint(*ln_range) - for k in range(30, ln): - v = random.randint(10, 99) - if k == idx: - answer.append(v) - prompt += f"x{k} = {v}\n" - prompt += f"# Now, we check the value of x{idx}:\n" - prompt += f"assert x{idx} == " - prompts.append(prompt) - return prompts, answer, indices - - -def check_answers(indices: list[int], - answer: list[int], - outputs: list[str], - accept_rate: float = 0.7): - answer2 = [int(text[0:2].strip()) for text in outputs] - print(list(zip(indices, zip(answer, answer2)))) - numok = 0 - for a1, a2 in zip(answer, answer2): - if a1 == a2: - numok += 1 - frac_ok = numok / len(answer) - print(f"Num OK: {numok}/{len(answer)} {frac_ok}") - assert frac_ok >= accept_rate - - -def check_window(prompts: list[str]): - - def inner(llm: LLM): - sliding_window = llm.llm_engine.model_config.get_sliding_window() - assert sliding_window and sliding_window > 0 - assert any( - len(llm.get_tokenizer().tokenize(prompt)) > sliding_window - for prompt in prompts) - - return inner diff --git a/tests/core/block/test_block_manager.py b/tests/core/block/test_block_manager.py deleted file mode 100644 index 9eed264fd7d43..0000000000000 --- a/tests/core/block/test_block_manager.py +++ /dev/null @@ -1,494 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) -from vllm.core.block_manager import SelfAttnBlockSpaceManager -from vllm.core.interfaces import AllocStatus -from vllm.sequence import Logprob, SequenceStatus -from vllm.utils import chunk_list - -from ..utils import (create_dummy_prompt, create_seq_group, - create_seq_group_encoder_decoder) - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80]) -@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, - num_gpu_blocks: int, watermark: float): - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - ) - num_watermark_blocks = int(watermark * num_gpu_blocks) - - num_output_blocks_per_seq = 1 - - # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but - # the current implementation assumes all seqs are new prompts / don't have - # different output lens. - num_output_blocks = num_output_blocks_per_seq - - for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks): - seq_group = create_seq_group( - seq_prompt_len=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - ) - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - - can_allocate_result = block_manager.can_allocate(seq_group) - - num_required_blocks = num_prompt_blocks + num_output_blocks - - if num_gpu_blocks - num_required_blocks < num_watermark_blocks: - assert can_allocate_result == AllocStatus.NEVER - elif num_gpu_blocks >= num_required_blocks: - assert can_allocate_result == AllocStatus.OK - else: - assert can_allocate_result == AllocStatus.LATER - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) -@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group_encoder_decoder(block_size: int, - num_seqs_per_group: int, - num_gpu_blocks: int, - watermark: float): - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - ) - num_watermark_blocks = int(watermark * num_gpu_blocks) - - num_output_blocks_per_seq = 1 - - # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but - # the current implementation assumes all seqs are new prompts / don't have - # different output lens. - num_output_blocks = num_output_blocks_per_seq - - for bdx, num_prompt_blocks in enumerate( - range(1, num_gpu_blocks - num_output_blocks)): - num_cross_blocks_per_seq = num_prompt_blocks - - seq_group = create_seq_group_encoder_decoder( - seq_prompt_len=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - request_id=str(bdx)) - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - - can_allocate_result = block_manager.can_allocate(seq_group) - - num_required_blocks = num_prompt_blocks + \ - num_output_blocks + \ - num_cross_blocks_per_seq - - if num_gpu_blocks - num_required_blocks < num_watermark_blocks: - assert can_allocate_result == AllocStatus.NEVER - elif num_gpu_blocks >= num_required_blocks: - assert can_allocate_result == AllocStatus.OK - else: - assert can_allocate_result == AllocStatus.LATER - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [16]) -@pytest.mark.parametrize("num_seqs_per_group", [1]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, - num_seqs_per_group: int, - num_gpu_blocks: int, - watermark: float): - ''' - SWA short for Sliding Window Attention. - - At time of writing block manager does not support SWA. - - However even when SWA is implemented for block manager, - there will still most likely be a separate workstream required - to enable SWA for encoder/decoder models. - - Therefore this test enforces that one of the following cases - hold true: - 1. Block manager does not support SWA at all (true at time of writing) - 2. Block manager fails with NotImplementError when SWA is enabled - AND a SequenceGroup with an encoder sequence (i.e. in support of an - encoder/decoder model) is passed into can_allocate() as an argument - - The setup for this test is stripped down version of - test_can_allocate_seq_group_encoder_decoder() - ''' - - with pytest.raises((NotImplementedError, AssertionError)) as exc_info: - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - sliding_window=5 # SWA - ) - - num_output_blocks_per_seq = 1 - num_prompt_blocks = 1 - num_output_blocks = num_output_blocks_per_seq - seq_group = create_seq_group_encoder_decoder( - seq_prompt_len=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - request_id="0") - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - block_manager.can_allocate(seq_group) - - # Assert that either - # 1. Block manager constructor fails with assertion that sliding window - # is not yet supported (most likely near-term outcome at time of - # writing), or - # 2. can_allocate() fails with NotImplementedError due to combination of - # encoder/decoder and sliding window attention - if isinstance(exc_info.value, NotImplementedError): - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA - elif isinstance(exc_info.value, AssertionError): - assert str(exc_info.value) == "Sliding window not yet supported" - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_gpu_blocks", [16]) -@pytest.mark.parametrize("num_seqs_per_group", [1]) -@pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_encoder_decoder_fails_with_prefix_cache( - block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, - watermark: float): - - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - watermark=watermark, - enable_caching=True # Prefix cache - ) - - num_output_blocks_per_seq = 1 - num_prompt_blocks = 1 - num_output_blocks = num_output_blocks_per_seq - seq_group = create_seq_group_encoder_decoder( - seq_prompt_len=block_size * num_prompt_blocks, - seq_output_lens=[ - block_size * num_output_blocks_per_seq - for _ in range(num_seqs_per_group) - ], - request_id="0") - - assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks - - # Assert that either can_allocate() fails with NotImplementedError - # due to combination of encoder/decoder and prefix cache - with pytest.raises(NotImplementedError) as exc_info: - block_manager.can_allocate(seq_group) - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("prompt_len", [1, 7, 8]) -@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) -@pytest.mark.parametrize("num_lookahead_slots", [0, 10]) -def test_append_slots(block_size, prompt_len, num_slots_to_append, - num_lookahead_slots): - """Verify append_slots consumes the correct number of blocks from the block - table. - """ - - num_gpu_blocks = 1024 - watermark = 0.1 - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - watermark=watermark, - ) - - seq_group = create_seq_group( - seq_prompt_len=prompt_len, - seq_output_lens=[0], - ) - - # Allocate seq - assert block_manager.can_allocate(seq_group) - block_manager.allocate(seq_group) - - # Seq seq to RUNNING - seq = seq_group.get_seqs()[0] - seq.status = SequenceStatus.RUNNING - - # Append tokens to the sequeqnce - for token_id in range(num_slots_to_append): - seq.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Append slots for new tokens and lookahead slots. - free_blocks_before_append = block_manager.get_num_free_gpu_blocks() - block_manager.append_slots(seq, num_lookahead_slots) - num_consumed_blocks = (free_blocks_before_append - - block_manager.get_num_free_gpu_blocks()) - - # Expect consumed blocks to be new blocks required to support the new slots. - expected_consumed_blocks = len( - list( - chunk_list( - list( - range(prompt_len + num_slots_to_append + - num_lookahead_slots)), - block_size))) - len( - list(chunk_list(list(range(prompt_len)), block_size))) - assert num_consumed_blocks == expected_consumed_blocks - - -@pytest.mark.parametrize("block_size", [8]) -@pytest.mark.parametrize("num_cpu_blocks", [4]) -@pytest.mark.parametrize("num_gpu_blocks", [4]) -@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10]) -@pytest.mark.parametrize("enable_caching", [False, True]) -def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, - enable_caching): - """Verify blocks number on src/desc device is correct after swapping in/out - sequence group (not missing or extra blocks). - """ - block_manager = SelfAttnBlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) - prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) - prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - - # Emulate a forward pass by appending a single token. - # The block manager then knows how many unprocessed - # tokens will be written in the next forward pass. - token_id = 0 - prompt.status = SequenceStatus.RUNNING - prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Swap seq group from GPU -> CPU. - gpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_out(seq_group) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_out(seq_group) - mapping_keys = [key for key, _ in mapping] - assert mapping_keys == gpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) - assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks - prompt.status = SequenceStatus.SWAPPED - - # Swap seq group from CPU -> GPU. - assert block_manager.can_swap_in(seq_group, num_lookahead_slots) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_in(seq_group) - cpu_blocks = block_manager.get_block_table(prompt) - mapping_keys = [key for key, _ in mapping] - assert mapping_keys == [cpu_blocks[0]] - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) - - -@pytest.mark.parametrize("block_size", [8]) -@pytest.mark.parametrize("num_gpu_blocks", [4]) -@pytest.mark.parametrize("num_lookahead_slots", [3, 8, 10]) -@pytest.mark.parametrize("enable_caching", [True, False]) -def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, - enable_caching): - """ Verify the block manager can correctly determine if a sequence group - can be swapped in/out. - """ - num_cpu_blocks = num_gpu_blocks - block_manager = SelfAttnBlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) - prompt, seq_group = create_dummy_prompt( - "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1) - prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - prompt.status = SequenceStatus.RUNNING - - # Swap seq group from GPU -> CPU. - gpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_out(seq_group) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_out(seq_group) - mapping_keys = [key for key, _ in mapping] - assert mapping_keys == gpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) - assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks - prompt.status = SequenceStatus.SWAPPED - - # At this moment, we still have enough free blocks to swap in the seq group. - if num_lookahead_slots <= block_size: - assert block_manager.can_swap_in(seq_group, - num_lookahead_slots) == AllocStatus.OK - else: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.NEVER - - # During Swapped out, 2 cached blocks were evicted from the GPU, - # so the prompt1 can't be swapped in - prompt2_len = 2 * block_size - 1 - prompt2, seq_group2 = create_dummy_prompt( - "2", - prompt_length=prompt2_len, - prompt_tokens=[10000 + i for i in range(prompt2_len)]) - prompt2.status = SequenceStatus.WAITING - block_manager.allocate(seq_group2) - - # Swap seq group from CPU -> GPU. - if num_lookahead_slots <= block_size: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.LATER - else: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.NEVER - - -@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10]) -@pytest.mark.parametrize("enable_caching", [False, True]) -def test_swap_in_infeasible(num_lookahead_slots, enable_caching): - """Verifies that swapping fails if there is not enough free blocks - to account for unseen tokens and lookahead_slots. - """ - block_size = 8 - num_cpu_blocks = 1 - num_gpu_blocks = 1 - block_manager = SelfAttnBlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) - prompt_length = block_size - 3 - assert prompt_length > 0 - prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length) - prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - # Emulate a forward pass by appending a single token. - # The block manager then knows how many unprocessed - # tokens will be written in the next forward pass. - token_id = 0 - prompt.status = SequenceStatus.RUNNING - prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Swap seq group from GPU -> CPU. - assert block_manager.can_swap_out(seq_group) - block_manager.swap_out(seq_group) - prompt.status = SequenceStatus.SWAPPED - - # Swap seq group from CPU -> GPU. - # The number of unseen tokens is 1. If the number of existing - # tokens plus the unseen ones and number of lookahead slots exceeds - # the total number of available GPU blocks then the swap - # should fail. - num_unseen_tokens = 1 - if (num_lookahead_slots + num_unseen_tokens + - prompt_length) <= (block_size * num_gpu_blocks): - assert block_manager.can_swap_in(seq_group, - num_lookahead_slots) == AllocStatus.OK - else: - assert block_manager.can_swap_in( - seq_group, num_lookahead_slots) == AllocStatus.NEVER - - -# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level. - - -@pytest.mark.parametrize("block_size", [8, 16]) -@pytest.mark.parametrize("prompt_len", [10, 300, 1000]) -@pytest.mark.parametrize("num_slots_to_append", [50]) -@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512]) -def test_sliding_window(block_size, prompt_len, num_slots_to_append, - sliding_window): - """Verify append_slots consumes the correct number of blocks from the block - table. - """ - - num_gpu_blocks = 1024 - watermark = 0.1 - block_manager = SelfAttnBlockSpaceManager( - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - watermark=watermark, - sliding_window=sliding_window, - ) - - def check_used(min_n, max_n=None): - if max_n is None: - max_n = min_n - used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks() - assert min_n <= used - assert used <= max_n - - def num_blocks(num_tokens): - return (num_tokens + block_size - 1) // block_size - - check_used(0) - - seq_group = create_seq_group( - seq_prompt_len=prompt_len, - seq_output_lens=[0], - ) - - check_used(0) - - # Allocate seq - assert block_manager.can_allocate(seq_group) - block_manager.allocate(seq_group) - - check_used(num_blocks(prompt_len)) - - # Seq seq to RUNNING - seq = seq_group.get_seqs()[0] - seq.status = SequenceStatus.RUNNING - - seq.data.update_num_computed_tokens(prompt_len) - check_used(num_blocks(prompt_len)) - - # this is how we compute it in SelfAttnBlockSpaceManager.__init__ - sliding_blocks = (sliding_window // block_size) + 2 - # plus one block for null block - sliding_blocks += 1 - - # Append tokens to the sequeqnce - for token_id in range(num_slots_to_append): - seq.append_token_id(token_id, {token_id: Logprob(0.0)}) - seq.data.update_num_computed_tokens(1) - block_manager.append_slots(seq, num_lookahead_slots=0) - if prompt_len < sliding_window + 10: - check_used(0, sliding_blocks + 1) - else: - check_used(sliding_blocks, sliding_blocks + 1) diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py deleted file mode 100644 index ba085001136be..0000000000000 --- a/tests/core/block/test_block_table.py +++ /dev/null @@ -1,577 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.core.block.block_table import BlockTable -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.utils import Device, cdiv, chunk_list - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -def test_allocate_naive(block_size: int, sequence_len: int): - """Test the allocation of blocks using the naive allocator. - - This test creates a CpuGpuBlockAllocator with the specified block size and - number of blocks. It then allocates multiple BlockTables with varying - sequence lengths and verifies that the number of free blocks decreases as - expected after each allocation. - """ - assert block_size > 1 - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type="naive", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) - - block_tables: list[BlockTable] = [] - for i in range(5): - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc - - block_tables.append( - BlockTable( - block_size=block_size, - block_allocator=allocator, - )) - block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -def test_allocate_prefix_caching(block_size: int, sequence_len: int): - """Test the allocation of blocks using the prefix caching allocator. - - This test creates a CpuGpuBlockAllocator with the specified block size and - number of blocks, using the prefix caching allocator. It then allocates - multiple BlockTables with varying sequence lengths and verifies that the - number of free blocks decreases as expected after each allocation. - - The test expects all sequences to share allocations, except for their last - block, which may be mutable. It calculates the expected number of immutable - and mutable blocks per allocation based on the sequence length and block - size. - """ - assert block_size > 1 - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - chunked_tokens = list(chunk_list(token_ids, block_size)) - num_mutable_blocks_per_alloc = 0 if len( - chunked_tokens[-1]) == block_size else 1 - num_immutable_blocks_per_alloc = len( - chunked_tokens) - num_mutable_blocks_per_alloc - - block_tables: list[BlockTable] = [] - for alloc_i in range(1, 6): - - block_tables.append( - BlockTable( - block_size=block_size, - block_allocator=allocator, - )) - block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) - - # Expect all sequences to share allocations, except for their last block - # (which may be mutable). - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_gpu_blocks - ( - num_immutable_blocks_per_alloc + num_mutable_blocks_per_alloc * - (alloc_i)) - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -@pytest.mark.parametrize("device", ["cpu", "gpu"]) -def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, - device: str): - """Test the allocation and freeing of blocks using different allocators and - devices. - - This test creates a CpuGpuBlockAllocator with the specified block size, - number of blocks, allocator type, and device. It then allocates a BlockTable - multiple times with the same sequence and verifies that the number of free - blocks remains consistent after each allocation and freeing. - """ - device = Device[device.upper()] - - num_device_blocks = 1024 - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_device_blocks, - num_cpu_blocks=num_device_blocks, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - for i in range(5): - block_table.allocate(token_ids=token_ids, device=device) - assert allocator.get_num_free_blocks( - device) == num_device_blocks - num_blocks_per_alloc - assert all(block_id is not None - for block_id in block_table.physical_block_ids) - - block_table.free() - assert allocator.get_num_free_blocks(device) == num_device_blocks - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("append_len", [1, 16, 129]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_append_token_ids_allocation(block_size: int, sequence_len: int, - append_len: int, allocator_type: str): - """Test the allocation behavior when appending token IDs to a BlockTable. - - This test creates a CpuGpuBlockAllocator with the specified block size, - number of blocks, and allocator type. It then allocates a BlockTable with an - initial sequence and appends additional token IDs to it. The test verifies - that the number of allocated blocks before and after appending matches the - expected values. - """ - - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - num_expected_blocks_before_append = len( - list(chunk_list(token_ids, block_size))) - num_expected_appended_blocks = len( - list(chunk_list(token_ids + token_ids_to_append, - block_size))) - num_expected_blocks_before_append - - block_table.allocate(token_ids=token_ids, device=Device.GPU) - - assert len( - block_table.physical_block_ids) == num_expected_blocks_before_append - block_table.append_token_ids(token_ids_to_append) - assert len( - block_table.physical_block_ids - ) == num_expected_blocks_before_append + num_expected_appended_blocks - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("num_empty_slots", [1, 16, 129]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, - num_empty_slots: int, - allocator_type: str): - """Test the allocation behavior when ensuring a certain number of empty - slots in a BlockTable. - - This test creates a CpuGpuBlockAllocator with the specified block size, - number of blocks, and allocator type. It then allocates a BlockTable with an - initial sequence and ensures a certain number of empty slots. The test - verifies that the number of allocated blocks before and after ensuring empty - slots matches the expected values. It also checks that filling up the empty - slots does not consume additional blocks. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - num_expected_blocks_before_append = len( - list(chunk_list(token_ids, block_size))) - num_expected_appended_blocks = len( - list(chunk_list(token_ids + [-1] * num_empty_slots, - block_size))) - num_expected_blocks_before_append - - block_table.allocate(token_ids=token_ids, device=Device.GPU) - - # Assert that the empty slots consume the expected number of additional - # blocks. - assert len( - block_table.physical_block_ids) == num_expected_blocks_before_append - block_table.ensure_num_empty_slots(num_empty_slots) - assert len( - block_table.physical_block_ids - ) == num_expected_blocks_before_append + num_expected_appended_blocks - - # Now, ensure no additional blocks consumed as we fill up the empty slots. - num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU) - block_table.append_token_ids(token_ids=list(range(num_empty_slots))) - assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU) - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("sequence_len", [1, 9]) -@pytest.mark.parametrize("append_len", [1, 16, 129]) -@pytest.mark.parametrize("append_size", [1, 4, 129]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_append_token_ids_correct_content(block_size: int, sequence_len: int, - append_len: int, allocator_type: str, - append_size: int): - """Verify token ids are correctly appended. Appends various amounts of - token ids in various append sizes, and verifies the final sequence is - correct. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=1024, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - block_table.allocate(token_ids=token_ids, device=Device.GPU) - - appended_so_far: list[int] = [] - for append in chunk_list(token_ids_to_append, append_size): - block_table.append_token_ids(append) - appended_so_far.extend(append) - - assert block_table._get_all_token_ids() == token_ids + appended_so_far - - assert block_table._get_all_token_ids() == token_ids + token_ids_to_append - - -@pytest.mark.parametrize("seq_len", [1, 9, 129]) -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_fork(seq_len: int, block_size: int, allocator_type: str): - """Create a sequence using the specified allocator. - 1. Assert that after forking the sequence, the free block count is the - same. - 2. Assert that the forked sequence has the same physical mappings. - 3. Then free the original sequence; verify that the free block count is - the same. - 4. Finally, free the forked sequence and verify that the free block - count drops to zero. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - block_size=block_size, - ) - - token_ids = list(range(seq_len)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - block_table.allocate(token_ids) - - num_free_blocks_before_fork = allocator.get_num_free_blocks( - device=Device.GPU) - - forked_block_table = block_table.fork() - - # Expect physical_block_ids and token_ids to match. - assert (block_table.physical_block_ids == - forked_block_table.physical_block_ids) - assert block_table._get_all_token_ids( - ) == forked_block_table._get_all_token_ids() - - # Do not expect any additional allocations. - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_free_blocks_before_fork - - # Free the original blocks. Assert num free blocks does not change, since - # refcount is nonzero. - block_table.free() - assert allocator.get_num_free_blocks( - device=Device.GPU) == num_free_blocks_before_fork - - # Expect the forked block table to be unaffected by the free. - assert all(block_id is not None - for block_id in forked_block_table.physical_block_ids) - - # Free the forked blocks. Assert num free blocks does change, since - # refcount is now zero. - forked_block_table.free() - assert allocator.get_num_free_blocks(device=Device.GPU) == num_gpu_blocks - - -@pytest.mark.parametrize("block_size", [8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("append_len", [1, 16, 129]) -@pytest.mark.parametrize("appender", ["forked", "original"]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_cow(block_size: int, sequence_len: int, append_len: int, - allocator_type: str, appender: str): - """Fork a sequence; append to the forked sequence; verify there's a CoW. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) - - original_block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - num_expected_non_cow_blocks = cdiv(sequence_len, block_size) - num_expected_cow_blocks = cdiv(sequence_len + append_len, - block_size) - (sequence_len // block_size) - - original_block_table.allocate(token_ids=token_ids, device=Device.GPU) - original_block_ids = original_block_table.physical_block_ids[:] - - print("original_block_ids = {}".format(original_block_ids)) - forked_block_table = original_block_table.fork() - - # Expect no additional allocation (copy on _write_). - assert allocator.get_num_free_blocks( - Device.GPU) == (num_gpu_blocks - num_expected_non_cow_blocks) - - if appender == "forked": - appender_block_table = forked_block_table - static_block_table = original_block_table - elif appender == "original": - appender_block_table = original_block_table - static_block_table = forked_block_table - else: - raise ValueError(f"unknown test config {appender=}") - - # Write tokens. - appender_block_table.append_token_ids(token_ids_to_append) - - # Expect the non-appending block table to have no change. - assert static_block_table.physical_block_ids == original_block_ids - assert appender_block_table.physical_block_ids != original_block_ids - - # Expect the blocks changed during append to have a CoW. - assert allocator.get_num_free_blocks( - Device.GPU) == num_gpu_blocks - (num_expected_non_cow_blocks + - num_expected_cow_blocks) - - cows = allocator.clear_copy_on_writes() - if sequence_len % block_size > 0: - # If the last block in the sequence is not full, then when appending we - # expect a CoW. - assert cows - - cow_block_id = sequence_len // block_size - expected_src = static_block_table.physical_block_ids[cow_block_id] - expected_dst = appender_block_table.physical_block_ids[cow_block_id] - - assert (expected_src, expected_dst) in cows - else: - # Otherwise, there should be no copy-on-write. - assert not cows - - static_block_table.free() - appender_block_table.free() - - # After free, expect all blocks to be freed. - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - -@pytest.mark.parametrize("block_size", [8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("append_len", [1, 16, 129]) -@pytest.mark.parametrize("lookahead_slots", [1, 16, 129]) -@pytest.mark.parametrize("appender", ["forked", "original"]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_cow_lookahead_simple(block_size: int, sequence_len: int, - append_len: int, lookahead_slots: int, - allocator_type: str, appender: str): - """Similar to test_cow, except with lookahead allocation. The assertions are - less rigorous due to the complexity of the property under test. - """ - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(append_len)) - - original_block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - original_block_table.allocate(token_ids=token_ids, device=Device.GPU) - - # Allocate lookahead slots. - original_block_table.ensure_num_empty_slots(lookahead_slots) - original_block_ids = original_block_table.physical_block_ids[:] - - forked_block_table = original_block_table.fork() - - if appender == "forked": - appender_block_table = forked_block_table - static_block_table = original_block_table - elif appender == "original": - appender_block_table = original_block_table - static_block_table = forked_block_table - else: - raise ValueError(f"unknown test config {appender=}") - - # Write tokens. - appender_block_table.append_token_ids(token_ids_to_append) - - # Expect the non-appending block table to have no change. - assert static_block_table.physical_block_ids == original_block_ids - assert appender_block_table.physical_block_ids != original_block_ids - - cows = allocator.clear_copy_on_writes() - - # Always expect copy-on-write - assert cows - - if sequence_len % block_size > 0: - # If the last block in the sequence is not full, then when appending we - # expect a CoW. - assert cows - - cow_block_id = sequence_len // block_size - expected_src = static_block_table.physical_block_ids[cow_block_id] - expected_dst = appender_block_table.physical_block_ids[cow_block_id] - - assert (expected_src, expected_dst) in cows - - static_block_table.free() - appender_block_table.free() - - # After free, expect all blocks to be freed. - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - -@pytest.mark.parametrize("block_size", [1, 8]) -@pytest.mark.parametrize("sequence_len", [1, 16, 129]) -@pytest.mark.parametrize("num_new_tokens", [1, 16, 129]) -@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int, - num_new_tokens: int, - num_lookahead_slots: int, - allocator_type: str): - """Verify correct calculation of get_num_blocks_touched_by_append_slots. - - This is done by using copy-on-write, which requires any modified block to - be copied before write if the refcount > 1. We set the refcount>1 by forking - a sequence, then measure the free blocks before and after an append. If the - number of consumed blocks equals what `get_num_blocks_touched_by_append_ - slots` returns, then the calculation is correct. - """ - - num_gpu_blocks = 1024 - - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=0, - block_size=block_size, - ) - - token_ids = list(range(sequence_len)) - token_ids_to_append = list(range(num_new_tokens)) - - block_table = BlockTable( - block_size=block_size, - block_allocator=allocator, - ) - - block_table.allocate(token_ids=token_ids, device=Device.GPU) - - # Add lookahead before fork so both sequences have the same lookahead - # blocks. - block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots) - - # Fork sequence so that every block has refcount > 1. - _ = block_table.fork() - - # Determine how many blocks should be touched. - expected_num_touched_blocks = ( - block_table.get_num_blocks_touched_by_append_slots( - token_ids=token_ids_to_append, - num_lookahead_slots=num_lookahead_slots)) - - # Measure how many blocks are touched by measuring num_free_blocks before - # and after the append. - # - # We expect append_token_ids to CoW all mutated blocks that have refcount>1. - num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU) - block_table.append_token_ids(token_ids_to_append, num_lookahead_slots) - num_consumed_blocks = (num_free_blocks_before_append - - allocator.get_num_free_blocks(Device.GPU)) - - # TODO(cade) ensure equality when num_lookahead_slots > 0. - # The reason we have < is because lookahead blocks are not copied eagerly; - # they are copied on first write. This will cause issues for beam search + - # speculative decoding. This is acceptable for now as it is a large effort - # to combine the two. To fix this, we can ensure single sequence ownership - # of lookahead blocks by appending empty slots to each block, which will - # trigger the CoW. - # - # Until then, we can accept that the consumed tokens are <= the expected - # tokens when appending with lookahead. - if num_lookahead_slots > 0: - assert num_consumed_blocks <= expected_num_touched_blocks - else: - assert num_consumed_blocks == expected_num_touched_blocks diff --git a/tests/core/block/test_common.py b/tests/core/block/test_common.py deleted file mode 100644 index 65400899b811c..0000000000000 --- a/tests/core/block/test_common.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random - -import pytest - -from vllm.core.block.common import RefCounter - - -@pytest.mark.parametrize("seed", list(range(20))) -@pytest.mark.parametrize("num_incrs", [1, 100]) -@pytest.mark.parametrize("num_blocks", [1024]) -def test_incr(seed: int, num_incrs: int, num_blocks: int): - random.seed(seed) - - all_block_indices = list(range(num_blocks)) - counter = RefCounter(all_block_indices=all_block_indices) - - block_id = random.randint(0, num_blocks - 1) - for i in range(num_incrs): - value = counter.incr(block_id) - assert value == i + 1 - - -@pytest.mark.parametrize("seed", list(range(20))) -@pytest.mark.parametrize("num_incrs", [1, 100]) -@pytest.mark.parametrize("num_blocks", [1024]) -def test_incr_decr(seed: int, num_incrs: int, num_blocks: int): - random.seed(seed) - - all_block_indices = list(range(num_blocks)) - counter = RefCounter(all_block_indices=all_block_indices) - - block_id = random.randint(0, num_blocks - 1) - for i in range(num_incrs): - value = counter.incr(block_id) - assert value == i + 1 - - for i in range(num_incrs): - value = counter.decr(block_id) - assert value == num_incrs - (i + 1) - - with pytest.raises(AssertionError): - counter.decr(block_id) diff --git a/tests/core/block/test_cpu_gpu_block_allocator.py b/tests/core/block/test_cpu_gpu_block_allocator.py deleted file mode 100644 index 795eef6743fd1..0000000000000 --- a/tests/core/block/test_cpu_gpu_block_allocator.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.utils import Device, chunk_list - - -@pytest.mark.parametrize("num_cpu_blocks", [0, 512]) -@pytest.mark.parametrize("num_gpu_blocks", [1024]) -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_allocate_mutable_block(num_cpu_blocks: int, num_gpu_blocks: int, - block_size: int, allocator_type: str): - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, - ) - - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - cpu_blocks = [ - allocator.allocate_mutable_block(prev_block=None, device=Device.CPU) - for _ in range(num_cpu_blocks) - ] - assert allocator.get_num_free_blocks(Device.CPU) == 0 - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - gpu_blocks = [ - allocator.allocate_mutable_block(prev_block=None, device=Device.GPU) - for _ in range(num_gpu_blocks) - ] - assert allocator.get_num_free_blocks(Device.CPU) == 0 - assert allocator.get_num_free_blocks(Device.GPU) == 0 - - _ = [allocator.free(block) for block in cpu_blocks] - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == 0 - - _ = [allocator.free(block) for block in gpu_blocks] - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - -@pytest.mark.parametrize("num_cpu_blocks", [0, 512]) -@pytest.mark.parametrize("num_gpu_blocks", [1024]) -@pytest.mark.parametrize("block_size", [2]) -@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) -def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, - block_size: int, allocator_type: str): - allocator = CpuGpuBlockAllocator.create( - allocator_type=allocator_type, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, - ) - - unique_token_ids = list( - range((num_cpu_blocks + num_gpu_blocks) * block_size)) - gpu_token_ids = list( - chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size)) - cpu_token_ids = list( - chunk_list(unique_token_ids[num_gpu_blocks * block_size:], block_size)) - - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - cpu_blocks = [ - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids, - device=Device.CPU) - for token_ids in cpu_token_ids - ] - assert allocator.get_num_free_blocks(Device.CPU) == 0 - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks - - gpu_blocks = [ - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids, - device=Device.GPU) - for token_ids in gpu_token_ids - ] - assert allocator.get_num_free_blocks(Device.CPU) == 0 - assert allocator.get_num_free_blocks(Device.GPU) == 0 - - _ = [allocator.free(block) for block in cpu_blocks] - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == 0 - - _ = [allocator.free(block) for block in gpu_blocks] - assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks - assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks diff --git a/tests/core/block/test_naive_block.py b/tests/core/block/test_naive_block.py deleted file mode 100644 index a31d1c46b37f0..0000000000000 --- a/tests/core/block/test_naive_block.py +++ /dev/null @@ -1,148 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import pytest - -from vllm.core.block.interfaces import Block, BlockAllocator -from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator - - -class TestNaiveBlockAllocator: - - @staticmethod - def create_allocate_lambda(allocate_type: str, - allocator: NaiveBlockAllocator, - prev_block: Optional[Block], - token_ids: list[int]): - if allocate_type == "immutable": - allocate_block = lambda: allocator.allocate_immutable_block( - prev_block=prev_block, token_ids=token_ids) - elif allocate_type == "mutable": - allocate_block = lambda: allocator.allocate_mutable_block( - prev_block=prev_block) - else: - raise ValueError() - - return allocate_block - - @staticmethod - @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_ooms(allocate_type: str, num_blocks: int, - block_size: int): - allocator = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - allocate_type, - allocator, - prev_block=None, - token_ids=list(range(block_size))) - - [allocate_block() for _ in range(num_blocks)] - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocate_block() - - @staticmethod - @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_free_prevents_oom(allocate_type: str, num_blocks: int, - block_size: int): - allocator = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - allocate_type, - allocator, - prev_block=None, - token_ids=list(range(block_size))) - - blocks = [allocate_block() for _ in range(num_blocks)] - - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocate_block() - - block_to_free = blocks.pop() - - for _ in range(100): - block_id = block_to_free.block_id - allocator.free(block_to_free) - assert block_to_free.block_id is None - - new_block = allocate_block() - assert new_block.block_id == block_id - - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocate_block() - - block_to_free = new_block - - @staticmethod - @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - def test_get_num_free_blocks(allocate_type: str, num_blocks: int, - block_size: int): - allocator = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - allocate_type, - allocator, - prev_block=None, - token_ids=list(range(block_size))) - - assert allocator.get_num_free_blocks() == num_blocks - - blocks = [allocate_block() for _ in range(num_blocks)] - - for i, block in enumerate(blocks): - assert allocator.get_num_free_blocks() == i - allocator.free(block) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [4]) - @pytest.mark.parametrize("block_size", [8]) - def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size): - """ Verify the allocator can correctly return the number of - full blocks touched. - """ - allocator_src = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - allocator_dst = NaiveBlockAllocator(create_block=NaiveBlock, - num_blocks=num_blocks, - block_size=block_size) - - # Create a chain of cacheable blocks in the dst - allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( - "immutable", - allocator_src, - prev_block=None, - token_ids=list(range(block_size))) - src_blocks = [allocate_block() for _ in range(num_blocks - 1)] - - # All blocks are cached - assert allocator_dst.get_num_full_blocks_touched( - src_blocks) == num_blocks - 1 - - # Insert one non-full block in the src - allocate_non_full_block = \ - TestNaiveBlockAllocator.create_allocate_lambda( - "mutable", allocator_src, - prev_block=src_blocks[-1],token_ids=[] - ) - src_blocks.append(allocate_non_full_block()) - src_blocks[-1].append_token_ids([0]) - - assert allocator_dst.get_num_full_blocks_touched( - src_blocks) == num_blocks - 1 - # Fill up the last source block and then invoke - # get_num_blocks_touched - src_blocks[-1].append_token_ids([0] * (block_size - 1)) - assert allocator_dst.get_num_full_blocks_touched( - src_blocks) == num_blocks diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py deleted file mode 100644 index 46e224c6f53b2..0000000000000 --- a/tests/core/block/test_prefix_caching_block.py +++ /dev/null @@ -1,1035 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -import random -from typing import Optional -from unittest.mock import MagicMock - -import pytest - -from tests.core.utils import create_dummy_lora_sequence, create_dummy_sequence -from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.core.block.interfaces import Block, BlockAllocator -from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, - PrefixCachingBlock, - PrefixCachingBlockAllocator) -from vllm.sequence import Logprob -from vllm.utils import Device - - -class TestPrefixCachingBlock: - - @staticmethod - @pytest.mark.parametrize("seed", list(range(10))) - @pytest.mark.parametrize("block_size", [1, 16]) - @pytest.mark.parametrize("is_curr_block_full", [True, False]) - def test_first_block_has_correct_content_hash(seed: int, block_size: int, - is_curr_block_full: bool): - """Verify a block which is first in the sequence has the correct hash. - """ - random.seed(seed) - num_to_fill = block_size if is_curr_block_full else random.randint( - 0, block_size - 1) - token_ids = list(range(num_to_fill)) - mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) - - block_with_prev = PrefixCachingBlock(prev_block=None, - token_ids=token_ids, - block_size=block_size, - allocator=mock_allocator) - - if is_curr_block_full: - # Expect hash since block is full. - assert block_with_prev.content_hash == ( - PrefixCachingBlock.hash_block_tokens( - is_first_block=True, - prev_block_hash=None, - cur_block_token_ids=token_ids)) - else: - # Do not expect hash since block is not full. - assert block_with_prev.content_hash is None - - @staticmethod - @pytest.mark.parametrize("seed", list(range(10))) - @pytest.mark.parametrize("block_size", [1, 16]) - @pytest.mark.parametrize("is_curr_block_full", [True, False]) - @pytest.mark.parametrize("prev_block_has_hash", [True, False]) - def test_nth_block_has_correct_content_hash(seed: int, block_size: int, - is_curr_block_full: bool, - prev_block_has_hash: bool): - """Verify a block which is not first in the sequence has the correct - hash. - """ - - random.seed(seed) - - previous_block = MagicMock(spec=PrefixCachingBlock) - prev_block_hash = random.randint(0, 1000) - previous_block.content_hash = (prev_block_hash if prev_block_has_hash - else hash('None')) - - num_to_fill = block_size if is_curr_block_full else random.randint( - 0, block_size - 1) - token_ids = list(range(num_to_fill)) - mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) - - block_with_prev = PrefixCachingBlock( - prev_block=previous_block, - token_ids=token_ids, - block_size=block_size, - allocator=mock_allocator, - ) - - if is_curr_block_full and prev_block_has_hash: - # Expect hash since block is full and previous block has hash. - assert (block_with_prev.content_hash == - PrefixCachingBlock.hash_block_tokens( - is_first_block=False, - prev_block_hash=prev_block_hash, - cur_block_token_ids=token_ids)) - else: - # Do not expect hash since block is not full or the previous block - # does not have a hash. - assert block_with_prev.content_hash is None - - @staticmethod - @pytest.mark.parametrize("block_size", [1, 2, 16]) - @pytest.mark.parametrize("num_tokens", list(range(3))) - @pytest.mark.parametrize("num_empty_trailing_blocks", [0, 1, 10]) - def test_blocks_have_correct_hash_in_chain(block_size: int, - num_tokens: int, - num_empty_trailing_blocks: int): - """Create two chains of logical blocks with the same contents. - Assert the hashes are equal. - """ - random.seed(0) - - token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)] - - first_chain, second_chain = (TestPrefixCachingBlock.create_chain( - block_size=block_size, - token_ids=token_ids, - num_empty_trailing_blocks=num_empty_trailing_blocks) - for _ in range(2)) - - for first_chain_block, second_chain_block in zip( - first_chain, second_chain): - assert (first_chain_block.content_hash == - second_chain_block.content_hash) - - if not first_chain or not second_chain: - assert first_chain == second_chain - assert num_tokens == 0 - - @staticmethod - def create_chain(block_size: int, - token_ids: list[int], - num_empty_trailing_blocks=0) -> list[PrefixCachingBlock]: - """Helper method which creates a chain of blocks. - """ - blocks: list[PrefixCachingBlock] = [] - num_blocks = math.ceil( - len(token_ids) / block_size) + num_empty_trailing_blocks - - if num_blocks == 0: - return [] - - allocator = MagicMock(spec=PrefixCachingBlockAllocator) - - prev_block = None - for block_number in range(0, num_blocks): - prev_block = PrefixCachingBlock( - prev_block=prev_block, - token_ids=[], - block_size=block_size, - allocator=allocator, - ) - - tokens_to_append = token_ids[block_number * - block_size:(block_number + 1) * - block_size] - if tokens_to_append: - prev_block.append_token_ids(tokens_to_append) - - blocks.append(prev_block) - - return blocks - - -class TestPrefixCachingBlockAllocator: - - @staticmethod - def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator, - prev_block: Optional[Block], - token_ids: list[int]): - if allocate_type == "immutable": - allocate_block = lambda: allocator.allocate_immutable_block( - prev_block=prev_block, token_ids=token_ids) - elif allocate_type == "mutable": - allocate_block = lambda: allocator.allocate_mutable_block( - prev_block=prev_block) - else: - raise ValueError() - - return allocate_block - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_mutable_ooms(num_blocks: int, block_size: int): - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda( - allocate_type="mutable", - allocator=allocator, - prev_block=None, - token_ids=list(range(block_size)), - ) - - [allocate_block() for _ in range(num_blocks)] - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocate_block() - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_immutable_does_not_oom_single_hash( - num_blocks: int, block_size: int): - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda( - allocate_type="immutable", - allocator=allocator, - prev_block=None, - token_ids=list(range(block_size)), - ) - - blocks = [allocate_block() for _ in range(num_blocks)] - - # Expect no OOM. If these were mutable blocks, this would OOM. - non_oom_block = allocate_block() - - # Expect all blocks to have same physical block index. - for block in blocks: - assert (block.block_id == non_oom_block.block_id) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_allocate_immutable_ooms_many_hash(num_blocks: int, - block_size: int): - """Consume all blocks using many different hashes/block content. - - Do this by creating a sequence that is very long. - Expect next block to OOM. - """ - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks * block_size)) - - chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Expect allocation with unseen hash to fail. - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_immutable_block(prev_block=chain[-1], - token_ids=list( - range(block_size))) - - # Expect mutable allocation to fail. - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_mutable_block(prev_block=chain[-1]) - - # Expect allocation of exact same chain to pass. - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Expect physical block indices to be the same in both chains. - assert chain and second_chain - for first_chain_block, second_chain_block in zip(chain, second_chain): - assert (first_chain_block.block_id == second_chain_block.block_id) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1, 1024]) - @pytest.mark.parametrize("block_size", [1, 16]) - def test_free_prevents_oom(num_blocks: int, block_size: int): - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks * block_size)) - - chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Expect mutable allocation to fail. - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_mutable_block(prev_block=None) - - block_to_free = chain[-1] - - # Expect free/allocate loop to succeed many times. - for i in range(100): - block_id = block_to_free.block_id - allocator.free(block_to_free) - assert block_to_free.block_id is None, i - - new_block = allocator.allocate_mutable_block(prev_block=None) - assert new_block.block_id == block_id, i - - with pytest.raises(BlockAllocator.NoFreeBlocksError): - allocator.allocate_mutable_block(prev_block=None) - - block_to_free = new_block - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(20))) - def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int): - random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - num_blocks_to_consume = random.randint(1, num_blocks - 1) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks_to_consume * block_size)) - - chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Free each block in chain, assert num free blocks includes new free - # block. - for i, block in enumerate(chain): - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_to_consume + - i) - allocator.free(block) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [4]) - @pytest.mark.parametrize("block_size", [8]) - def test_prefix_caching_block_get_num_full_blocks_touched( - num_blocks, block_size): - """ Verify the allocator can correctly return the number of - blocks touched, when there are cached prefixes. - """ - allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - allocator_dst = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - - # Create token ids that will exhaust all blocks except the last - token_ids = list(range((num_blocks - 1) * block_size)) - - # Create a chain of cacheable blocks in the dst - cached_blocks = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator_dst, - ) - - # Create a chain of the same blocks in the src - blocks_to_swap_in = \ - TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator_src, - ) - # All blocks are cached - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 0 - - # Free the first block in the dst - allocator_dst.free(cached_blocks[0]) - - # Now the first block becomes dangling, the swapped blocks need - # to reclaim the first block in the dst - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 1 - - # Insert one non-full block in the src - non_full_block = allocator_src.allocate_mutable_block( - blocks_to_swap_in[-1]) - non_full_block.append_token_ids([0]) - blocks_to_swap_in.append(non_full_block) - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 1 - # Fill up the last mutable block and invoke get_num_blocks_touched. - # Note: The last block is not cached so it will be touched. - non_full_block.append_token_ids([0] * (block_size - 1)) - assert allocator_dst.get_num_full_blocks_touched( - blocks_to_swap_in) == 2 - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(20))) - def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, - seed: int): - """Verify sharing occurs by allocating two sequences that share prefixes - and incrementally freeing blocks. - """ - random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - num_blocks_to_consume = random.randint(1, num_blocks - 1) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks_to_consume * block_size)) - - first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Free each block in the first chain. Since all blocks are shared, the - # free count should stay constant. - for i, block in enumerate(first_chain): - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_to_consume) - allocator.free(block) - - # Free each block in the second chain. Since the refcount is now zero, - # the free count should increment with each free. - for i, block in enumerate(second_chain): - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_to_consume + - i) - allocator.free(block) - - @staticmethod - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(20))) - def test_get_common_computed_block_ids(num_blocks: int, block_size: int, - seed: int): - """Verify get_common_computed_block_ids could get correct result - by create two immutable chain sharing prefix at specified pos, - and compare whether we also could get right result - from get_common_computed_block_ids. - """ - random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2, - block_size=block_size) - num_blocks_to_consume = random.randint(1, num_blocks - 1) - - # Create token ids that will exhaust all blocks. - token_ids = list(range(num_blocks_to_consume * block_size)) - - first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # After zero_point, second_chain's token_ids would be set -1, which - # make it different from here comparing with first_chain - zero_point = random.randint(1, len(token_ids) - 1) - zero_point_blocks = zero_point // block_size - token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point) - - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - first_computed_ids = [ - first_chain[i].block_id for i in range(num_blocks_to_consume) - ] - second_computed_ids = [ - second_chain[i].block_id for i in range(num_blocks_to_consume) - ] - res = allocator.get_common_computed_block_ids( - [first_computed_ids, second_computed_ids]) - - assert (len(res) == zero_point_blocks) - - # Test case that assume those prompted block after first immutable would - # be freed into hashless allocator, while first immutable block get ref - # increased. - @staticmethod - @pytest.mark.parametrize("num_blocks", [3]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(10))) - def test_alloc_promotion(num_blocks: int, block_size: int, seed: int): - random.seed(seed) - - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - token_ids = list(range(block_size)) - - block = allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) - - assert allocator._refcounter.get(block.block_id) == 1 - m = allocator.allocate_mutable_block(prev_block=None) - - block_id = m.block_id - for i in range(block_size): - m.append_token_ids([i]) - - # After block get promoted to immutable from mutable, if there is - # already same content hash block, then it shall be released into - # hashless_allocator - # And first immutable block's ref get increased by 1 - assert m.block_id == block.block_id - assert block_id in allocator._hashless_allocator._free_block_indices - assert allocator._refcounter.get(block.block_id) == 2 - - # Test case when eviction and allocation are mixed, - # make sure they work as expected - @staticmethod - @pytest.mark.parametrize("num_blocks", [3]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(10))) - def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): - random.seed(seed) - - all_blocks_list = [i for i in range(num_blocks)] - zero_ref = {i: 0 for i in range(num_blocks)} - one_ref = {i: 1 for i in range(num_blocks)} - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - token_ids = list(range(num_blocks * block_size)) - - # Verify initial/pre-alloc state - - # Ensure all blocks are free inside hashless allocator - assert list(allocator._hashless_allocator._free_block_indices - ) == all_blocks_list - # Ensure no tracked blocks - assert len(allocator._block_tracker.keys()) == num_blocks - for block_id in range(num_blocks): - assert not allocator._block_tracker[block_id].active - # Ensure no cached blocks - assert len(allocator._cached_blocks.values()) == 0 - # Ensure no evicted blocks - assert len(allocator.evictor.free_table.keys()) == 0 - # Ensure 0s ref counts for all blocks - assert allocator._refcounter._refcounts == zero_ref - - # Allocate immutable chains with only one block residuled in - new_block = [] - for i in range(num_blocks): - block = allocator.allocate_immutable_block( - prev_block=None, - token_ids=token_ids[block_size * i:block_size * (i + 1)]) - new_block.append(block) - - # Verify post-alloc state - - # Ensure no blocks are free inside hashless allocator - assert (len(allocator._hashless_allocator._free_block_indices) == 0) - # Ensure all blocks are tracked - assert len(allocator._block_tracker.keys()) == num_blocks - for block_id in range(num_blocks): - assert allocator._block_tracker[block_id].active - # Ensure all blocks are cached (all promoted) - assert len(allocator._cached_blocks.values()) == num_blocks - # Ensure no evicted blocks - assert len(allocator.evictor.free_table.keys()) == 0 - # Ensure 1s ref counts for all blocks - assert allocator._refcounter._refcounts == one_ref - - # Free all blocks, and now all blocks shall be in the evictor - # there shall be no tracking data left in _block_tracker - # all blocks shall be tracked in _cached_blocks - # all blocks' ref shall be zero - for block in new_block: - allocator.free(block) - - # Verify post-free state - - # Ensure no tracked blocks - assert len(allocator._block_tracker.keys()) == num_blocks - for block_id in range(num_blocks): - assert not allocator._block_tracker[block_id].active - # Ensure no blocks in hashless allocator (all promoted) - assert len(allocator._hashless_allocator._free_block_indices) == 0 - # Ensure all blocks are cached - assert list(allocator._cached_blocks.values()) == all_blocks_list - # Ensure all blocks are inside the evictor - assert list(allocator.evictor.free_table.keys()) == all_blocks_list - # Ensure 0s refcounts - assert allocator._refcounter._refcounts == zero_ref - - # Allocate a mutable block, and the first block shall be evicted - # and set its content hash into None, ref to 1 - mutable = allocator.allocate_mutable_block(prev_block=None) - - assert mutable.block_id == 0 - assert mutable.content_hash is None - assert allocator._block_tracker[0].active - assert allocator._refcounter.get(0) == 1 - assert 0 not in allocator._cached_blocks - assert 0 not in allocator.evictor - - # Since this mutable block has no hash yet, it shall be released into - # hashless allocator - allocator.free(mutable) - - assert not allocator._block_tracker[0].active - assert allocator._refcounter._refcounts == zero_ref - assert 0 not in allocator._cached_blocks - assert 0 not in allocator.evictor - assert 0 in allocator._hashless_allocator._free_block_indices - - # When allocate immutable with first block_size tokens, we - # shall get free block from hashless allocator, thus no block left - # in hashless - block = allocator.allocate_immutable_block( - prev_block=None, token_ids=token_ids[:block_size]) - - assert block.block_id == 0 - assert len(allocator._hashless_allocator._free_block_indices) == 0 - assert allocator._block_tracker[0].active - assert 0 in allocator._cached_blocks.values() - assert allocator._refcounter.get(0) == 1 - assert 0 not in allocator.evictor - - # allocate mutable block again, it shall be popped from evictor - mutable = allocator.allocate_mutable_block(prev_block=None) - assert len(allocator._hashless_allocator._free_block_indices) == 0 - assert mutable.block_id not in allocator.evictor.free_table - assert allocator._refcounter.get(mutable.block_id) == 1 - - # Test case where two last accessed times are equal - @staticmethod - @pytest.mark.parametrize("num_blocks", [1024]) - @pytest.mark.parametrize("block_size", [16]) - @pytest.mark.parametrize("seed", list(range(20))) - def test_eviction_order(num_blocks: int, block_size: int, seed: int): - """This test case simulate the two chain created and free in order, - and together they would exhaust the initial freed blocks. - - So the next block created after those two chain shall use the block - from the first chain as that block has long access time. - While first chain has two blocks, it shall pick up the last one, as - it has larger token number. - """ - - random.seed(seed) - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - num_blocks_to_consume = num_blocks + 1 - - token_ids = list(range(num_blocks_to_consume * block_size)) - - num_blocks_in_first_chain = 2 - num_tokens_in_first_chain = block_size * num_blocks_in_first_chain - # First chain takes the first block - first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids[:num_tokens_in_first_chain], - allocator=allocator, - ) - # There should only be one block allocated at this point - assert allocator.get_num_free_blocks() == (num_blocks - - num_blocks_in_first_chain) - - # Set the last accessed time of the first block to 1 - blocks_ids = [block.block_id for block in first_chain] - allocator.mark_blocks_as_accessed(blocks_ids, 1) - - # Second chain takes the rest of the blocks - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids[num_tokens_in_first_chain:-block_size], - allocator=allocator, - ) - - # There shouldn't be any blocks left at this point - assert allocator.get_num_free_blocks() == (0) - - assert len(first_chain) == num_blocks_in_first_chain - last_block_id = first_chain[-1].block_id - # Free each block in the first chain. - for i, block in enumerate(first_chain): - allocator.free(block) - - # Set the last accessed time on all of the blocks in the second chain - # to 2 - blocks_ids = [block.block_id for block in second_chain] - allocator.mark_blocks_as_accessed(blocks_ids, 2) - - # Free each block in the second chain. - for i, block in enumerate(second_chain): - allocator.free(block) - - # Allocate a new block and check that it's the least recently used block - # from the first chain. - new_block = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids[-block_size:], - allocator=allocator, - ) - - assert new_block[0].block_id == last_block_id - - # Test case for cache mertics - @staticmethod - def test_metric(): - block_size = 16 - allocator = PrefixCachingBlockAllocator(num_blocks=4, - block_size=block_size) - # Test when no query (0/0) - assert allocator.get_prefix_cache_hit_rate() == 0.0 - - token_ids = list(range(block_size)) - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) - # Test 0/1 hit rate - assert allocator.get_prefix_cache_hit_rate() == 0.0 - - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) - # Test 1/2 hit rate - assert allocator.get_prefix_cache_hit_rate() == 0.5 - - # Test more than one block - for _ in range(2, 1005): - allocator.allocate_immutable_block(prev_block=None, - token_ids=token_ids) - assert allocator.get_prefix_cache_hit_rate() > 0.99 - - # Test case for marking cache hit blocks as computed right after - # a batch of prefill sequences are scheduled. - @staticmethod - def test_touch_block(): - block_size = 16 - common_blocks = 4 - allocator = PrefixCachingBlockAllocator(num_blocks=8, - block_size=block_size) - - common_token_ids = list(range(block_size * common_blocks)) - - # Mimic the behavior of allocating the same block chain - # (i.e., common prefix) for a batch of 3 different prefill sequences. - for _ in range(3): - blocks = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=common_token_ids, - allocator=allocator, - ) - block_hashes = [block.content_hash for block in blocks] - # The allocated blocks should be marked as touched - # but not computed. - computed_block_ids = allocator.find_cached_blocks_prefix( - block_hashes) - assert len(computed_block_ids) == 0 - - allocator.mark_blocks_as_computed([]) - computed_block_ids = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes) - assert len(computed_block_ids) == common_blocks - - @staticmethod - def test_find_cached_blocks_prefix(): - """ - This test verifies the behavior of find_cached_blocks_prefix. - """ - block_size = 4 - num_blocks = 8 - total_test_blocks = 12 - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - - token_ids = list(range(total_test_blocks * block_size)) - block_tokens_seq1 = token_ids[:num_blocks * block_size] - blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=block_tokens_seq1, - allocator=allocator, - ) - block_hashes_seq1 = [block.content_hash for block in blocks_seq1] - allocator.mark_blocks_as_computed([]) - - # All blocks should be cached. - cached_blocks_seq1 = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq1) - assert len(cached_blocks_seq1) == num_blocks - - # Free the first sequence. - for block in blocks_seq1: - allocator.free(block) - - # All blocks should be still be cached if not required to be allocated. - cached_blocks = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq1) - assert len(cached_blocks) == num_blocks - - block_tokens_seq2 = token_ids[num_blocks * block_size:] - blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=block_tokens_seq2, - allocator=allocator, - ) - block_hashes_seq2 = [block.content_hash for block in blocks_seq2] - allocator.mark_blocks_as_computed([]) - cached_blocks = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq2) - assert len(cached_blocks) == len(blocks_seq2) - - # Half of the blocks from seq1 should still be cached. - num_evicted_blocks = len(blocks_seq2) - cached_blocks = allocator.find_cached_blocks_prefix( - block_hashes=block_hashes_seq1) - assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks - - # Test reset prefix cache - @staticmethod - @pytest.mark.parametrize("num_blocks", [10]) - @pytest.mark.parametrize("block_size", [16]) - def test_reset_prefix_cache(num_blocks: int, block_size: int): - """This test case simulates the case of resetting the prefix cache.""" - - allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, - block_size=block_size) - token_ids = list(range(3 * block_size)) - - first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=token_ids, - allocator=allocator, - ) - - # Free each block in the first chain. - for block in first_chain: - allocator.free(block) - - # Failed to reset prefix cache because some blocks are not freed yet. - assert not allocator.reset_prefix_cache() - assert allocator.get_prefix_cache_hit_rate() > 0.0 - - # Free each block in the second chain. - for block in second_chain: - allocator.free(block) - - # Reset prefix cache. - assert allocator.reset_prefix_cache() - assert allocator.get_prefix_cache_hit_rate() == 0.0 - - @staticmethod - def create_immutable_chain( - block_size: int, - token_ids: list[int], - allocator: PrefixCachingBlockAllocator, - extra_hash: Optional[int] = None, - ) -> list[PrefixCachingBlock]: - """Helper method which creates a chain of blocks. - """ - blocks: list[Block] = [] - num_blocks = math.ceil(len(token_ids) / block_size) - - if num_blocks == 0: - return [] - - prev_block = None - for block_number in range(0, num_blocks): - block_token_ids = token_ids[block_number * - block_size:(block_number + 1) * - block_size] - prev_block = allocator.allocate_immutable_block( - prev_block=prev_block, - token_ids=block_token_ids, - extra_hash=extra_hash) - blocks.append(prev_block) - - return blocks - - -class TestComputedBlocksTracker: - - @staticmethod - def _get_mock_allocator(): - return MagicMock(spec=PrefixCachingBlockAllocator) - - @staticmethod - def test_get_num_cached_tokens(): - """ - Test it correctly computes the number of cached tokens for a given - sequence: - - - The cache token count is derived from the number of cached blocks. - - The cache token count is updated when the allocator is updated. - - When a sequence is removed, the cache token count should be updated - accordingly. - - # TODO(rickyx): This behaviour for prefill sequence is a hack until - we fix the computed blocks tracking. - - The cache token count for prefill sequence doesn't change while - the sequence is in continuous prefill (chunked prefill). - """ - block_size = 4 - mock_allocator = TestComputedBlocksTracker._get_mock_allocator() - tracker = ComputedBlocksTracker( - allocator=mock_allocator, - block_size=block_size, - enable_caching=True, - ) - - # Not yet allocated. - tokens = [0, 1, 2, 3, 4, 5] - seq1 = create_dummy_sequence(request_id=0, - token_ids=tokens, - block_size=block_size) - mock_allocator.find_cached_blocks_prefix.return_value = [] - assert tracker.get_num_cached_tokens(seq1) == 0 - - mock_allocator.find_cached_blocks_prefix.return_value = [ - None - ] # 1 block cached. - # Result is cached for prefill sequence. - assert tracker.get_num_cached_tokens(seq1) == 0 - - # Mark the sequence as non-prefill. - seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed. - assert not seq1.is_prefill() - - # Recomputes for decoding sequence. - assert tracker.get_num_cached_tokens(seq1) == 4 - - # Append new tokens to the sequence. - num_new_tokens = 3 - for i in range(num_new_tokens): - seq1.append_token_id(i, {i: Logprob(logprob=0.0)}) - - assert tracker.get_num_cached_tokens(seq1) == 4 - - # Update the allocator. - mock_allocator.find_cached_blocks_prefix.return_value = [ - None - ] * 2 # 2 blocks cached. - assert tracker.get_num_cached_tokens(seq1) == 8 - - # Remove the sequence. - tracker.remove_seq(seq1.seq_id) - - # Re-create the sequence with the same request id to simulate recompute. - seq1 = create_dummy_sequence(request_id=0, - token_ids=tokens, - block_size=block_size) - mock_allocator.find_cached_blocks_prefix.return_value = [ - ] # no cached block - assert tracker.get_num_cached_tokens(seq1) == 0 - - @staticmethod - def test_correct_block_hash(): - """ - Test that the block hash is correctly computed for a sequence (should - match the underlying block allocator's block hash). So the number of - cached tokens is correctly retrieved. - """ - block_size = 4 - allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching", - num_gpu_blocks=16, - num_cpu_blocks=16, - block_size=block_size, - ) - gpu_allocator = allocator._allocators[Device.GPU] - - tracker = ComputedBlocksTracker( - allocator=allocator, - block_size=block_size, - enable_caching=True, - ) - - tokens = list(range(block_size * 4)) # 4 blocks. - seq = create_dummy_sequence(request_id=0, - token_ids=tokens, - block_size=block_size) - _ = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=tokens, - allocator=gpu_allocator, - ) - allocator.mark_blocks_as_computed([]) - - assert tracker.get_num_cached_tokens(seq) == len(tokens) - - @staticmethod - def test_correct_extra_hash(): - """ - Test that the block hash is correctly computed based on the extra hash, - ensuring it matches the allocator's block hash, specifically for the - LoRA case, and that the correct number of cached tokens is retrieved. - """ - block_size = 4 - allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching", - num_gpu_blocks=16, - num_cpu_blocks=16, - block_size=block_size, - ) - gpu_allocator = allocator._allocators[Device.GPU] - - tracker = ComputedBlocksTracker( - allocator=allocator, - block_size=block_size, - enable_caching=True, - ) - - tokens = list(range(block_size * 4)) - - # Create a dummy LoRA sequence with a specific LoRA ID. - lora_seq = create_dummy_lora_sequence(request_id=0, - token_ids=tokens, - block_size=block_size, - lora_int_id=1) - - _ = TestPrefixCachingBlockAllocator.create_immutable_chain( - block_size=block_size, - token_ids=tokens, - allocator=gpu_allocator, - extra_hash=lora_seq.extra_hash(), - ) - - allocator.mark_blocks_as_computed([]) - - # Create different dummy sequences that have the same token IDs - # but different LoRA IDs. - seq = create_dummy_sequence(request_id=1, - token_ids=tokens, - block_size=block_size) - - different_lora_seq = create_dummy_lora_sequence(request_id=2, - token_ids=tokens, - block_size=block_size, - lora_int_id=2) - - # Due to the different LoRA IDs, corresponding blocks are not cached. - assert tracker.get_num_cached_tokens(seq) == 0 - assert tracker.get_num_cached_tokens(different_lora_seq) == 0 - - # The number of cached tokens matches the length of the tokens - # for the cached LoRA sequence. - assert tracker.get_num_cached_tokens(lora_seq) == len(tokens) diff --git a/tests/core/conftest.py b/tests/core/conftest.py deleted file mode 100644 index 375b248ebedaa..0000000000000 --- a/tests/core/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py deleted file mode 100644 index ce1fe189b3ca1..0000000000000 --- a/tests/core/test_chunked_prefill_scheduler.py +++ /dev/null @@ -1,858 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from unittest.mock import MagicMock - -import pytest # noqa - -from vllm.config import CacheConfig, SchedulerConfig -from vllm.core.scheduler import Scheduler -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.sampling_params import SamplingParams -from vllm.sequence import Logprob, SequenceGroup - -from .utils import create_dummy_prompt - - -def get_sequence_groups(scheduler_output): - return [s.seq_group for s in scheduler_output.scheduled_seq_groups] - - -def append_new_token(seq_group: SequenceGroup, token_id: int): - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -def schedule_and_update_computed_tokens(scheduler): - metas, out, _ = scheduler.schedule() - for s, meta in zip(out.scheduled_seq_groups, metas): - s.seq_group.update_num_computed_tokens(meta.token_chunk_size) - return metas, out - - -def test_simple(): - """Verify basic scheduling works.""" - block_size = 4 - num_seq_group = 4 - max_model_len = 16 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - num_seq_group, - max_model_len, - enable_chunked_prefill=True) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prompts. - num_tokens = block_size * num_seq_group - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_tokens - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - for s in running: - append_new_token(s, 1) - - # Schedule seq groups generation. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_seq_group - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - - -def test_chunk(): - """Verify prefills are chunked properly.""" - block_size = 4 - max_seqs = 60 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Verify the second request is chunked. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - print() - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 60 - # Verify it is chunked. - assert seq_group_meta[1].token_chunk_size == 4 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - # Only the first seq group has a new token appended. - append_new_token(running[0], 1) - - # One chunked prefill, and one decoding. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - # The first one is prefill. Scheduler guarantees ordering. - assert seq_group_meta[0].token_chunk_size == 56 - # The second one is a chunked prefill. - assert seq_group_meta[1].token_chunk_size == 1 - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 57 - - -def test_concurrent_chunking(): - """Verify prefills are chunked properly when - --max-num-partial-prefills is > 1""" - block_size = 4 - max_seqs = 60 - max_model_len = 2000 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2, # Up to 2 partial prefills at a time - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Verify both requests are chunked with half of max_num_batched_tokens each - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 32 - assert seq_group_meta[1].token_chunk_size == 32 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - - # After one iteration, both should have 60 - 32 = 28 tokens left to prefill - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 28 - assert seq_group_meta[1].token_chunk_size == 28 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 56 - - -def test_concurrent_chunking_large_requests(): - """Verify large prefill requests are run one at a time""" - block_size = 4 - max_seqs = 60 - max_model_len = 2000 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2, # Up to 2 partial prefills at a time - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests - cache_config.num_gpu_blocks = 3200 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt( - str(i), - prompt_length=1200, # Very large prompt - block_size=block_size) - scheduler.add_seq_group(seq_group) - - # Verify only a single request is chunked, and it gets all 64 tokens - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 1 - assert seq_group_meta[0].token_chunk_size == 64 - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 64 - - -def test_short_prompts_jump_long_prompts_in_queue(): - """Verify large prefill requests are punted behind smaller ones if - another large prefill request is already running""" - block_size = 4 - max_seqs = 60 - max_model_len = 2000 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2, # Up to 2 partial prefills at a time - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests - cache_config.num_gpu_blocks = 3200 - scheduler = Scheduler(scheduler_config, cache_config, None) - long_seqs: list[SequenceGroup] = [] - short_seqs: list[SequenceGroup] = [] - - # Add 2 large seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt( - str(i), - prompt_length=1200, # Very large prompt - block_size=block_size) - scheduler.add_seq_group(seq_group) - long_seqs.append(seq_group) - assert seq_group.is_prefill() - - # Add 2 small seq groups behind them - for i in range(2): - _, seq_group = create_dummy_prompt( - str(i + 2), - prompt_length=40, # Very small prompt - block_size=block_size) - scheduler.add_seq_group(seq_group) - short_seqs.append(seq_group) - assert seq_group.is_prefill() - - # Verify one large req and 1 small req chunked - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens - assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens - - # all 4 are prefilling - assert long_seqs[0].is_prefill() - assert long_seqs[1].is_prefill() - assert short_seqs[0].is_prefill() - assert short_seqs[1].is_prefill() - # First short and first long sequences have been scheduled - assert long_seqs[0].first_seq.get_num_computed_tokens() == 32 - assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 - assert short_seqs[0].first_seq.get_num_computed_tokens() == 32 - assert short_seqs[1].first_seq.get_num_computed_tokens() == 0 - - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - - # in the second iteration, - # the first small request had only 8 tokens left - # so it went to decode - # The other small req is scheduled - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - # the new small req got 64 - (32+8) tokens - assert seq_group_meta[0].token_chunk_size == 24 - assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32 - # the other small request had only 8 tokens left - assert seq_group_meta[2].token_chunk_size == 8 # 40-32 - - # The first small request got to decode now - assert long_seqs[0].is_prefill() - assert long_seqs[1].is_prefill() - assert not short_seqs[0].is_prefill() - assert short_seqs[1].is_prefill() - # Both small requests have started in front of the second long request - assert long_seqs[0].first_seq.get_num_computed_tokens() == 64 - assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 - assert short_seqs[0].first_seq.get_num_computed_tokens() == 40 - assert short_seqs[1].first_seq.get_num_computed_tokens() == 24 - - assert out.num_prefill_groups == 3 - assert out.num_batched_tokens == 64 - # the first small seq group has a new token appended. - append_new_token(short_seqs[0], 1) - - # in the third iteration, - # the first small request is already decoding - # the second small request only has 16 tokens left and will enter decoding - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == 32 # large still got 32 - # small req finished prefilling 40-24=16 tokens - assert seq_group_meta[1].token_chunk_size == 16 - assert seq_group_meta[2].token_chunk_size == 1 # decode - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 49 # (32+16+1 decode) - - # both small requests have now reached decode - assert long_seqs[0].is_prefill() - assert long_seqs[1].is_prefill() - assert not short_seqs[0].is_prefill() - assert not short_seqs[1].is_prefill() - assert long_seqs[0].first_seq.get_num_computed_tokens() == 96 - assert long_seqs[1].first_seq.get_num_computed_tokens() == 0 - assert short_seqs[0].first_seq.get_num_computed_tokens() == 41 - assert short_seqs[1].first_seq.get_num_computed_tokens() == 40 - - # both the small seq groups have a new token appended - append_new_token(short_seqs[0], 1) - append_new_token(short_seqs[1], 1) - - # in the fourth iteration, both small requests are decoding - # so large request gets all the budget - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - - # large req gets 62 tokens (minus 2 for decode) - assert seq_group_meta[0].token_chunk_size == 62 - assert seq_group_meta[1].token_chunk_size == 1 # decode - assert seq_group_meta[2].token_chunk_size == 1 # decode - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 64 - - assert long_seqs[0].first_seq.get_num_computed_tokens() == 158 - - # assert long_seqs[0].is_prefill() - # assert long_seqs[1].is_prefill() - # assert not short_seqs[0].is_prefill() - # assert not short_seqs[1].is_prefill() - - # # both the small seq groups have a new token appended - # append_new_token(short_seqs[0], 1) - # append_new_token(short_seqs[1], 1) - - # # in the fifth iteration, large request gets all the budget - # # while both small requests are decoding - # seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - # assert seq_group_meta[0].token_chunk_size == 62 - # assert seq_group_meta[1].token_chunk_size == 1 # decode - # assert seq_group_meta[2].token_chunk_size == 1 # decode - # assert out.num_prefill_groups == 1 - # assert out.num_batched_tokens == 64 - - -def test_complex(): - block_size = 4 - max_seqs = 60 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 64 - cache_config.num_gpu_blocks = 64 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - - # Verify the second request is chunked. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 60 - # Verify it is chunked. - assert seq_group_meta[1].token_chunk_size == 4 - assert not running[0].is_prefill() - assert running[1].is_prefill() - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - # Only the first seq group has a new token appended. - append_new_token(running[0], 1) - - # Add 2 more requests. - for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Decoding & chunked prefill & first chunk of 3rd request is scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 3 - # The first one is the first chunked prefill. - assert seq_group_meta[0].token_chunk_size == 7 - # The second one is the second new chunked prefill. - assert seq_group_meta[1].token_chunk_size == 56 - # The last one is decode. - assert seq_group_meta[2].token_chunk_size == 1 - # Two of them are in chunked prefill. - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 64 - # The first 2 requests are now in decodine phase. - append_new_token(running[0], 1) - assert not running[0].is_prefill() - append_new_token(running[1], 1) - assert not running[1].is_prefill() - # The third request is still in prefill stage. - assert running[2].is_prefill() - - -def test_maximal_decoding(): - """Verify decoding requests are prioritized.""" - block_size = 4 - max_seqs = 2 - max_model_len = 8 - max_num_batched_tokens = 2 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - - # The first prefill is scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 1 - assert seq_group_meta[0].token_chunk_size == 2 - assert not running[0].is_prefill() - assert running[1].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - # Only the first seq group has a new token appended. - append_new_token(running[0], 1) - - # Create one more seq_group. - _, seq_group = create_dummy_prompt("3", - prompt_length=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - # The first decoding + second chunk is scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[0].is_prefill() - assert running[1].is_prefill() - assert running[2].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - append_new_token(running[0], 1) - - # Decoding + running prefill is prioritized. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[0].is_prefill() - assert not running[1].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - append_new_token(running[0], 1) - append_new_token(running[1], 1) - - # Only decoding is prioritized. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[0].is_prefill() - assert not running[1].is_prefill() - assert out.num_prefill_groups == 0 - assert out.num_batched_tokens == 2 - append_new_token(running[0], 1) - append_new_token(running[1], 1) - - # After aborting the decoding request, the fcfs new prefill is prioritized. - scheduler.abort_seq_group(running[0].request_id) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 2 - assert seq_group_meta[0].token_chunk_size == 1 - assert seq_group_meta[1].token_chunk_size == 1 - assert not running[1].is_prefill() - assert running[2].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 2 - - -def test_prompt_limit(): - """Verify max_num_batched_tokens < max_model_len is possible.""" - block_size = 4 - max_seqs = 32 - max_model_len = 64 - max_num_batched_tokens = 32 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - _, seq_group = create_dummy_prompt("1", - prompt_length=48, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - - # The prompt length > max_num_batched_tokens should be still scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 1 - assert seq_group_meta[0].token_chunk_size == 32 - assert running[0].is_prefill() - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == 32 - - -def test_prompt_limit_exceed(): - block_size = 4 - max_seqs = 64 - max_model_len = 32 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("2", - prompt_length=48, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - assert seq_group.is_prefill() - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.ignored_seq_groups) == 1 - assert out.ignored_seq_groups[0] == seq_group - - -def test_chunked_prefill_preempt(): - """Verify preempt works with chunked prefill requests""" - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - # The request should be preempted. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group1(seq_group, num_lookahead_slots): - return seq_group.request_id != "1" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group1) - - # The running prefill is now preempted. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 0 - assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out == [] - assert out.blocks_to_swap_in == [] - - # Make sure we can reschedule preempted request. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - assert seq_group.get_num_uncomputed_tokens() == 30 - - # We should be able to run prefill twice as it is chunked. - def cannot_append_second_group2(seq_group, num_lookahead_slots): - return True - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group2) - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert not seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - -def test_chunked_prefill_spec_prefill(): - """Verify that the num_lookahead_slots is set appropriately for an all""" - """prefill batch.""" - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - num_lookahead_slots = 4 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - num_lookahead_slots=num_lookahead_slots, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=30, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert out.num_batched_tokens == max_num_batched_tokens - print(out.num_lookahead_slots) - assert out.num_lookahead_slots == 0 - - -def test_chunked_prefill_max_seqs(): - block_size = 4 - max_seqs = 2 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 128 - cache_config.num_gpu_blocks = 128 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - _, seq_group = create_dummy_prompt("1", - prompt_length=65, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - # The first prefill is chunked. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens - assert len(get_sequence_groups(out)) == 1 - - # Add new requests. - for i in range(4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=65, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Make sure only 2 requests are scheduled. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert out.num_batched_tokens == max_num_batched_tokens - assert len(get_sequence_groups(out)) == 2 - assert not running[0].is_prefill() - assert running[1].is_prefill() - append_new_token(running[0], 1) - - # Although we have enough token budget, we can only schedule max_seqs. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[0].token_chunk_size == 2 - assert seq_group_meta[1].token_chunk_size == 1 - assert out.num_batched_tokens == 3 - assert len(get_sequence_groups(out)) == max_seqs - assert not running[0].is_prefill() - assert not running[1].is_prefill() - - -def test_prefix_caching(): - """Verify allocating full blocks when prefix caching is enabled.""" - block_size = 4 - max_seqs = 10 - max_model_len = 80 - max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - ) - cache_config = CacheConfig(block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=True) - cache_config.num_cpu_blocks = 0 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - block_size=block_size, - prompt_length=50) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert seq_group_meta[0].token_chunk_size == 50 - # Verify it is chunked. Note that although the budget is 64-50=14, - # we only allocate full blocks for prefix caching, so only 4*(14//4)=12 - # tokens are allocated. - assert seq_group_meta[1].token_chunk_size == 12 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 62 - - -def test_prefix_caching_with_concurrent_partial_prefills(): - """Verify allocating full blocks when prefix caching is enabled with - --max-num-partial-prefills > 1.""" - block_size = 4 - max_seqs = 10 - max_model_len = 8000 - max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens - scheduler_config = SchedulerConfig("generate", - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - max_num_partial_prefills=2) - cache_config = CacheConfig(block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=True) - cache_config.num_cpu_blocks = 0 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - block_size=block_size, - prompt_length=50) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - # To partially prefill both sequences, both can chunk up to 30 tokens - # But the next lowest multiple of the block size (4) is 28 - assert seq_group_meta[0].token_chunk_size == 28 - assert seq_group_meta[1].token_chunk_size == 28 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 56 - - # On the next iteration, both sequences should finish prefill - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - # Both sequences have 50 - 28 = 22 tokens left to prefill. - # This is not a multiple of the block size, but we don't care since we don't - # cache the final partial block of prefix sequences - assert seq_group_meta[0].token_chunk_size == 22 - assert seq_group_meta[1].token_chunk_size == 22 - assert out.num_prefill_groups == 2 - assert out.num_batched_tokens == 44 - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) -def test_chunked_prefill_with_actual_engine(model: str, - max_num_partial_prefills: int): - """Make sure the model can actually sample with concurrent - partial prefills - """ - - prompt = "hello" * 40 - - engine_args = EngineArgs( - model=model, - max_num_partial_prefills=max_num_partial_prefills, - max_num_batched_tokens=40, - max_num_seqs=8, - enable_chunked_prefill=True, - gpu_memory_utilization=0.8, - ) - - engine = LLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(temperature=0) - - for req_num in range(max_num_partial_prefills): - engine.add_request(f"{req_num}", prompt, sampling_params) - # first step - request_outputs = engine.step() - # means all are prefilling - assert len(request_outputs) == 0 - assert len(engine.scheduler[0].running) == max_num_partial_prefills diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py deleted file mode 100644 index 131a7b3a6299b..0000000000000 --- a/tests/core/test_num_computed_tokens_update.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from tests.conftest import VllmRunner -from tests.core.utils import create_dummy_prompt -from vllm.engine.llm_engine import LLMEngine -from vllm.sequence import SequenceGroup - -MODEL = "JackFram/llama-160m" - - -def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup): - scheduler = engine.scheduler[0] - scheduler.add_seq_group(seq_group) - - -@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) -@pytest.mark.parametrize("enforce_eager", [False, True]) -def test_num_computed_tokens_update(enable_chunked_prefill: bool, - enforce_eager: bool): - - # Make a vllm engine - runner = VllmRunner(model_name=MODEL, - gpu_memory_utilization=0.7, - enable_chunked_prefill=enable_chunked_prefill, - enforce_eager=enforce_eager) - engine: LLMEngine = runner.llm.llm_engine - - num_prompt_steps = 1 - - num_output_tokens_list = [4, 8, 12, 15, 16, 17] - - # Create sequence and add to engine - prompt_len = 10 - - for req_idx, num_output_tokens in enumerate(num_output_tokens_list): - seq, seq_group = create_dummy_prompt(request_id=str(req_idx), - prompt_length=prompt_len, - min_tokens=num_output_tokens, - max_tokens=num_output_tokens) - add_seq_group_to_engine(engine, seq_group) - - assert seq.data.get_num_computed_tokens() == 0 - - for _ in range(num_prompt_steps): - # prompt steps - engine.step() - - if not seq.is_finished(): - prompt_num_computed_tokens = seq.data.get_num_computed_tokens() - # Test correctness of num_computed_tokens after the prompt steps - assert prompt_num_computed_tokens == \ - prompt_len + num_prompt_steps - 1 - - decode_step_counter = 0 - while not seq.is_finished(): - # Test correctness of num_computed_tokens after the decode steps - assert seq.data.get_num_computed_tokens( - ) == prompt_num_computed_tokens + decode_step_counter - engine.step() - decode_step_counter += 1 - - # Test correctness of num_computed_tokens after the sequence finish. - assert seq.data.get_num_computed_tokens( - ) == prompt_len + num_output_tokens - 1 diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py deleted file mode 100644 index 591e1780c11c6..0000000000000 --- a/tests/core/test_scheduler.py +++ /dev/null @@ -1,1337 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import time -from collections import deque -from typing import Optional -from unittest.mock import MagicMock - -import pytest # noqa -import torch -from torch import Use # noqa - -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.core.interfaces import AllocStatus -from vllm.core.scheduler import Scheduler, SchedulingBudget -from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroup, SequenceStatus - -from .utils import (append_new_token, append_new_token_seq, - append_new_token_seq_group, create_dummy_prompt, - get_sequence_groups, schedule_and_update_computed_tokens) - - -def test_scheduler_add_seq_group(): - block_size = 4 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=100, - max_num_seqs=64, - max_model_len=1, - ) - cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") - cache_config.num_cpu_blocks = 4 - cache_config.num_gpu_blocks = 4 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq group to scheduler. - num_seq_group = 4 - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - assert scheduler.get_num_unfinished_seq_groups() == i + 1 - - -def test_scheduler_abort_seq_group(): - block_size = 4 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=100, - max_num_seqs=64, - max_model_len=1, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 4 - cache_config.num_gpu_blocks = 4 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add multiple seq groups to scheduler. - num_seq_group = 4 - request_ids: set[str] = set() - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), block_size) - scheduler.add_seq_group(seq_group) - request_ids.add(str(i)) - - # Abort all added seq groups. - assert scheduler.get_num_unfinished_seq_groups() == num_seq_group - scheduler.abort_seq_group(request_ids) - assert scheduler.get_num_unfinished_seq_groups() == 0 - - -def test_scheduler_schedule_simple(): - block_size = 4 - num_seq_group = 4 - max_model_len = 16 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=64, - max_num_seqs=num_seq_group, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prompts. - num_tokens = block_size * num_seq_group - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_tokens - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - append_new_token(out, 1) - - # Schedule seq groups generation. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set(running) - assert out.num_batched_tokens == num_seq_group - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == num_seq_group - append_new_token(out, 1) - - -def test_scheduler_prefill_prioritized(): - """Verify running batched tokens are not applied to prefill requests.""" - block_size = 4 - max_model_len = 30 - max_batched_num_tokens = 30 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=max_batched_num_tokens, - max_num_seqs=2, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size) - scheduler.add_seq_group(seq_group_a) - - # Schedule seq groups prompts. - _, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_a] - - # Add a new prefill request B. - _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size) - scheduler.add_seq_group(seq_group_b) - - # Verify prefill requests are prioritized. Since max_batched_num_tokens - # is 1, new prefill request has to be scheduled first. - _, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_b] - - -def test_scheduler_schedule_preempt_abort(): - block_size = 4 - max_model_len = 16 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=64, - max_num_seqs=2, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 2 - cache_config.num_gpu_blocks = 2 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # Add seq groups to scheduler. - seq_a, seq_group_a = create_dummy_prompt("1", - block_size, - block_size=block_size) - seq_b, seq_group_b = create_dummy_prompt("2", - block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group_a) - scheduler.add_seq_group(seq_group_b) - - # Schedule seq groups prompts. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_a, seq_group_b] - assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 2 - assert scheduler.get_num_unfinished_seq_groups() == 2 - - # Append "generated" tokens, allowing the sequence to mark prompt tokens as - # processed. - append_new_token(out, 1) - - # Schedule seq groups generation and preempt seq group b. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_a] - assert out.num_batched_tokens == 1 - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert scheduler.get_num_unfinished_seq_groups() == 2 - assert out.preempted == 1 - - # Abort seq group a. Re-schedule seq group b prompt with recomputation. - scheduler.abort_seq_group("1") - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert get_sequence_groups(out) == [seq_group_b] - assert out.num_batched_tokens == 5 # 4 prompt + 1 generation. - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - assert len(seq_group_meta) == 1 - assert scheduler.get_num_unfinished_seq_groups() == 1 - - -def test_scheduler_max_seqs(): - block_size = 4 - num_seq_group = 4 - max_seq_group = 2 - max_model_len = 16 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=64, - max_num_seqs=max_seq_group, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - - all_seq_groups: list[SequenceGroup] = [] - # Add seq groups to scheduler. - for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=block_size, - block_size=block_size) - all_seq_groups.append(seq_group) - - # Append 1 seq group - scheduler.add_seq_group(all_seq_groups[0]) - - # Schedule seq groups prompts. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) - append_new_token(out, 1) - - # Schedule seq groups generation. - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set([all_seq_groups[0]]) - append_new_token(out, 1) - - # Append 2 more seq group - scheduler.add_seq_group(all_seq_groups[1]) - scheduler.add_seq_group(all_seq_groups[2]) - - # Schedule seq groups prompts. - # Only 1 seq group should be scheduled since max_seq_group is 2 - # and one is prompting. - _, out = schedule_and_update_computed_tokens(scheduler) - assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) - - -def test_scheduler_delay_factor(): - block_size = 4 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=100, - max_num_seqs=64, - max_model_len=16, - delay_factor=0.5, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 - scheduler = Scheduler(scheduler_config, cache_config, None) - - # schedule first prompt - seq_group_meta, seq_group = create_dummy_prompt("0", - prompt_length=block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert out.num_prefill_groups > 0 - assert seq_group_meta[0].request_id == '0' - append_new_token(out, 1) - - # wait for a second before scheduling next prompt - time.sleep(1) - seq_group_meta, seq_group = create_dummy_prompt("1", - prompt_length=block_size, - block_size=block_size) - scheduler.add_seq_group(seq_group) - - # second prompt should *not* be scheduled - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert out.num_prefill_groups == 0 - assert seq_group_meta[0].request_id == '0' - append_new_token(out, 1) - - # wait for more than 0.5 second and try again - time.sleep(0.6) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert out.num_prefill_groups > 0 - assert seq_group_meta[0].request_id == '1' - append_new_token(out, 1) - - -def initialize_scheduler( - *, - max_num_seqs=1000, - max_token_budget=1000, - max_model_len=1000, - lora_config=None, - block_size=4, - num_cpu_blocks=8, - num_gpu_blocks=8, - enable_prefix_caching=False, - enable_chunked_prefill=False, -): - block_size = block_size - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=max_token_budget, - max_num_seqs=max_num_seqs, - max_model_len=max_model_len, - enable_chunked_prefill=enable_chunked_prefill, - ) - cache_config = CacheConfig( - block_size, - 1.0, - 1, - "auto", - enable_prefix_caching=enable_prefix_caching, - ) - cache_config.num_cpu_blocks = num_cpu_blocks - cache_config.num_gpu_blocks = num_gpu_blocks - scheduler = Scheduler(scheduler_config, cache_config, lora_config) - return scheduler - - -def create_token_budget(token_budget: int = 10000, - max_num_seqs: int = 10000) -> SchedulingBudget: - return SchedulingBudget( - token_budget=token_budget, - max_num_seqs=max_num_seqs, - ) - - -def add_token_budget(budget: SchedulingBudget, - num_batched_tokens: int = 0, - num_curr_seqs: int = 0): - mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1] - budget.add_num_batched_tokens(mock_seq_group.request_id, - num_batched_tokens) - budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs) - - -def test_prefill_schedule_max_prompt_len(): - """ - Test prompt longer than max_prompt_len is aborted. - """ - block_size = 4 - scheduler = initialize_scheduler(max_model_len=30, block_size=block_size) - _, seq_group = create_dummy_prompt("0", - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - budget = create_token_budget() - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 1 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 0 - - -def test_prefill_schedule_token_budget(): - """ - Test token budget respected. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - budget = create_token_budget(token_budget=0) - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - - # 0 token budget == nothing is scheduled. - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 2 - - # 60 token budget == 1 request scheduled. - budget = create_token_budget(token_budget=60) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 1 - assert budget.num_batched_tokens == 60 - assert budget.num_curr_seqs == 1 - assert len(remaining_waiting) == 1 - - # Test when current_batched_tokens respected. - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16) - budget = create_token_budget(token_budget=60) - add_token_budget(budget, 30, 0) - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - # Cannot schedule a prompt that doesn't fit the budget. - scheduler.add_seq_group(seq_group) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 30 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 1 - budget = create_token_budget(token_budget=90) - add_token_budget(budget, 30, 0) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.seq_groups) == 1 - assert budget.num_batched_tokens == 90 - assert budget.num_curr_seqs == 1 - assert len(remaining_waiting) == 0 - - -def test_prefill_schedule_max_seqs(): - """ - Test max seq respected. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - budget = create_token_budget(max_num_seqs=2) - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 2 - assert budget.num_batched_tokens == 120 - assert budget.num_curr_seqs == 2 - assert len(remaining_waiting) == 1 - - # Verify curr_num_seqs respected. - scheduler.waiting = deque() - budget = create_token_budget(max_num_seqs=2) - add_token_budget(budget, 0, 2) - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 2 - assert len(remaining_waiting) == 1 - - -def test_prefill_schedule_max_lora(): - """ - Test max lora is respected and prioritized. - """ - block_size = 4 - lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config, - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - budget = create_token_budget(token_budget=120) - curr_loras: set[int] = set() - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size, - lora_request=LoRARequest( - lora_name=str(i), - lora_int_id=i + 1, - lora_path="abc")) - scheduler.add_seq_group(seq_group) - # Add two more requests to verify lora is prioritized. - # 0: LoRA, 1: LoRA, 2: regular, 3: regular - # In the first iteration, index 0, 2 is scheduled. - # If a request is not scheduled because it hits max lora, it is - # prioritized. Verify that. - for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - # Schedule 2 requests (0 and 2) - output = scheduler._schedule_prefills(budget, curr_loras) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 2 - assert budget.num_batched_tokens == 120 - assert budget.num_curr_seqs == 2 - assert len(remaining_waiting) == 2 - assert len(curr_loras) == 1 - # The second lora request is scheduled next as FCFS policy. - # Reset curr_loras so that it can be scheduled. - curr_loras = set() - budget = create_token_budget(token_budget=60) - output = scheduler._schedule_prefills(budget, curr_loras) - remaining_waiting = scheduler.waiting - assert len(output.seq_groups) == 1 - assert output.seq_groups[0].seq_group.request_id == "1" - assert len(remaining_waiting) == 1 - assert len(curr_loras) == 1 - assert budget.num_batched_tokens == 60 - - -def test_prefill_schedule_no_block_manager_capacity(): - """ - Test sequence cannot be scheduled due to block manager has no capacity. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_gpu_blocks=128, - num_cpu_blocks=128) - budget = create_token_budget() - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - scheduler.block_manager.can_allocate = MagicMock() - scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 0 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 3 - - scheduler = initialize_scheduler() - budget = create_token_budget() - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group) - scheduler.block_manager.can_allocate = MagicMock() - scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER - output = scheduler._schedule_prefills(budget, None) - remaining_waiting = scheduler.waiting - assert len(output.ignored_seq_groups) == 3 - assert len(output.seq_groups) == 0 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(remaining_waiting) == 0 - - -def test_decode_schedule_preempted(): - """ - Test decodes cannot be scheduled and preempted. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - curr_loras = None - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._add_seq_group_to_running(seq_group) - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "1" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - - # 1 cannot be scheduled, and the lowest priority (request 2) - # should be preempted. 1 will also be preempted. - budget = create_token_budget() - output = scheduler._schedule_running(budget, curr_loras) - remaining_running = scheduler.running - assert len(remaining_running) == 0 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - assert output.decode_seq_groups[0].seq_group.request_id == "0" - assert len(output.preempted) == 2 - # Verify budgets are updated. - assert budget.num_batched_tokens == 1 - # NOTE: When enable_chunk is False, num_seqs budget is not updated. - # assert budget.num_curr_seqs == 1 - # Both should be preempted, not swapped. - assert output.blocks_to_swap_out == [] - # Nothing is copied. - assert output.blocks_to_copy == [] - - -def test_schedule_decode_blocks_to_copy_update(): - """ - Verify blocks_to_copy is updated. - """ - block_size = 4 - scheduler = initialize_scheduler(block_size=4, - num_cpu_blocks=16, - num_gpu_blocks=16) - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) - curr_loras = None - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._add_seq_group_to_running(seq_group) - - # The last request should be swapped out. - scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = [(2, 3)] - - budget = create_token_budget() - output = scheduler._schedule_running(budget, curr_loras) - remaining_running = scheduler.running - assert len(remaining_running) == 0 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - assert len(output.preempted) == 0 - assert len(output.swapped_out) == 0 - # Nothing is preempted. - assert output.blocks_to_swap_out == [] - # Since append_slot returns the source -> dist mapping, it should - # applied. - assert output.blocks_to_copy == [(2, 3)] - - -def test_schedule_swapped_max_loras(): - block_size = 4 - lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config, - block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras: set[int] = set() - blocks_to_swap_out: list[tuple[int, int]] = [] - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size, - lora_request=LoRARequest( - lora_name=str(i), - lora_int_id=i + 1, - lora_path="abc")) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 1 - assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 1 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - assert len(curr_loras) == 1 - - -def test_schedule_swapped_cannot_swap_in(): - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras = None - blocks_to_swap_out: list[tuple[int, int]] = [] - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - # The last request should be swapped out. - scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER - # Since we cannot swap in, none of the requests are swapped in. - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 2 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(output.decode_seq_groups) == 0 - assert len(output.prefill_seq_groups) == 0 - - -def test_infeasible_swap(): - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras = None - blocks_to_swap_out: list[tuple[int, int]] = [] - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - # The last request should be swapped out. - scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER - # Since we cannot swap in, none of the requests are swapped in. - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 0 - assert len(output.infeasible_seq_groups) == 2 - assert budget.num_batched_tokens == 0 - assert budget.num_curr_seqs == 0 - assert len(output.decode_seq_groups) == 0 - assert len(output.prefill_seq_groups) == 0 - - -def test_schedule_swapped_blocks_to_copy(): - block_size = 4 - scheduler = initialize_scheduler(block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras = None - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - blocks_to_swap_out: list[tuple[int, int]] = [] - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - # The last request should be swapped out. - scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = [(2, 3)] - - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 0 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - assert output.blocks_to_copy == [(2, 3)] - - -def test_scheduling_budget(): - TOKEN_BUDGET = 4 - MAX_SEQS = 4 - budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS) - assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1) - assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4) - assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5) - assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1) - assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5) - assert budget.remaining_token_budget() == TOKEN_BUDGET - - # Verify add/subtract num batched tokens. - _, seq_group = create_dummy_prompt("1", 3) - budget.add_num_batched_tokens(seq_group.request_id, 2) - assert budget.remaining_token_budget() == 2 - assert budget.num_batched_tokens == 2 - assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1) - assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1) - # Verify adding another seq group is no-op. - budget.add_num_batched_tokens(seq_group.request_id, 2) - assert budget.remaining_token_budget() == 2 - assert budget.num_batched_tokens == 2 - budget.subtract_num_batched_tokens(seq_group.request_id, 2) - assert budget.remaining_token_budget() == 4 - assert budget.num_batched_tokens == 0 - budget.subtract_num_batched_tokens(seq_group.request_id, 2) - assert budget.remaining_token_budget() == 4 - assert budget.num_batched_tokens == 0 - - # Verify add/subtract max seqs. - _, seq_group = create_dummy_prompt("1", 3) - budget.add_num_seqs(seq_group.request_id, 2) - assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2) - assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3) - assert budget.num_curr_seqs == 2 - # Verify adding another seq group is no-op. - budget.add_num_seqs(seq_group.request_id, 2) - assert budget.num_curr_seqs == 2 - budget.subtract_num_seqs(seq_group.request_id, 2) - assert budget.num_curr_seqs == 0 - budget.subtract_num_seqs(seq_group.request_id, 2) - assert budget.num_curr_seqs == 0 - - -@pytest.mark.parametrize("enable_prefix_caching", [True, False]) -def test_prefix_caching_aware_prefills(enable_prefix_caching): - """ - Test the below scenario: - - For 3 sequences, seqA, seqB, seqC, share the first block as prefix. - - The test verifies the below scenarios: - 1. SeqA is first scheduled. - 2. SeqB and SeqC can be prefilled together in a single schedule round - even though there are not enough token budgets to prefill both without - considering prefix caching. - """ - - block_size = 4 - max_num_batched_tokens = 12 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_token_budget=max_num_batched_tokens, - max_num_seqs=max_seq_group, - max_model_len=max_num_batched_tokens, - enable_prefix_caching=enable_prefix_caching, - ) - - seqA_tokens = list(range(8)) - num_shared_tokens = 4 - seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range( - 12, 16)) # Shared prefix first 4. - seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range( - 16, 20)) # Shared prefix first 4. - - seqA, seqA_group = create_dummy_prompt("0", - prompt_tokens=seqA_tokens, - block_size=block_size) - seqB, seqB_group = create_dummy_prompt("1", - prompt_tokens=seqB_tokens, - block_size=block_size) - seqC, seqC_group = create_dummy_prompt("2", - prompt_tokens=seqC_tokens, - block_size=block_size) - - # Schedule seqA prefill. - scheduler.add_seq_group(seqA_group) - metas, out, _ = scheduler.schedule() - assert (len(out.scheduled_seq_groups) == 1 - and out.scheduled_seq_groups[0].seq_group == seqA_group) - assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens) - - # Schedule seqA decode. - append_new_token_seq_group(len(seqA_tokens), seqA_group, 999) - metas, out, _ = scheduler.schedule() - - assert len(out.scheduled_seq_groups) == 1 - assert out.scheduled_seq_groups[0].seq_group == seqA_group - assert out.scheduled_seq_groups[0].token_chunk_size == 1 - - # Schedule seqB and seqC prefills should work with prefix caching. - scheduler.add_seq_group(seqB_group) - scheduler.add_seq_group(seqC_group) - metas, out, _ = scheduler.schedule() - - if enable_prefix_caching: - assert len(out.scheduled_seq_groups) == 2 - assert set([ - out.scheduled_seq_groups[0].seq_group, - out.scheduled_seq_groups[1].seq_group, - ]) == set([seqB_group, seqC_group]) - assert len(metas) == 2 - for meta in metas: - assert meta.token_chunk_size == 8 - assert (len(meta.computed_block_nums) == num_shared_tokens // - block_size) # 1 Block for the 8 tokens. - else: - assert len(out.scheduled_seq_groups) == 1 - assert len(metas) == 1 - assert metas[0].token_chunk_size == 8 - assert len(metas[0].computed_block_nums) == 0 # No blocks computed. - - -def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( -): - """ - This test verifies that we don't schedule new prefills if there's already - a continuous prefill in progress even though the new prefills with shared - prefix can fit in the token budget: - - - SeqA is being chunked prefill. - - SeqB with the same prompt shouldn't be scheduled for prefill even though - there's enough token budget to prefill the cached tokens. - - Neither should seqC be scheduled. - - - When seqA is in decoding phase, seqB and seqC can be scheduled. - - Entire seqB should be prefilled since it's a full prefix cache hit. - - SeqC would be partially prefilled with the prefix shared, and the - remaining unique tokens would be prefilled (rounded down to be - block-size aligned). - """ - - block_size = 2 - max_num_batched_tokens = 4 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_token_budget=max_num_batched_tokens, - max_num_seqs=max_seq_group, - max_model_len=100, - enable_prefix_caching=True, - enable_chunked_prefill=True, - ) - - seqA_tokens = list(range(8)) - seqB_tokens = seqA_tokens - seqC_shared_prefix_len = 4 - seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20)) - - seqA, seqA_group = create_dummy_prompt("0", - prompt_tokens=seqA_tokens, - block_size=block_size) - seqB, seqB_group = create_dummy_prompt("1", - prompt_tokens=seqB_tokens, - block_size=block_size) - - # Chunked prefill seqA. - scheduler.add_seq_group(seqA_group) - metas, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.scheduled_seq_groups[0].seq_group == seqA_group - assert out.scheduled_seq_groups[0].token_chunk_size == 4 - - # seqB should not be scheduled with ongoing prefills. - scheduler.add_seq_group(seqB_group) - metas, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.scheduled_seq_groups[0].seq_group == seqA_group - assert out.scheduled_seq_groups[0].token_chunk_size == 4 - - # both seqB and seqC can now be scheduled with seqA is over. - # seqA is in decoding phase. - append_new_token_seq(seqA, 999) - seqC, seqC_group = create_dummy_prompt("2", - prompt_tokens=seqC_tokens, - block_size=block_size) - scheduler.add_seq_group(seqC_group) - metas, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 3 - - metas = {meta.request_id: meta for meta in metas} - assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode - assert (metas[seqB_group.request_id].token_chunk_size == 8 - ) # Fully cached prefill - assert ( - metas[seqC_group.request_id].token_chunk_size == 6 - ), "A partial prefix of C (4 tokens) should be prefilled, with the " - "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " - "then be rounded down to 2 tokens on block size, thus 6 tokens in total." - - -def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): - """ - Test that the scheduler does not schedule batches with prompt tokens and - prompt embeddings co-mingled. - """ - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=100, - enable_prefix_caching=True, - ) - - # the odd indexed inputs should be passed in via embeddings, - # evens via token_ids - seq_length = 7 - embedding_size = 5 - num_seqs = 11 - seq_tokens: list[list[int]] = [] - seq_embeds: list[Optional[torch.Tensor]] = [] - for i in range(num_seqs): - if i % 2: - seq_tokens.append(list(range(seq_length))) - seq_embeds.append(None) - else: - seq_tokens.append([0] * seq_length) - seq_embeds.append(torch.rand(embedding_size)) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens[i], - prompt_embeds=seq_embeds[i], - block_size=block_size) - for i in range(len(seq_tokens)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - while not all(seq.is_finished() for seq, _ in seq_and_seq_groups): - unfinished_seq_groups = [ - seq_group for _, seq_group in seq_and_seq_groups - if not seq_group.is_finished() - ] - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) > 0 - batch_is_prompt_embeds = out.scheduled_seq_groups[ - 0].seq_group.uses_prompt_embeds() - expected_scheduled_seq_groups = [ - seq_group for seq_group in unfinished_seq_groups - if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds - ] - - # We should have as many scheduled groups as possible, without mixing - assert len(out.scheduled_seq_groups) == min( - max_seq_group, len(expected_scheduled_seq_groups)) - assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() == - batch_is_prompt_embeds - for scheduled_seq_group in out.scheduled_seq_groups) - - # Finish the scheduled groups - for scheduled_seq_group in out.scheduled_seq_groups: - for seq in scheduled_seq_group.seq_group.seqs: - seq.status = SequenceStatus.FINISHED_STOPPED - scheduler.free_finished_seq_groups() - - -def test_remove_seq_from_computed_blocks_tracker(): - """ - Test that computed_blocks_tracker correctly removes stale sequences - during scheduling. - - The test covers 9 scheduling branches where stale seqs are removed: - - 1 in _schedule_swapped - - 1 in _schedule_priority_preemption - - 7 in _schedule_prefill - - Each branch is tested to ensure proper cleanup of - _seq_id_to_num_tokens_computed. - """ - # Budget can not schedule in swapped - block_size = 2 - max_seq_group = 3 - seq_tokens_with_swapped: list[list[int]] = [] - blocks_to_swap_out: list[tuple[int, int]] = [] - curr_loras: set[int] = set() - - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - enable_prefix_caching=True, - ) - budget = create_token_budget(token_budget=15) - - seq_length = 16 - num_seqs = 3 - for i in range(num_seqs): - seq_tokens_with_swapped.append([i] * seq_length) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_with_swapped[i], - block_size=block_size) - for i in range(len(seq_tokens_with_swapped)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler._allocate_and_set_running(seq_group) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - scheduler._schedule_swapped(budget, curr_loras) - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Prefill schedule don't have a space for another LoRA, so - # we ignore this request for now. - block_size = 4 - lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config, - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64, - enable_prefix_caching=True) - budget = create_token_budget(token_budget=120) - num_seqs = 2 - for i in range(num_seqs): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=seq_length, - block_size=block_size, - lora_request=LoRARequest( - lora_name=str(i), - lora_int_id=i + 1, - lora_path="abc")) - scheduler.add_seq_group(seq_group) - - scheduler._schedule_prefills(budget, curr_loras) - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Priority preemption schedule - scheduler._schedule_priority_preemption(budget) - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Prefill scheduler does not schedule batches with prompt tokens and - # prompt embeddings co-mingled. - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=100, - enable_prefix_caching=True, - ) - seq_length = 7 - embedding_size = 5 - seq_tokens_with_embedding: list[list[int]] = [] - seq_embeds: list[Optional[torch.Tensor]] = [] - - seq_tokens_with_embedding.append(list(range(seq_length))) - seq_embeds.append(None) - seq_tokens_with_embedding.append([0] * seq_length) - seq_embeds.append(torch.rand(embedding_size)) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_with_embedding[i], - prompt_embeds=seq_embeds[i], - block_size=block_size) - for i in range(len(seq_tokens_with_embedding)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Prefill scheduler budget num_batched_tokens - # >= scheduler_config max_num_batched_tokens - block_size = 2 - max_seq_group = 3 - seq_tokens_prefill_budget: list[list[int]] = [] - - scheduler = initialize_scheduler( - block_size=block_size, - max_token_budget=8, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=5, - enable_prefix_caching=True, - ) - seq_length = 4 - num_seqs = 3 - for i in range(num_seqs): - seq_tokens_prefill_budget.append([i] * seq_length) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_prefill_budget[i], - block_size=block_size) - for i in range(len(seq_tokens_prefill_budget)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(2)) - assert seq_id_to_num_tokens_computed is None - - # Budget can not schedule in waiting - block_size = 2 - max_seq_group = 3 - - scheduler = initialize_scheduler( - block_size=block_size, - max_token_budget=30, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=30, - enable_prefix_caching=True, - ) - seq_length = 16 - num_seqs = 3 - seq_tokens_prefill_budget_waiting: list[list[int]] = [] - - for i in range(num_seqs): - seq_tokens_prefill_budget_waiting.append(list(range(seq_length))) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_prefill_budget_waiting[i], - block_size=block_size) - for i in range(len(seq_tokens_prefill_budget_waiting)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None - - # Sequence num_new_tokens > prompt_limit marked FINISHED_IGNORED - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=16, - num_gpu_blocks=16, - max_num_seqs=max_seq_group, - max_model_len=30, - enable_prefix_caching=True, - ) - - seq_length = 31 - seq_tokens_prompt_limit: list[list[int]] = [] - seq_tokens_prompt_limit.append(list(range(seq_length))) - seq_and_seq_groups = [ - create_dummy_prompt("0", - prompt_tokens=seq_tokens_prompt_limit[0], - block_size=block_size) - ] - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(0)) - assert seq_id_to_num_tokens_computed is None - - # Budget can not allocate, AllocStatus is NEVER marked FINISHED_IGNORED - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=160, - num_gpu_blocks=160, - max_num_seqs=max_seq_group, - max_model_len=320, - enable_prefix_caching=True, - ) - - seq_length = 320 - num_seqs = 1 - seq_tokens_never: list[list[int]] = [] - for i in range(num_seqs): - seq_tokens_never.append(list(range(seq_length))) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_never[i], - block_size=block_size) - for i in range(len(seq_tokens_never)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(0)) - assert seq_id_to_num_tokens_computed is None - - # Budget can not allocate, AllocStatus is LATER - block_size = 2 - max_seq_group = 3 - scheduler = initialize_scheduler( - block_size=block_size, - num_cpu_blocks=160, - num_gpu_blocks=160, - max_num_seqs=max_seq_group, - max_model_len=320, - enable_prefix_caching=True, - ) - - seq_length = 160 - num_seqs = 2 - seq_tokens_later: list[list[int]] = [] - for i in range(num_seqs): - seq_tokens_later.append(list(range(seq_length))) - - seq_and_seq_groups = [ - create_dummy_prompt(f"{i}", - prompt_tokens=seq_tokens_later[i], - block_size=block_size) - for i in range(len(seq_tokens_later)) - ] - - for _, seq_group in seq_and_seq_groups: - scheduler.add_seq_group(seq_group) - - scheduler._schedule_default() - seq_id_to_num_tokens_computed = ( - scheduler.block_manager._computed_blocks_tracker. - _seq_id_to_num_tokens_computed.get(1)) - assert seq_id_to_num_tokens_computed is None diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py deleted file mode 100644 index 20cc083ec8db4..0000000000000 --- a/tests/core/test_scheduler_encoder_decoder.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest # noqa - -from vllm.config import CacheConfig, SchedulerConfig -from vllm.core.scheduler import Scheduler -from vllm.sequence import SequenceGroup - -from .utils import (append_new_token, create_dummy_prompt_encoder_decoder, - get_sequence_groups, schedule_and_update_computed_tokens) - - -def test_scheduler_schedule_simple_encoder_decoder(): - ''' - Test basic scheduler functionality in the context - of an encoder/decoder model. Focus on testing - enc/dec-specific functionality sense tests already - exist for decoder-only functionality - - Test behavior: - * Construct Scheduler - * Construct dummy encoder/decoder sequence groups - * Add dummy seq groups to scheduler backlog - * Schedule the next seq group & validate: - * Cross-attn block tables - * Updated states of seq groups - * Number of batched tokens - * Number of blocks to copy/swap-in/swap-out - * Number of scheduled seq groups - * Repeat for both prefill- and decode-phase - * Abort scheduled seq groups - * Assert that aborted seq groups no longer appear in - cross-attention block table - ''' - - block_size = 4 - num_seq_group = 4 - max_model_len = 16 - scheduler_config = SchedulerConfig( - "generate", - max_num_batched_tokens=64, - max_num_seqs=num_seq_group, - max_model_len=max_model_len, - ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group - cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group - scheduler = Scheduler(scheduler_config, cache_config, None) - running: list[SequenceGroup] = [] - - # Add seq groups to scheduler. - req_id_list = [] - for i in range(num_seq_group): - req_id = str(i) - req_id_list.append(req_id) - _, _, seq_group = create_dummy_prompt_encoder_decoder( - req_id, block_size, block_size, block_size) - scheduler.add_seq_group(seq_group) - running.append(seq_group) - - # Schedule seq groups prefill. - num_tokens = block_size * num_seq_group - seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) - # - Verify that sequence group cross-attention block tables are - # registered with the block manager - assert all([(req_id in scheduler.block_manager.cross_block_tables) - for req_id in req_id_list]) - # - Validate sequence-group status - assert set(get_sequence_groups(out)) == set(running) - # - Validate number of batched tokens - assert out.num_batched_tokens == num_tokens - # - Validate there are no remaining blocks to swap - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - # - Validate all seq groups were scheduled - assert len(seq_group_meta_list) == num_seq_group - append_new_token(out, 1) - - # Schedule seq groups decode. - seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) - # - Verify that sequence group metadata includes encoder attention - # and cross-attention metadata - assert all([ - not ((seq_group_meta.encoder_seq_data is None) or - (seq_group_meta.cross_block_table is None)) - for seq_group_meta in seq_group_meta_list - ]) - # - Validate sequence-group status - assert set(get_sequence_groups(out)) == set(running) - # - Validate there is one batched token per seq group - assert out.num_batched_tokens == num_seq_group - # - Validate there are no remaining blocks to swap - assert (not out.blocks_to_copy and not out.blocks_to_swap_in - and not out.blocks_to_swap_out) - # - Validate that all seq groups were scheduled - assert len(seq_group_meta_list) == num_seq_group - append_new_token(out, 1) - - # Abort sequences - for req_id in req_id_list: - scheduler.abort_seq_group(req_id) - # - Verify that sequence group cross-attention block tables are - # NO LONGER registered with the block manager - assert req_id not in scheduler.block_manager.cross_block_tables diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py deleted file mode 100644 index ee9ac2129f2db..0000000000000 --- a/tests/core/test_serialization.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import msgspec - -from vllm.executor.msgspec_utils import decode_hook, encode_hook -from vllm.sequence import ExecuteModelRequest - -from .utils import create_batch - - -def test_msgspec_serialization(): - num_lookahead_slots = 4 - seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots) - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=num_lookahead_slots, - running_queue_size=4) - - encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, - dec_hook=decode_hook) - req = decoder.decode(encoder.encode(execute_model_req)) - expected = execute_model_req.seq_group_metadata_list - actual = req.seq_group_metadata_list - assert (len(expected) == len(actual)) - expected = expected[0] - actual = actual[0] - - assert expected.block_tables == actual.block_tables - assert expected.is_prompt == actual.is_prompt - assert expected.request_id == actual.request_id - assert (expected.seq_data[0].prompt_token_ids == - actual.seq_data[0].prompt_token_ids) - assert (expected.seq_data[0].output_token_ids == - actual.seq_data[0].output_token_ids) diff --git a/tests/core/utils.py b/tests/core/utils.py deleted file mode 100644 index 033fffd2c4e24..0000000000000 --- a/tests/core/utils.py +++ /dev/null @@ -1,392 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import time -from collections import defaultdict -from collections.abc import Sequence as GenericSequence -from itertools import count -from typing import Any, Optional, Union - -import torch - -from vllm.core.scheduler import Scheduler, SchedulerOutputs -from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata) - - -def create_dummy_prompt( - request_id: str, - prompt_length: int = -1, - block_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_tokens: Optional[list[int]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - min_tokens: int = 0, - max_tokens: int = 16, -) -> tuple[Sequence, SequenceGroup]: - if not block_size: - block_size = prompt_length - - if prompt_tokens is None: - # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". - prompt_tokens = list(range(prompt_length)) - - prompt_str = " ".join([str(t) for t in prompt_tokens]) - inputs = token_inputs( - prompt_token_ids=prompt_tokens, - prompt=prompt_str) if prompt_embeds is None else embeds_inputs( - prompt_embeds=prompt_embeds) - prompt = Sequence( - int(request_id), - inputs=inputs, - block_size=block_size, - ) - seq_group = SequenceGroup( - request_id=request_id, - seqs=[prompt], - arrival_time=time.time(), - sampling_params=SamplingParams(max_tokens=max_tokens, - min_tokens=min_tokens), - lora_request=lora_request, - ) - - return prompt, seq_group - - -def create_dummy_lora_sequence(request_id: int, token_ids: list[int], - block_size: int, lora_int_id: int) -> Sequence: - return Sequence(seq_id=request_id, - inputs=token_inputs(token_ids), - block_size=block_size, - lora_request=LoRARequest(lora_name="dummy", - lora_path="/dummy", - lora_int_id=lora_int_id)) - - -def create_dummy_sequence(request_id: int, token_ids: list[int], - block_size: int) -> Sequence: - return Sequence( - seq_id=request_id, - inputs=token_inputs(token_ids), - block_size=block_size, - ) - - -def create_dummy_prompt_encoder_decoder( - request_id: str, - decoder_prompt_length: int, - encoder_prompt_length: int, - block_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, -) -> tuple[Sequence, Sequence, SequenceGroup]: - if not block_size: - block_size = decoder_prompt_length - - # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". Note that the prompt string - # doesn't actually match the tokens - decoder_prompt_tokens = list(range(decoder_prompt_length)) - decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) - encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) - encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - - inputs: EncoderDecoderInputs = { - "decoder": token_inputs(decoder_prompt_tokens, - prompt=decoder_prompt_str), - "encoder": token_inputs(encoder_prompt_tokens, - prompt=encoder_prompt_str), - } - - decoder_prompt = Sequence(int(request_id), - inputs=inputs["decoder"], - block_size=block_size) - - encoder_prompt = Sequence(int(request_id), - inputs=inputs["encoder"], - block_size=block_size) - - seq_group = SequenceGroup(request_id=request_id, - seqs=[decoder_prompt], - arrival_time=time.time(), - lora_request=lora_request, - encoder_seq=encoder_prompt) - - return decoder_prompt, encoder_prompt, seq_group - - -def create_seq_group( - seq_prompt_len: int = 1024, - seq_output_lens: GenericSequence[int] = (128, ), - request_id: str = '0', - seq_id_start: int = 0, - sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: - - assert len(seq_output_lens) > 0 - - if sampling_params is None: - sampling_params = SamplingParams() - - prompt_token_ids = [0] * seq_prompt_len - - seqs: list[Sequence] = [] - for seq_id_offset, output_len in enumerate(seq_output_lens): - seq = Sequence( - seq_id=seq_id_start + seq_id_offset, - inputs=token_inputs(prompt_token_ids), - block_size=16, - ) - - for i in range(output_len): - seq.append_token_id( - token_id=i, - logprobs={i: Logprob(0.0)}, - ) - seqs.append(seq) - - seq_group = SequenceGroup( - request_id=request_id, - seqs=seqs, - sampling_params=sampling_params, - arrival_time=time.time(), - ) - - return seq_group - - -def create_seq_group_encoder_decoder( - seq_prompt_len: int = 1024, - seq_output_lens: GenericSequence[int] = (128, ), - request_id: str = '0', - seq_id_start: int = 0, - sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: - - assert len(seq_output_lens) > 0 - - if sampling_params is None: - sampling_params = SamplingParams() - - prompt_token_ids = [0] * seq_prompt_len - - inputs: EncoderDecoderInputs = { - "decoder": token_inputs(prompt_token_ids), - "encoder": token_inputs(prompt_token_ids), - } - - seqs = [] - for seq_id_offset, output_len in enumerate(seq_output_lens): - # Construct decoder input sequences - seq = Sequence( - seq_id=seq_id_start + seq_id_offset, - inputs=inputs["decoder"], - block_size=16, - ) - - for i in range(output_len): - seq.append_token_id( - token_id=i, - logprobs={i: Logprob(0.0)}, - ) - seqs.append(seq) - - # Encoder input sequence - encoder_seq = Sequence( - seq_id=seq_id_start + len(seq_output_lens), - inputs=inputs["encoder"], - block_size=16, - ) - - return SequenceGroup(request_id=request_id, - seqs=seqs, - sampling_params=sampling_params, - arrival_time=time.time(), - encoder_seq=encoder_seq) - - -def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size - - -# Helper functions for scheduler tests - - -def get_sequence_groups(scheduler_output): - return [s.seq_group for s in scheduler_output.scheduled_seq_groups] - - -def append_new_token(out, token_id: int): - seq_groups = get_sequence_groups(out) - for seq_group in seq_groups: - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -def schedule_and_update_computed_tokens(scheduler): - metas, out, _ = scheduler.schedule() - for s in out.scheduled_seq_groups: - s.seq_group.update_num_computed_tokens(s.token_chunk_size) - return metas, out - - -def append_new_token_seq(seq: Sequence, token_id: int): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): - seq_group.update_num_computed_tokens(token_chunk_size) - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -class SchedulerProxy: - """ - A proxy class to forward calls to the scheduler. - """ - - def __init__(self, scheduler: Scheduler): - self.scheduler_ = scheduler - self.call_history: dict[str, list[Any]] = defaultdict(list) - - def __getattr__(self, name: str) -> Any: - - def wrapper(*args, **kwargs): - result = getattr(self.scheduler_, name)(*args, **kwargs) - self.call_history[name].append((args, kwargs, result)) - return result - - return wrapper - - def last_schedule_ret( - self, ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]: - _, _, ret = self.call_history["schedule"][-1] - return ret - - -def create_seq_group_metadata_from_prompts( - prompts: list[list[int]], - num_gpu_blocks: int, - block_size: int, - final_prompt_lens: list[int], - continuations: Optional[list[list[int]]] = None, - seq_ids: Optional[list[int]] = None, -) -> list[SequenceGroupMetadata]: - - if continuations is None: - continuations = [[] for _ in prompts] - - if seq_ids is None: - seq_ids = list(i for i, _ in enumerate(prompts)) - - free_gpu_blocks = list(range(num_gpu_blocks)) - - block_allocations = { - i: [ - free_gpu_blocks.pop() - for _ in range(round_up_to_next_block(final_len, block_size)) - ] - for i, final_len in enumerate(final_prompt_lens) - } - - seq_grou_metadata_list = [] - for i, (prompt_token_ids, - cont_token_ids) in enumerate(zip(prompts, continuations)): - data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids) - data.update_num_computed_tokens( - len(prompt_token_ids) + len(cont_token_ids) - 1) - seq_data = {i: data} - seq_grou_metadata_list.append( - SequenceGroupMetadata( - request_id=str(i), - is_prompt=len(cont_token_ids) == 0, - seq_data=seq_data, - sampling_params=SamplingParams(temperature=0.0), - block_tables={i: block_allocations[i][:]}, - )) - return seq_grou_metadata_list - - -def create_chunked_seq_group_metadata_from_prompt( - prompt: list[int], - num_gpu_blocks: int, - chunk_size: int, - block_size: int, - seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]: - - if seq_id is None: - seq_id = 0 - - free_gpu_blocks = list(range(num_gpu_blocks)) - - block_allocations = [ - free_gpu_blocks.pop() - for _ in range(round_up_to_next_block(len(prompt), block_size)) - ] - - seq_group_metadata_list = [] - for i, idx in enumerate(range(0, len(prompt), chunk_size)): - chunk_ids = prompt[idx:idx + chunk_size] - data = SequenceData.from_seqs(prompt) - data.update_num_computed_tokens(idx) - seq_data = {i: data} - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=str(seq_id), - is_prompt=True, - do_sample=idx + chunk_size >= len(prompt), # terminal chunk - seq_data=seq_data, - sampling_params=SamplingParams(temperature=0.0), - block_tables={i: block_allocations}, - token_chunk_size=len(chunk_ids))) - return seq_group_metadata_list - - -def create_batch(batch_size, - k, - prompt_len: Union[int, list[int]] = 10, - prev_output_token_len: int = 10, - seq_ids: Optional[list[int]] = None, - num_gpu_blocks: Optional[int] = None, - block_size: Optional[int] = None, - prefill_chunk_size: Optional[int] = None): - if block_size is None: - block_size = 8 - - if num_gpu_blocks is None: - num_gpu_blocks = 2048 // block_size - - iterator = count() - - if isinstance(prompt_len, int): - prompt_lens = [prompt_len for _ in range(batch_size)] - else: - prompt_lens = prompt_len - - prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens] - - if prefill_chunk_size: - # Create a batch of chunked prompts. - if not seq_ids: - seq_ids = list(range(len(prompts))) - seq_group_metadata_list = [] - for p, sid in zip(prompts, seq_ids): - seq_group_metadata_list += \ - create_chunked_seq_group_metadata_from_prompt( - p, num_gpu_blocks, prefill_chunk_size, block_size, sid) - seq_group_metadata_list = seq_group_metadata_list[:batch_size] - prev_output_tokens = [] - else: - prev_output_tokens = [[ - next(iterator) for _ in range(prev_output_token_len) - ] for _ in range(batch_size)] - final_prompt_lens = [ - len(prompt) + len(prev_output_token) + k + 1 - for prompt, prev_output_token in zip(prompts, prev_output_tokens) - ] - - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, final_prompt_lens, - prev_output_tokens, seq_ids) - return seq_group_metadata_list, prompts, prev_output_tokens diff --git a/tests/detokenizer/conftest.py b/tests/detokenizer/conftest.py deleted file mode 100644 index f2c125355c83c..0000000000000 --- a/tests/detokenizer/conftest.py +++ /dev/null @@ -1,11 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(autouse=True) -def v1(run_with_both_engines): - # Simple autouse wrapper to run both engines for each test - # This can be promoted up to conftest.py to run for every - # test in a package - pass diff --git a/tests/detokenizer/test_min_tokens.py b/tests/detokenizer/test_min_tokens.py index 887e83342536e..26003373c569c 100644 --- a/tests/detokenizer/test_min_tokens.py +++ b/tests/detokenizer/test_min_tokens.py @@ -31,16 +31,14 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): stop=stop, min_tokens=min_tokens, ) - request = EngineCoreRequest("", - prompt_token_ids, - None, - None, - None, - params, - None, - None, - 0.0, - None, + request = EngineCoreRequest(request_id="", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, cache_salt=None, data_parallel_rank=None) diff --git a/tests/detokenizer/test_stop_checker.py b/tests/detokenizer/test_stop_checker.py deleted file mode 100644 index bd221977224f9..0000000000000 --- a/tests/detokenizer/test_stop_checker.py +++ /dev/null @@ -1,89 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from unittest.mock import MagicMock - -import pytest -from transformers import PreTrainedTokenizer - -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.inputs import token_inputs -from vllm.sampling_params import SamplingParams -from vllm.sequence import Logprob, Sequence, SequenceStatus - - -def sequence_with_eos(text: str, eos_token: str, - eos_token_id: int) -> Sequence: - """ - Create a Sequence that ends with an EOS token. - """ - seq = Sequence( - seq_id=0, - inputs=token_inputs([]), - block_size=16, - eos_token_id=eos_token_id, - ) - seq.output_text = text + eos_token - - offset = eos_token_id + 1 - for i in range(offset, len(text) + offset): - seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)}) - seq.append_token_id(token_id=eos_token_id, - logprobs={eos_token_id: Logprob(0.0)}) - - seq.status = SequenceStatus.RUNNING - - return seq - - -@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [ - ("This text ends with EOS token", "</s>", 2), -]) -@pytest.mark.parametrize("ignore_eos", [True, False]) -@pytest.mark.parametrize("include_stop_str_in_output", [True, False]) -@pytest.mark.skip_global_cleanup -def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, - ignore_eos: bool, include_stop_str_in_output: bool): - """ - Test the behavior of the StopChecker's maybe_stop_sequence method - when an EOS token is encountered. - - This test covers: - - When the EOS token should stop the sequence and be removed from the output - - When the EOS token should stop the sequence and be included in the output - - When the EOS token should be ignored, and the sequence continues - """ - - tokenizer = MagicMock(spec=PreTrainedTokenizer) - get_tokenizer_for_seq = MagicMock(return_value=tokenizer) - stop_checker = StopChecker(max_model_len=1024, - get_tokenizer_for_seq=get_tokenizer_for_seq) - - seq = sequence_with_eos( - text=text_wo_eos, - eos_token=eos_token, - eos_token_id=eos_token_id, - ) - new_char_count = len(eos_token) - - # Note that `stop` and `stop_token_ids` are not specified - sampling_params = SamplingParams( - min_tokens=1, - ignore_eos=ignore_eos, - include_stop_str_in_output=include_stop_str_in_output) - - stop_checker.maybe_stop_sequence( - seq=seq, - new_char_count=new_char_count, - sampling_params=sampling_params, - ) - - if ignore_eos: - assert seq.status == SequenceStatus.RUNNING - assert seq.output_text == text_wo_eos + eos_token - elif include_stop_str_in_output: - assert seq.status == SequenceStatus.FINISHED_STOPPED - assert seq.output_text == text_wo_eos + eos_token - else: - assert seq.status == SequenceStatus.FINISHED_STOPPED - assert seq.output_text == text_wo_eos diff --git a/tests/detokenizer/test_stop_string_while_stop_model_terminates.py b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py new file mode 100644 index 0000000000000..9b32a2927f2de --- /dev/null +++ b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import BaseIncrementalDetokenizer + + +@pytest.fixture(params=[True, False]) +def include_stop_str_in_output(request): + return request.param + + +class _DummyDetokenizer(BaseIncrementalDetokenizer): + + def __init__(self, request: EngineCoreRequest): + super().__init__(request) + + def decode_next(self, next_token_id: int) -> str: + # Map token id to single ASCII character for deterministic testing. + return chr(next_token_id) + + +def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0): + params = SamplingParams( + stop=stop, + include_stop_str_in_output=include_stop_str_in_output, + min_tokens=min_tokens) + # Keep other fields minimal for unit test purposes. + req = EngineCoreRequest( + request_id="test", + prompt_token_ids=[], + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) + return req + + +def test_stop_string_while_stop_token_terminates( + include_stop_str_in_output: bool): + """ + This test verifies that the detokenizer correctly handles the case where + the generated token sequence contains both: + - a stop token + - an <eos> token + + The detokenizer should respect the stop string and truncate the output + accordingly. + + Imagine the following sequence: + - "abcdeZ" is generated, where "Z" is the <eos> token. + - "cd" is the stop string. + + If include_stop_str_in_output=False, the detokenizer should truncate the + output to "ab" because the stop string "cd" is excluded. + If include_stop_str_in_output=True, the detokenizer should include the stop + string "cd" in the output, resulting in "abcd". + + + This verifies the behavioral change introduced in BaseIncrementalDetokenizer + where stop-string evaluation occurs before the early-return on + stop_terminated. + """ + + # Generate text "abcdeZ" and tokenize it. + generated_text = "abcde" + eos_token = "Z" + stop_string = "cd" + generated_text = generated_text + eos_token + token_ids = [ord(c) for c in generated_text] + + # Create a request with the stop string and initialize the detokenizer. + req = _make_request(stop=[stop_string], + include_stop_str_in_output=include_stop_str_in_output) + detok = _DummyDetokenizer(req) + + # Simulate that the last token ('Z') is a stop token (stop_terminated=True). + result = detok.update(new_token_ids=token_ids, stop_terminated=True) + + # The update should not report a stop string + assert result == stop_string + + # Output text should reflect stop-string handling: + # - include_stop_str_in_output=False => exclude "cd" => "ab" + # - include_stop_str_in_output=True => include "cd" => "abcd" + expected_text = "abcd" if include_stop_str_in_output else "ab" + assert detok.output_text == expected_text + + # The skipped final token should still be recorded in token_ids. + assert detok.output_token_ids == token_ids + + # get_next_output_text should return the full text when finished=True. + # (Buffering only applies during streaming when finished=False.) + assert detok.get_next_output_text(finished=True, + delta=False) == expected_text diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 666a715cc0da1..7dc4a0cc3d582 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -8,7 +8,7 @@ import msgspec.msgpack import pytest import zmq -from vllm.config import KVEventsConfig +from vllm.config.kv_events import KVEventsConfig from vllm.distributed.kv_events import EventPublisherFactory from .test_events import SampleBatch diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py new file mode 100644 index 0000000000000..11685bc90c413 --- /dev/null +++ b/tests/distributed/test_context_parallel.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +WARNING: This test runs in both single-node (4 GPUs) and multi-node + (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is + important to set the distributed backend to "mp" to avoid Ray scheduling + all workers in a node other than the head node, which can cause the test + to fail. +""" +import json +import os +from dataclasses import dataclass +from typing import Literal, NamedTuple, Optional + +import pytest + +from vllm.config import RunnerOption +from vllm.logger import init_logger + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import compare_two_settings, create_new_process_for_each_test + +logger = init_logger("test_context_parallel") + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + + +class ParallelSetup(NamedTuple): + tp_size: int + pp_size: int + dcp_size: int + eager_mode: bool + chunked_prefill: bool + + +class CPTestOptions(NamedTuple): + multi_node_only: bool + load_format: Optional[str] = None + + +@dataclass +class CPTestSettings: + parallel_setups: list[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. + distributed_backends: list[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: list[str] + runner: RunnerOption + test_options: CPTestOptions + + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + + @staticmethod + def detailed( + *, + tp_base: int = 4, + pp_base: int = 1, + dcp_base: int = 1, + multi_node_only: bool = False, + runner: RunnerOption = "auto", + load_format: Optional[str] = None, + ): + parallel_setups = [] + for eager_mode_val in [False]: + for pp_multiplier in [1]: + for dcp_multiplier in [0.5, 1]: + for chunked_prefill_val in [True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + dcp_size=int(dcp_multiplier * + tp_base), + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) + return CPTestSettings( + parallel_setups=parallel_setups, + distributed_backends=["mp"], + vllm_major_versions=["1"], + runner=runner, + test_options=CPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + def iter_params(self, model_id: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_id, parallel_setup, backend, vllm_major_version, + self.runner, opts) + + +def _compare_cp_with_tp( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + runner: RunnerOption, + test_options: CPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate"], + is_multimodal: bool, +): + ( + tp_size, + pp_size, + dcp_size, + eager_mode, + chunked_prefill, + ) = parallel_setup + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") + + if num_gpus_available < tp_size * pp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + if VLLM_MULTI_NODE and distributed_backend == "mp": + pytest.skip("Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if runner != "auto": + common_args.extend(["--runner", runner]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + + cp_env = tp_env = { + "VLLM_USE_V1": + vllm_major_version, # Note(hc): DCP only support V1 engine only + } + + cp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--decode-context-parallel-size", + str(dcp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + try: + compare_two_settings(model_id, + cp_args, + tp_args, + cp_env, + tp_env, + method=method, + max_wait_seconds=720) + except Exception: + testing_ray_compiled_graph = cp_env is not None + if testing_ray_compiled_graph and vllm_major_version == "0": + # Ray Compiled Graph tests are flaky for V0, + # so we don't want to fail the test + logger.exception("Ray Compiled Graph tests failed") + else: + raise + + +CP_TEXT_GENERATION_MODELS = { + # [MLA attention only] + "deepseek-ai/DeepSeek-V2-Lite-Chat": + [CPTestSettings.detailed(), + CPTestSettings.detailed(tp_base=2)], +} + +CP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "deepseek-ai/DeepSeek-V2-Lite-Chat", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "runner", "test_options"), + [ + params for model_id, settings in CP_TEXT_GENERATION_MODELS.items() + for setting in settings for params in setting.iter_params(model_id) + if model_id in CP_TEST_MODELS + ], +) +@create_new_process_for_each_test() +def test_cp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + runner: RunnerOption, + test_options: CPTestOptions, + num_gpus_available, +): + _compare_cp_with_tp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False) diff --git a/tests/distributed/test_expert_placement.py b/tests/distributed/test_expert_placement.py new file mode 100644 index 0000000000000..a3b1b3193deb0 --- /dev/null +++ b/tests/distributed/test_expert_placement.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.model_executor.layers.fused_moe.layer import determine_expert_map + + +def verify_round_robin_pattern(expert_map, ep_rank, ep_size, + global_num_experts): + """Verify that the expert map follows the round_robin pattern.""" + # Calculate expected local experts (supporting non-divisible cases) + base_experts = global_num_experts // ep_size + remainder = global_num_experts % ep_size + + if ep_rank < remainder: + local_num_experts = base_experts + 1 + else: + local_num_experts = base_experts + + # Expected expert IDs for this rank in round_robin pattern + # For non-divisible cases, ranks with extra experts start earlier + expected_expert_ids = [] + for expert_idx in range(local_num_experts): + global_expert_id = ep_rank + expert_idx * ep_size + expected_expert_ids.append(global_expert_id) + + # Check that only expected experts are mapped to this rank + for global_expert_id in range(global_num_experts): + if global_expert_id in expected_expert_ids: + local_expert_id = expert_map[global_expert_id] + expected_local_id = expected_expert_ids.index(global_expert_id) + assert ( + local_expert_id == expected_local_id + ), f"Global expert {global_expert_id} should map to local expert " \ + f"{expected_local_id}, got {local_expert_id}" + else: + assert ( + expert_map[global_expert_id] == -1 + ), f"Global expert {global_expert_id} should not be mapped to " \ + f"this rank" + + # Verify that all local expert IDs are consecutive starting from 0 + local_expert_ids = [ + expert_map[global_id] for global_id in expected_expert_ids + ] + expected_local_ids = list(range(local_num_experts)) + assert ( + local_expert_ids == expected_local_ids + ), f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}" + + +@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"]) +@pytest.mark.parametrize("world_size", [2, 4]) +def test_expert_placement_various_sizes(expert_placement_strategy, world_size): + """Test round_robin expert placement with various expert counts.""" + + # Test with different global_num_experts values + # Include both divisible and non-divisible cases + if world_size == 2: + test_cases = [ + (4, 2), # 4 experts (divisible) + (8, 2), # 8 experts (divisible) + (9, 2), # 9 experts (non-divisible) + (16, 2), # 16 experts (divisible) + (17, 2), # 17 experts (non-divisible) + ] + elif world_size == 4: + test_cases = [ + (8, 4), # 8 experts (divisible) + (16, 4), # 16 experts (divisible) + (18, 4), # 18 experts (non-divisible) + (32, 4), # 32 experts (divisible) + (33, 4), # 33 experts (non-divisible) + ] + else: + test_cases = [] + + for test_global_experts, test_ep_size in test_cases: + # Ensure ep_size matches world_size + assert (test_ep_size == world_size + ), f"ep_size {test_ep_size} must equal world_size {world_size}" + + # Test each rank + for ep_rank in range(world_size): + # Calculate expected local experts + base_experts = test_global_experts // test_ep_size + remainder = test_global_experts % test_ep_size + if ep_rank < remainder: + expected_test_local = base_experts + 1 + else: + expected_test_local = base_experts + + test_local_experts, test_expert_map = determine_expert_map( + ep_size=test_ep_size, + ep_rank=ep_rank, + global_num_experts=test_global_experts, + expert_placement_strategy=expert_placement_strategy, + ) + + assert ( + test_local_experts == expected_test_local + ), f"For {test_global_experts} experts on {test_ep_size} ranks, " \ + f"rank {ep_rank}: expected {expected_test_local} local" \ + f"experts, got {test_local_experts}" + + if test_expert_map is not None: + assert test_expert_map.shape == ( + test_global_experts, + ), f"Expected expert map shape ({test_global_experts},), " \ + f"got {test_expert_map.shape}" + + # Verify round_robin pattern for this test case + verify_round_robin_pattern(test_expert_map, ep_rank, + test_ep_size, test_global_experts) + + +@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"]) +@pytest.mark.parametrize("world_size", [2, 4]) +def test_expert_placement_edge_cases(expert_placement_strategy, world_size): + """Test edge cases for round_robin expert placement.""" + + # Test case 1: ep_size = 1 (should return None for expert_map) + local_num_experts, expert_map = determine_expert_map( + ep_size=1, + ep_rank=0, + global_num_experts=8, + expert_placement_strategy=expert_placement_strategy, + ) + assert local_num_experts == 8, "For ep_size=1, should get all experts" + assert expert_map is None, "For ep_size=1, expert_map should be None" + + # Test case 2: ep_size = 0 (should raise assertion) + with pytest.raises(AssertionError): + determine_expert_map( + ep_size=0, + ep_rank=0, + global_num_experts=8, + expert_placement_strategy=expert_placement_strategy, + ) + + +def test_determine_expert_map_comprehensive(): + """Test of determine_expert_map function with various configurations.""" + + # Test cases: (ep_size, ep_rank, global_num_experts, + # expert_placement_strategy, expected_local, expected_map_pattern) + test_cases = [ + # Round robin placement tests + (2, 0, 8, "round_robin", 4, [0, -1, 1, -1, 2, -1, 3, + -1]), # rank 0 gets even experts + (2, 1, 8, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, + 3]), # rank 1 gets odd experts + (2, 0, 9, "round_robin", 5, [0, -1, 1, -1, 2, -1, 3, -1, 4 + ]), # rank 0 gets 5 experts (even + last) + (2, 1, 9, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, 3, + -1]), # rank 1 gets 4 experts (odd) + + # 4-rank tests + (4, 0, 8, "round_robin", 2, [0, -1, -1, -1, 1, -1, -1, + -1]), # rank 0 gets experts 0, 4 + (4, 1, 8, "round_robin", 2, [-1, 0, -1, -1, -1, 1, -1, + -1]), # rank 1 gets experts 1, 5 + (4, 2, 8, "round_robin", 2, [-1, -1, 0, -1, -1, -1, 1, + -1]), # rank 2 gets experts 2, 6 + (4, 3, 8, "round_robin", 2, [-1, -1, -1, 0, -1, -1, -1, + 1]), # rank 3 gets experts 3, 7 + ] + + for ep_size, ep_rank, global_num_experts, expert_placement_strategy, \ + expected_local, expected_map_pattern in test_cases: + local_num_experts, expert_map = determine_expert_map( + ep_size=ep_size, + ep_rank=ep_rank, + global_num_experts=global_num_experts, + expert_placement_strategy=expert_placement_strategy, + ) + + assert local_num_experts == expected_local, \ + f"ep_size={ep_size}, ep_rank={ep_rank}, " \ + f"global_num_experts={global_num_experts}, " \ + f"expert_placement_strategy={expert_placement_strategy}: " \ + f"expected {expected_local} local experts, got {local_num_experts}" + + if expected_map_pattern is None: + assert expert_map is None, "Expected expert_map to be None" + else: + assert expert_map is not None, "Expected expert_map to not be None" + actual_map = expert_map.tolist() + assert actual_map == expected_map_pattern, \ + f"ep_size={ep_size}, ep_rank={ep_rank}, " \ + f"global_num_experts={global_num_experts}, " \ + f"expert_placement_strategy={expert_placement_strategy}: " \ + f"expected map {expected_map_pattern}, got {actual_map}" diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 1afe9ea970c97..fcd09844c0951 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -26,23 +26,10 @@ logger = init_logger("test_pipeline_parallel") VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - For PP, we fall back to V0 by default. This means - that the TP baseline runs with V1 while the PP engine - runs with V0. This gives divergent results with dummy - weights. Once we enable V1 by default for PP, we can - remove this. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - class ParallelSetup(NamedTuple): tp_size: int pp_size: int eager_mode: bool - chunked_prefill: bool class PPTestOptions(NamedTuple): @@ -53,23 +40,10 @@ class PPTestOptions(NamedTuple): @dataclass class PPTestSettings: parallel_setups: list[ParallelSetup] - # NOTE: the length of distributed_backends and - # vllm_major_versions should be the same, and they - # are first zipped together to iterate over all - # test settings. distributed_backends: list[str] - # vllm major version: "0" for V0, "1" for V1 - vllm_major_versions: list[str] runner: RunnerOption test_options: PPTestOptions - def __post_init__(self): - if len(self.distributed_backends) != len(self.vllm_major_versions): - raise ValueError( - f"Length mismatch: distributed_backends " - f"({len(self.distributed_backends)}) != " - f"vllm_major_versions ({len(self.vllm_major_versions)})") - @staticmethod def detailed( *, @@ -83,27 +57,21 @@ class PPTestSettings: parallel_setups=[ ParallelSetup(tp_size=tp_base, pp_size=pp_base, - eager_mode=False, - chunked_prefill=False), + eager_mode=False), ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, - eager_mode=False, - chunked_prefill=True), + eager_mode=False), ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, - eager_mode=True, - chunked_prefill=False), + eager_mode=True), ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, - eager_mode=False, - chunked_prefill=True), + eager_mode=False), ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, - eager_mode=True, - chunked_prefill=False), + eager_mode=True), ], - distributed_backends=["mp", "mp", "ray", "ray"], - vllm_major_versions=["0", "1", "0", "1"], + distributed_backends=["mp", "ray"], runner=runner, test_options=PPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -118,17 +86,14 @@ class PPTestSettings: multi_node_only: bool = False, load_format: Optional[str] = None, ): - vllm_major_versions = ["1"] if runner == "pooling" else ["0"] return PPTestSettings( parallel_setups=[ ParallelSetup(tp_size=tp_base, pp_size=pp_base, - eager_mode=True, - chunked_prefill=False), + eager_mode=True), ], distributed_backends=["mp"], - vllm_major_versions=vllm_major_versions, runner=runner, test_options=PPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -138,10 +103,8 @@ class PPTestSettings: opts = self.test_options for parallel_setup in self.parallel_setups: - for backend, vllm_major_version in zip(self.distributed_backends, - self.vllm_major_versions): - yield (model_id, parallel_setup, backend, vllm_major_version, - self.runner, opts) + for backend in self.distributed_backends: + yield (model_id, parallel_setup, backend, self.runner, opts) # NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU @@ -215,9 +178,7 @@ TEXT_GENERATION_MODELS = { EMBEDDING_MODELS = { # type: ignore[var-annotated] # [Text-only] "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"), - # TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883 - # is fixed - #"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), + "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast( load_format="dummy", runner="pooling" ), @@ -244,9 +205,6 @@ MULTIMODAL_MODELS = { "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(), - # [Encoder-decoder] - # TODO: Implement PP - # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), } # yapf: enable @@ -274,7 +232,6 @@ def _compare_tp( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: PPTestOptions, num_gpus_available: int, @@ -286,7 +243,6 @@ def _compare_tp( tp_size, pp_size, eager_mode, - chunked_prefill, ) = parallel_setup multi_node_only, load_format = test_options @@ -298,6 +254,8 @@ def _compare_tp( tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides hf_config = get_config(model_id, trust_remote_code) + skip_tokenizer_init = model_info.skip_tokenizer_init + max_num_seqs = model_info.max_num_seqs dtype = "float16" if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS: @@ -337,8 +295,6 @@ def _compare_tp( "--max-num-seqs", "8", ] - if chunked_prefill: - common_args.append("--enable-chunked-prefill") if eager_mode: common_args.append("--enforce-eager") if runner != "auto": @@ -351,15 +307,15 @@ def _compare_tp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if skip_tokenizer_init: + common_args.append("--skip-tokenizer-init") + if max_num_seqs: + common_args.extend(["--max-num-seqs", f"{max_num_seqs}"]) - specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill - testing_ray_compiled_graph = False - if distributed_backend == "ray" and (vllm_major_version == "1" - or specific_case): + if distributed_backend == "ray": # For V1, test Ray Compiled Graph for all the tests - # For V0, test Ray Compiled Graph for a subset of the tests pp_env = { - "VLLM_USE_V1": vllm_major_version, + "VLLM_USE_V1": "1", "VLLM_USE_RAY_COMPILED_DAG": "1", "VLLM_USE_RAY_SPMD_WORKER": "1", "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", @@ -367,17 +323,15 @@ def _compare_tp( # Temporary. Currently when zeromq + SPMD is used, it does not properly # terminate because of a Ray Compiled Graph issue. common_args.append("--disable-frontend-multiprocessing") - testing_ray_compiled_graph = True elif distributed_backend == "mp": - # Both V0/V1 of multiprocessing executor support PP pp_env = { - "VLLM_USE_V1": vllm_major_version, + "VLLM_USE_V1": "1", } else: pp_env = None tp_env = { - "VLLM_USE_V1": vllm_major_version, + "VLLM_USE_V1": "1", } pp_args = [ @@ -403,25 +357,17 @@ def _compare_tp( "mp", ] - try: - compare_two_settings(model_id, - pp_args, - tp_args, - pp_env, - tp_env, - method=method) - except Exception: - if testing_ray_compiled_graph and vllm_major_version == "0": - # Ray Compiled Graph tests are flaky for V0, - # so we don't want to fail the test - logger.exception("Ray Compiled Graph tests failed") - else: - raise + compare_two_settings(model_id, + pp_args, + tp_args, + pp_env, + tp_env, + method=method) @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "runner", "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "runner", + "test_options"), [ params for model_id, settings in TEXT_GENERATION_MODELS.items() for params in settings.iter_params(model_id) if model_id in TEST_MODELS @@ -432,15 +378,14 @@ def test_tp_language_generation( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: PPTestOptions, num_gpus_available, ): + pytest.skip("Skipping the test until V1 passes it.") _compare_tp(model_id, parallel_setup, distributed_backend, - vllm_major_version, runner, test_options, num_gpus_available, @@ -449,8 +394,8 @@ def test_tp_language_generation( @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "runner", "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "runner", + "test_options"), [ params for model_id, settings in EMBEDDING_MODELS.items() for params in settings.iter_params(model_id) if model_id in TEST_MODELS @@ -461,15 +406,14 @@ def test_tp_language_embedding( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: PPTestOptions, num_gpus_available, ): + pytest.skip("Skipping the test until V1 passes it.") _compare_tp(model_id, parallel_setup, distributed_backend, - vllm_major_version, runner, test_options, num_gpus_available, @@ -478,8 +422,8 @@ def test_tp_language_embedding( @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "runner", "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "runner", + "test_options"), [ params for model_id, settings in MULTIMODAL_MODELS.items() for params in settings.iter_params(model_id) if model_id in TEST_MODELS @@ -490,15 +434,14 @@ def test_tp_multimodal_generation( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, - vllm_major_version: str, runner: RunnerOption, test_options: PPTestOptions, num_gpus_available, ): + pytest.skip("Skipping the test until V1 passes it.") _compare_tp(model_id, parallel_setup, distributed_backend, - vllm_major_version, runner, test_options, num_gpus_available, diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index c93b436f384b9..ded3d834faf00 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -178,6 +178,7 @@ def _compare_sp( trust_remote_code = model_info.trust_remote_code tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides + skip_tokenizer_init = model_info.skip_tokenizer_init if load_format == "dummy": # Avoid OOM @@ -227,12 +228,13 @@ def _compare_sp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if skip_tokenizer_init: + common_args.append("--skip-tokenizer-init") compilation_config = { 'level': 3, 'custom_ops': ["+rms_norm"], 'compile_sizes': [4, 8], - 'splitting_ops': [], 'pass_config': { 'enable_sequence_parallelism': True, 'enable_fusion': enable_fusion, @@ -248,6 +250,8 @@ def _compare_sp( *common_args, "--tensor-parallel-size", str(tp_size), + "--pipeline-parallel-size", + str(pp_size), "--distributed-executor-backend", distributed_backend, "--compilation_config", diff --git a/tests/distributed/test_shm_buffer.py b/tests/distributed/test_shm_buffer.py new file mode 100644 index 0000000000000..f70028b879609 --- /dev/null +++ b/tests/distributed/test_shm_buffer.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import traceback +import unittest + +from vllm.distributed.device_communicators.shm_object_storage import ( + SingleWriterShmRingBuffer) + + +class TestSingleWriterShmRingBuffer(unittest.TestCase): + """Test suite for the ring buffer implementation""" + + def setUp(self): + """Set up test fixtures""" + self.buffer_size = 4096 + self.ring_buffer = None + + def tearDown(self): + """Clean up after tests""" + if self.ring_buffer: + del self.ring_buffer + + def test_buffer_opening(self): + """Test opening an existing buffer""" + # First create a buffer + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True) + + # Then open it with another instance + reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle()) + self.assertFalse(reader_buffer.is_writer) + self.assertEqual(reader_buffer.shared_memory.name, + self.ring_buffer.shared_memory.name) + + def test_buffer_access(self): + """Test accessing allocated buffers""" + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True) + + size = 100 + address, monotonic_id = self.ring_buffer.allocate_buf(size) + + # Write some test data + test_data = b"Hello, World!" * 7 # 91 bytes + with self.ring_buffer.access_buf(address) as (data_buf, metadata): + data_buf[0:len(test_data)] = test_data + + # Read it back + with self.ring_buffer.access_buf(address) as (data_buf2, metadata2): + read_data = bytes(data_buf2[0:len(test_data)]) + read_id = metadata2[0] + + self.assertEqual(read_data, test_data) + self.assertEqual(read_id, monotonic_id) + + def test_memory_error_on_full_buffer(self): + """Test that MemoryError is raised when buffer is full""" + small_buffer_size = 200 + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=small_buffer_size, create=True) + + # Fill up the buffer + self.ring_buffer.allocate_buf(100) + self.ring_buffer.allocate_buf(80) # Total: 196 bytes used + + # This should fail + with self.assertRaises(MemoryError): + self.ring_buffer.allocate_buf(1) # Would exceed buffer capacity + + def test_allocation_and_free(self): + """Test allocation and freeing of buffers""" + small_buffer_size = 200 + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=small_buffer_size, create=True) + + size = 80 + # Write some data + test_data = b"Repeated test data" + for i in range(5): + address, monotonic_id = self.ring_buffer.allocate_buf(size) + with self.ring_buffer.access_buf(address) as (data_buf, metadata): + data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use + data_buf[4:len(test_data) + 4] = test_data + print(self.ring_buffer.metadata) + freed_ids = self.ring_buffer.free_buf(lambda *args: True) + print(f" Freed IDs: {freed_ids}") + self.assertEqual(freed_ids[0], i) + + def test_clear_buffer(self): + """Test clearing the buffer""" + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True) + + # Allocate some buffers + for _ in range(3): + self.ring_buffer.allocate_buf(100) + + # Clear the buffer + self.ring_buffer.clear() + + # Check that metadata is empty and IDs reset + self.assertEqual(len(self.ring_buffer.metadata), 0) + self.assertEqual(self.ring_buffer.monotonic_id_start, 0) + self.assertEqual(self.ring_buffer.monotonic_id_end, 0) + self.assertEqual(self.ring_buffer.data_buffer_start, 0) + self.assertEqual(self.ring_buffer.data_buffer_end, 0) + + +def main(): + """Main function demonstrating usage and running tests""" + print("=== SingleWriterShmRingBuffer Test Suite ===\n") + + # Run unit tests + print("Running unit tests...") + unittest.main(argv=[""], exit=False, verbosity=2) + + print("\n" + "=" * 50) + print("=== Manual Demo ===\n") + + # Manual demonstration + try: + print("Creating ring buffer...") + writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, + create=True) + reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle()) + + print(f"Buffer created with name: {writer_buffer.shared_memory.name}") + + # Allocate some buffers + print("\nAllocating buffers...") + address_array = [] + for i in range(3): + size = 100 + i * 50 + try: + writer_buffer.free_buf(lambda *args: True) + address, monotonic_id = writer_buffer.allocate_buf(size) + address_array.append((address, size, monotonic_id)) + + # Write some test data + with writer_buffer.access_buf(address) as (data_buf, metadata): + test_message = f"Test message {i}".encode() + data_buf[0:len(test_message)] = test_message + + except MemoryError as e: + print(f" Failed to allocate {size} bytes: {e}") + + print("\nBuffer state:") + print(f" Data buffer start: {writer_buffer.data_buffer_start}") + print(f" Data buffer end: {writer_buffer.data_buffer_end}") + print(f" Monotonic ID start: {writer_buffer.monotonic_id_start}") + print(f" Monotonic ID end: {writer_buffer.monotonic_id_end}") + print(f" Metadata entries: {len(writer_buffer.metadata)}") + + # Try to read back the data + print("\nReading back data...") + for address, size, monotonic_id in address_array: + with reader_buffer.access_buf(address) as (data_buf, metadata): + # Find null terminator or read first 50 chars + data_bytes = bytes(data_buf[0:size]) + message = data_bytes.decode() + print(f" ID {monotonic_id}: '{message}'") + + except Exception as e: + print(f"Demo error: {e}") + traceback.print_exc() + + print("\n=== Demo Complete ===") + + +if __name__ == "__main__": + main() diff --git a/tests/distributed/test_shm_storage.py b/tests/distributed/test_shm_storage.py new file mode 100644 index 0000000000000..03495222bc1b8 --- /dev/null +++ b/tests/distributed/test_shm_storage.py @@ -0,0 +1,327 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing +import random +import time +import traceback +import unittest +from multiprocessing import Lock + +import torch + +# Assuming these are imported from your module +from vllm.distributed.device_communicators.shm_object_storage import ( + MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer) +from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, + MultiModalSharedField) + + +def _dummy_elem(modality: str, key: str, size: int): + return MultiModalFieldElem( + modality=modality, + key=key, + data=torch.empty((size, ), dtype=torch.int8), + field=MultiModalSharedField(1), + ) + + +def _dummy_item(modality: str, size_by_key: dict[str, int]): + return MultiModalKwargsItem.from_elems([ + _dummy_elem(modality, key, size) for key, size in size_by_key.items() + ]) + + +class TestSingleWriterShmObjectStorage(unittest.TestCase): + + def setUp(self): + """Set up test fixtures before each test method.""" + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=1024 * 100, + create=True, # 10 MB buffer + ) + self.storage = SingleWriterShmObjectStorage( + max_object_size=1024 * 10, # 10KB max object + n_readers=2, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=Lock(), + ) + + def tearDown(self): + """Clean up after each test.""" + if self.storage: + del self.storage + + def test_minimal_put_get_cycle(self): + """Test basic put and get operations.""" + key = "test_key" + value = _dummy_item("text", {"field1": 10, "field2": 20}) + + # Put operation + address, monotonic_id = self.storage.put(key, value) + + # Verify key is in index + self.assertIn(key, self.storage.key_index) + self.assertEqual(self.storage.key_index[key], (address, monotonic_id)) + self.assertEqual(self.storage.id_index[monotonic_id], key) + + # Get operation + result = self.storage.get(address, monotonic_id) + + # Verify result + self.assertEqual(result, value) + + def test_put_same_key_twice(self): + """Test behavior when putting the same key multiple times.""" + key = "duplicate_key" + value1 = "first value" + value2 = "second value" + + # First put + address1, id1 = self.storage.put(key, value1) + retrieved1 = self.storage.get(address1, id1) + self.assertEqual(retrieved1, value1) + + # should raise an error on second put + with self.assertRaises(ValueError) as context: + self.storage.put(key, value2) + + self.assertIn("already exists in the storage", str(context.exception)) + + def test_large_object_rejection(self): + """Test that objects exceeding max_object_size are rejected.""" + # Create an object larger than max_object_size + large_data = "x" * (self.storage.max_object_size + 100) + + with self.assertRaises(ValueError) as context: + self.storage.put("large_key", large_data) + + self.assertIn("exceeds max object size", str(context.exception)) + + def test_buffer_overflow_and_cleanup(self): + """Test behavior when buffer fills up and needs cleanup.""" + # Fill up the buffer with many small objects + stored_items = [] + + try: + for i in range(1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # Verify that some items are still accessible + accessible_count = 0 + for key, original_value, address, monotonic_id in stored_items: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + + self.assertEqual(accessible_count, len(stored_items)) + + try: + for i in range(len(stored_items), 1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # Verify that some items are still accessibles + for key, original_value, address, monotonic_id in stored_items: + try: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + except ValueError as e: + print(f"Error retrieving {key}: {e}") + + # some items from the first batch may still be accessible + self.assertGreaterEqual(accessible_count, len(stored_items)) + + def test_blocking_unread_object(self): + """Test behavior when buffer fills up and needs cleanup.""" + # Fill up the buffer with many small objects + stored_items = [] + + try: + for i in range(1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # read all items except the first one + # to simulate a blocking situation + accessible_count = 0 + for key, original_value, address, monotonic_id in stored_items[1:]: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + + self.assertEqual(accessible_count, len(stored_items) - 1) + + try: + key = f"item_{len(stored_items)}" + value = f"data_{len(stored_items)}" * 100 + address, monotonic_id = self.storage.put(key, value) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # read the first item + for i in range(self.storage.n_readers): + key, original_value, address, monotonic_id = stored_items[0] + retrieved = self.storage.get(address, monotonic_id) + self.assertEqual(retrieved, original_value) + + try: + for i in range(len(stored_items), 1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # some items from the first batch may still be accessible + self.assertGreaterEqual(len(stored_items), accessible_count + 10) + + def test_invalid_get_operations(self): + """Test various invalid get operations.""" + # Test with non-existent address + with self.assertRaises(ValueError): # Could be various exceptions + self.storage.get(99999, 1) + + # Store something first + address, monotonic_id = self.storage.put("test", "value") + + # Test with wrong monotonic_id + with self.assertRaises(ValueError) as context: + self.storage.get(address, monotonic_id + 100) + + self.assertIn("has been modified or is invalid", \ + str(context.exception)) + + def test_clear_storage(self): + """Test clearing the storage.""" + # Store some items + for i in range(5): + self.storage.put(f"item_{i}", f"value_{i}") + + # Clear the storage + self.storage.clear() + + # Verify that all indices are empty + self.assertEqual(len(self.storage.key_index), 0) + self.assertEqual(len(self.storage.id_index), 0) + self.assertEqual(len(self.storage.ring_buffer.metadata), 0) + + # Verify that new items can be added after clearing + address, monotonic_id = self.storage.put("new_item", "new_value") + self.assertIn("new_item", self.storage.key_index) + self.assertEqual((address, monotonic_id), (0, 0)) + + +# Reader process function +def reader_process(process_id, storage_handle, items_to_read): + """Reader process that connects to existing shared memory and reads data.""" + reader_storage = SingleWriterShmObjectStorage.create_from_handle( + storage_handle) + + print(f"Reader {process_id} started") + + errors = [] + + for key, original_value, address, monotonic_id in items_to_read: + time.sleep(random.random() / 100) + try: + # Read data from shared memory + retrieved_value = reader_storage.get(address, monotonic_id) + + # Verify data integrity + assert retrieved_value == original_value + print(f"Reader {process_id} retrieved {key}: {retrieved_value}") + except Exception as e: + errors.append((key, str(e), type(e).__name__)) + + +def run_multiprocess_example(): + """Run a minimal working example with real shared memory.""" + print("=== Minimal Object Storage Example ===") + + try: + # Create storage instance + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=1024 * 100, + create=True, # 10 MB buffer + ) + storage = SingleWriterShmObjectStorage( + max_object_size=1024, + n_readers=3, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=Lock(), + ) + + print(f"Created storage (writer: {storage.is_writer})") + + # Test basic data types + test_data = [ + ("user_data", { + "name": "Alice", + "age": 30, + "scores": [95, 87, 92] + }), + ("simple_string", "Hello, World!"), + ("number", 42), + ("list_data", [1, 2, 3, "four", 5.0]), + ] + + stored_items = [] + + # Store all data + for key, value in test_data: + print(f"Storing {key}: {value}") + address, monotonic_id = storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + print(f" -> Stored at address {address}, ID {monotonic_id}") + + print("\n--- Retrieving Data ---") + processes = [] + handle = storage.handle() + # initialize lock for reader processes + handle.reader_lock = Lock() + for i in range(storage.n_readers): + p = multiprocessing.Process(target=reader_process, + args=(i, handle, stored_items)) + processes.append(p) + p.start() + + for p in processes: + p.join(timeout=10) + if p.is_alive(): + p.terminate() + p.join() + + except Exception as e: + print(f"Error in minimal example: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + # Run the minimal example first + run_multiprocess_example() + print("\n" + "=" * 50 + "\n") + + # Run the test suite + print("Running comprehensive test suite...") + unittest.main(verbosity=2, exit=False) diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py deleted file mode 100644 index 8b99d9d6e21fb..0000000000000 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""E2E tests to verify the correctness of the encoder-decoder framework - -Run `pytest tests/encoder_decoder/test_e2e_correctness.py`. -""" -from typing import Optional - -import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.attention.selector import (_Backend, _cached_get_attn_backend, - global_force_attn_backend_context_manager) -from vllm.platforms import current_platform -from vllm.sequence import SampleLogprobs - -from ..conftest import DecoderPromptType -from ..models.utils import check_logprobs_close - -LIST_ENC_DEC_SUPPORTED_BACKENDS = [ - _Backend.XFORMERS, _Backend.FLASH_ATTN, None -] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -def vllm_to_hf_output( - vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, -): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - hf_output_str = output_str + "</s>" - if decoder_prompt_type == DecoderPromptType.NONE: - hf_output_str = "<s>" + hf_output_str - - return output_ids, hf_output_str, out_logprobs - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Fixture to clear backend cache before each test.""" - _cached_get_attn_backend.cache_clear() # Clear the cache - yield # This allows the test to run - - -@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) -@pytest.mark.parametrize("enforce_eager", [True, False]) -@pytest.mark.skipif( - current_platform.is_cpu(), - reason="CPU backend is not currently supported with encoder/decoder models" -) -def test_encoder_decoder_e2e( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, - decoder_prompt_type: DecoderPromptType, - enforce_eager: bool, - attn_backend: _Backend, -) -> None: - ''' - End-to-End (E2E) test for the encoder-decoder framework. - This test evaluates the encoder-decoder functionality using the BART - model. We compare the outputs of the Hugging Face and vLLM - implementations to ensure that both implementations produce consistent - and correct results. - ''' - with global_force_attn_backend_context_manager(attn_backend): - if attn_backend == _Backend.FLASH_ATTN: - # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' - test_case_prompts = example_encoder_decoder_prompts[ - decoder_prompt_type] - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = ( - hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_case_prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - with vllm_runner(model, dtype=dtype, - enforce_eager=enforce_eager) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_case_prompts, max_tokens, num_logprobs) - - hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE - else 0) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) diff --git a/tests/engine/conftest.py b/tests/engine/conftest.py deleted file mode 100644 index 375b248ebedaa..0000000000000 --- a/tests/engine/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index ba8e31a79feb5..b82e839638041 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -287,15 +287,6 @@ def test_prefix_cache_default(): }, "mm-processor-kwargs" ), - ( - '{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}', - { - "cast_logits_dtype": "bfloat16", - "sequence_parallel_norm": True, - "sequence_parallel_norm_threshold": 2048, - }, - "override-neuron-config" - ), ]) # yapf: enable def test_composite_arg_parser(arg, expected, option): diff --git a/tests/engine/test_computed_prefix_blocks.py b/tests/engine/test_computed_prefix_blocks.py deleted file mode 100644 index ac5a1f957dfe4..0000000000000 --- a/tests/engine/test_computed_prefix_blocks.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.sampling_params import SamplingParams - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -@pytest.mark.parametrize("block_size", [16]) -def test_computed_prefix_blocks(model: str, block_size: int): - # This test checks if we are able to run the engine to completion - # without triggering asserts. - # We are in a scenario where all blocks from the second request's prompt - # are full and already computed when the second request arrives. - prompt = ( - "You are a helpful assistant. How do I build a car from cardboard and " - "paper clips? Is there an easy to follow video tutorial available " - "online for free?") - prompt2 = ( - " Please recommend to me some resources where I can learn not only to " - "handle technical difficulties of building a car, but also " - "decoration.") - - engine_args = EngineArgs(model=model, - block_size=block_size, - enable_prefix_caching=True) - - engine = LLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams() - - engine.add_request("0", prompt + prompt2, sampling_params) - engine.step() - engine.add_request("1", prompt, sampling_params) - engine.step() diff --git a/tests/engine/test_executor.py b/tests/engine/test_executor.py deleted file mode 100644 index 15c7a97b50e1f..0000000000000 --- a/tests/engine/test_executor.py +++ /dev/null @@ -1,111 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import os -from typing import Any, Callable, Optional, Union - -import pytest - -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.llm_engine import LLMEngine -from vllm.executor.uniproc_executor import UniProcExecutor -from vllm.sampling_params import SamplingParams - - -class Mock: - ... - - -class CustomUniExecutor(UniProcExecutor): - - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None) -> list[Any]: - # Drop marker to show that this was ran - with open(".marker", "w"): - ... - return super().collective_rpc(method, timeout, args, kwargs) - - -CustomUniExecutorAsync = CustomUniExecutor - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_custom_executor_type_checking(model): - with pytest.raises(ValueError): - engine_args = EngineArgs(model=model, - distributed_executor_backend=Mock) - LLMEngine.from_engine_args(engine_args) - with pytest.raises(ValueError): - engine_args = AsyncEngineArgs(model=model, - distributed_executor_backend=Mock) - AsyncLLMEngine.from_engine_args(engine_args) - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_custom_executor(model, tmp_path): - cwd = os.path.abspath(".") - os.chdir(tmp_path) - try: - assert not os.path.exists(".marker") - - engine_args = EngineArgs( - model=model, - distributed_executor_backend=CustomUniExecutor, - enforce_eager=True, # reduce test time - ) - engine = LLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(max_tokens=1) - - engine.add_request("0", "foo", sampling_params) - engine.step() - - assert os.path.exists(".marker") - finally: - os.chdir(cwd) - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_custom_executor_async(model, tmp_path): - cwd = os.path.abspath(".") - os.chdir(tmp_path) - try: - assert not os.path.exists(".marker") - - engine_args = AsyncEngineArgs( - model=model, - distributed_executor_backend=CustomUniExecutorAsync, - enforce_eager=True, # reduce test time - ) - engine = AsyncLLMEngine.from_engine_args(engine_args) - sampling_params = SamplingParams(max_tokens=1) - - async def t(): - stream = await engine.add_request("0", "foo", sampling_params) - async for x in stream: - ... - - asyncio.run(t()) - - assert os.path.exists(".marker") - finally: - os.chdir(cwd) - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_respect_ray(model): - # even for TP=1 and PP=1, - # if users specify ray, we should use ray. - # users might do this if they want to manage the - # resources using ray. - engine_args = EngineArgs( - model=model, - distributed_executor_backend="ray", - enforce_eager=True, # reduce test time - ) - engine = LLMEngine.from_engine_args(engine_args) - assert engine.model_executor.uses_ray diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py deleted file mode 100644 index b5381b61a020a..0000000000000 --- a/tests/engine/test_multiproc_workers.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from time import sleep -from typing import Any - -import pytest - -from vllm.config import VllmConfig -from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, - ResultHandler, WorkerMonitor) -from vllm.worker.worker_base import WorkerWrapperBase - - -class DummyWorkerWrapper(WorkerWrapperBase): - """Dummy version of vllm.worker.worker.Worker""" - - def worker_method(self, worker_input: Any) -> tuple[int, Any]: - sleep(0.05) - - if isinstance(worker_input, Exception): - # simulate error case - raise worker_input - - return self.rpc_rank, input - - -def _start_workers() -> tuple[list[ProcessWorkerWrapper], WorkerMonitor]: - result_handler = ResultHandler() - vllm_config = VllmConfig() - workers = [ - ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config, - rank) for rank in range(8) - ] - - worker_monitor = WorkerMonitor(workers, result_handler) - assert not worker_monitor.is_alive() - - result_handler.start() - worker_monitor.start() - assert worker_monitor.is_alive() - - return workers, worker_monitor - - -def test_local_workers() -> None: - """Test workers with sync task submission""" - - workers, worker_monitor = _start_workers() - - def execute_workers(worker_input: str) -> None: - worker_outputs = [ - worker.execute_method("worker_method", worker_input) - for worker in workers - ] - - for rank, output in enumerate(worker_outputs): - assert output.get() == (rank, input) - - executor = ThreadPoolExecutor(max_workers=4) - - # Test concurrent submission from different threads - futures = [ - executor.submit(partial(execute_workers, f"thread {thread_num}")) - for thread_num in range(4) - ] - - for future in futures: - future.result() - - # Test error case - exception = ValueError("fake error") - result = workers[0].execute_method("worker_method", exception) - try: - result.get() - pytest.fail("task should have failed") - except Exception as e: - assert isinstance(e, ValueError) - assert str(e) == "fake error" - - # Test cleanup when a worker fails - assert worker_monitor.is_alive() - workers[3].process.kill() - - # Other workers should get shut down here - worker_monitor.join(20) - - # Ensure everything is stopped - assert not worker_monitor.is_alive() - assert all(not worker.process.is_alive() for worker in workers) - - # Further attempts to submit tasks should fail - try: - _result = workers[0].execute_method("worker_method", "test") - pytest.fail("task should fail once workers have been shut down") - except Exception as e: - assert isinstance(e, ChildProcessError) - - -def test_local_workers_clean_shutdown() -> None: - """Test clean shutdown""" - - workers, worker_monitor = _start_workers() - - assert worker_monitor.is_alive() - assert all(worker.process.is_alive() for worker in workers) - - # Clean shutdown - worker_monitor.close() - - worker_monitor.join(20) - - # Ensure everything is stopped - assert not worker_monitor.is_alive() - assert all(not worker.process.is_alive() for worker in workers) - - # Further attempts to submit tasks should fail - try: - _result = workers[0].execute_method("worker_method", "test") - pytest.fail("task should fail once workers have been shut down") - except Exception as e: - assert isinstance(e, ChildProcessError) - - -@pytest.mark.asyncio -async def test_local_workers_async() -> None: - """Test local workers with async task submission""" - - workers, worker_monitor = _start_workers() - - async def execute_workers(worker_input: str) -> None: - worker_coros = [ - worker.execute_method_async("worker_method", worker_input) - for worker in workers - ] - - results = await asyncio.gather(*worker_coros) - for rank, result in enumerate(results): - assert result == (rank, input) - - tasks = [ - asyncio.create_task(execute_workers(f"task {task_num}")) - for task_num in range(4) - ] - - for task in tasks: - await task - - # Test error case - exception = ValueError("fake error") - try: - _result = await workers[0].execute_method_async( - "worker_method", exception) - pytest.fail("task should have failed") - except Exception as e: - assert isinstance(e, ValueError) - assert str(e) == "fake error" - - # Test cleanup when a worker fails - assert worker_monitor.is_alive() - workers[3].process.kill() - - # Other workers should get shut down here - worker_monitor.join(20) - - # Ensure everything is stopped - assert not worker_monitor.is_alive() - assert all(not worker.process.is_alive() for worker in workers) - - # Further attempts to submit tasks should fail - try: - _result = await workers[0].execute_method_async( - "worker_method", "test") - pytest.fail("task should fail once workers have been shut down") - except Exception as e: - assert isinstance(e, ChildProcessError) diff --git a/tests/engine/test_options.py b/tests/engine/test_options.py deleted file mode 100644 index 42e88e84770ab..0000000000000 --- a/tests/engine/test_options.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from contextlib import nullcontext - -import pytest - -from vllm.entrypoints.llm import LLM -from vllm.sampling_params import SamplingParams - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_skip_tokenizer_initialization(model: str): - # This test checks if the flag skip_tokenizer_init skips the initialization - # of tokenizer and detokenizer. The generated output is expected to contain - # token ids. - llm = LLM( - model=model, - skip_tokenizer_init=True, - enforce_eager=True, - ) - sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) - - with pytest.raises(ValueError, match="cannot pass text prompts when"): - llm.generate("abc", sampling_params) - - outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, - sampling_params=sampling_params) - assert len(outputs) > 0 - completions = outputs[0].outputs - assert len(completions) > 0 - assert completions[0].text == "" - assert completions[0].token_ids - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) -def test_enable_prompt_embeds(hf_runner, model: str, - enable_prompt_embeds: bool): - prompt = "abc" - - with hf_runner(model) as hf_model: - token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids - token_ids = token_ids.to(hf_model.model.device) - - embed_layer = hf_model.model.get_input_embeddings() - prompt_embeds = embed_layer(token_ids).squeeze(0) - - ctx = (nullcontext() if enable_prompt_embeds else pytest.raises( - ValueError, match="set `--enable-prompt-embeds`")) - - llm = LLM( - model=model, - enable_prompt_embeds=enable_prompt_embeds, - enforce_eager=True, - ) - - with ctx: - llm.generate({"prompt_embeds": prompt_embeds}) diff --git a/tests/engine/test_short_mm_context.py b/tests/engine/test_short_mm_context.py index 9c62761d78afb..9eb3dfc09224e 100644 --- a/tests/engine/test_short_mm_context.py +++ b/tests/engine/test_short_mm_context.py @@ -25,6 +25,7 @@ def test_context_length_too_short(vllm_runner, image_assets, model): model, max_model_len=128, # LLaVA has a feature size of 576 enforce_eager=True, + load_format="dummy", ) with vllm_model: diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index 48fd848e88200..da75806ccf4de 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -184,7 +184,7 @@ def sample_enum_json_schema(): @pytest.fixture -def sample_guided_choice(): +def sample_structured_outputs_choices(): return [ "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby", "Swift", "Kotlin" @@ -208,25 +208,3 @@ def zephyr_lora_files(): """Download zephyr LoRA files once per test session.""" from huggingface_hub import snapshot_download return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora") - - -@pytest.fixture(scope="session") -def zephyr_lora_added_tokens_files(zephyr_lora_files): - """Create zephyr LoRA files with added tokens once per test session.""" - import shutil - from tempfile import TemporaryDirectory - - from transformers import AutoTokenizer - - tmp_dir = TemporaryDirectory() - tmp_model_dir = f"{tmp_dir.name}/zephyr" - shutil.copytree(zephyr_lora_files, tmp_model_dir) - tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - # Copy tokenizer to adapter and add some unique tokens - # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) - assert added == 3 - tokenizer.save_pretrained(tmp_model_dir) - yield tmp_model_dir - tmp_dir.cleanup() diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 2cbfed98a577a..bf460d0fb25d3 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -7,7 +7,7 @@ import pytest from vllm import LLM from vllm.distributed import cleanup_dist_env_and_memory -from ..openai.test_vision import TEST_IMAGE_URLS +from ..openai.test_vision import TEST_IMAGE_ASSETS @pytest.fixture(scope="function") @@ -95,7 +95,8 @@ def vision_llm(): @pytest.mark.parametrize("image_urls", - [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) + [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], + indirect=True) def test_chat_multi_image(vision_llm, image_urls: list[str]): messages = [{ "role": diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py deleted file mode 100644 index ac0b7e134c55a..0000000000000 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ /dev/null @@ -1,82 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import sys -from contextlib import nullcontext - -from vllm_test_utils import BlameResult, blame - -from vllm import LLM, SamplingParams -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.sampling_params import GuidedDecodingParams - - -def run_normal(): - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - # Create an LLM without guided decoding as a baseline. - llm = LLM(model="distilbert/distilgpt2", - enforce_eager=True, - gpu_memory_utilization=0.3) - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - # Destroy the LLM object and free up the GPU memory. - del llm - cleanup_dist_env_and_memory() - - -def run_xgrammar(sample_regex): - # Create an LLM with guided decoding enabled. - llm = LLM(model="distilbert/distilgpt2", - enforce_eager=True, - guided_decoding_backend="xgrammar", - gpu_memory_utilization=0.3) - prompt = f"Give an example IPv4 address with this regex: {sample_regex}" - guided_decoding = GuidedDecodingParams(regex=sample_regex) - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=guided_decoding) - outputs = llm.generate( - prompts=[prompt] * 2, - sampling_params=sampling_params, - use_tqdm=True, - ) - - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def test_lazy_outlines(sample_regex): - """If users don't use guided decoding, outlines should not be imported. - """ - # make sure outlines is not imported - module_name = "outlines" - # In CI, we only check finally if the module is imported. - # If it is indeed imported, we can rerun the test with `use_blame=True`, - # which will trace every function call to find the first import location, - # and help find the root cause. - # We don't run it in CI by default because it is slow. - use_blame = False - context = blame( - lambda: module_name in sys.modules) if use_blame else nullcontext() - with context as result: - run_normal() - run_xgrammar(sample_regex) - if use_blame: - assert isinstance(result, BlameResult) - print(f"the first import location is:\n{result.trace_stack}") - assert module_name not in sys.modules, ( - f"Module {module_name} is imported. To see the first" - f" import location, run the test with `use_blame=True`.") diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index a154bb1059aae..f8ed5dda260ff 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -79,7 +79,7 @@ def test_offline_mode(monkeypatch: pytest.MonkeyPatch): ) # Need to re-import huggingface_hub - # and friends to setup offline mode + # and friends to set up offline mode _re_import_modules() # Cached model files should be used in offline mode for model_config in MODEL_CONFIGS: @@ -136,7 +136,7 @@ def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch): disable_connect, ) # Need to re-import huggingface_hub - # and friends to setup offline mode + # and friends to set up offline mode _re_import_modules() engine_args = EngineArgs(model="facebook/opt-125m") LLM(**dataclasses.asdict(engine_args)) diff --git a/tests/entrypoints/openai/correctness/test_lmeval.py b/tests/entrypoints/openai/correctness/test_lmeval.py index 684407cd6ee97..624acd5ffde73 100644 --- a/tests/entrypoints/openai/correctness/test_lmeval.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -81,13 +81,3 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): more_args = ["--max-num-seqs", "64"] run_test(more_args) - - -@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) -def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch, - more_args): - """Run with the V0 Engine.""" - - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - run_test(more_args) diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 0d0ce0be8c5f8..9122b7003bf9a 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -32,7 +32,7 @@ def to_bytes(y, sr): async def transcribe_audio(client, tokenizer, y, sr): # Send loaded audio directly instead of loading from disk, - # dont account for that time though + # don't account for that time though with to_bytes(y, sr) as f: start_time = time.perf_counter() transcription = await client.audio.transcriptions.create( diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index c9947c54a9181..a827f94cfbfe5 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# imports for guided decoding tests +# imports for structured outputs tests import json from typing import Optional @@ -12,7 +12,7 @@ import pytest_asyncio import regex as re import requests import torch -from openai import BadRequestError, OpenAI +from openai import BadRequestError from ...utils import RemoteOpenAIServer @@ -29,11 +29,7 @@ def monkeypatch_module(): @pytest.fixture(scope="module", params=[False, True]) -def server( - request, - monkeypatch_module, - zephyr_lora_files, #noqa: F811 - zephyr_lora_added_tokens_files): # noqa: F811 +def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811 use_v1 = request.param monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') @@ -49,7 +45,6 @@ def server( "--enable-lora", "--lora-modules", f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", "--max-lora-rank", "64", "--max-cpu-loras", @@ -79,7 +74,7 @@ async def client(server): @pytest.mark.parametrize( # first test base model, then test loras "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + [MODEL_NAME, "zephyr-lora"], ) async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): messages = [{ @@ -485,10 +480,11 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, @pytest.mark.asyncio -async def test_guided_choice_chat(client: openai.AsyncOpenAI, - sample_guided_choice, is_v1_server: bool): +async def test_structured_outputs_choice_chat( + client: openai.AsyncOpenAI, sample_structured_outputs_choices, + is_v1_server: bool): if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") + pytest.skip("Structured outputs is only supported in v1 engine") messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -503,9 +499,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, messages=messages, max_completion_tokens=10, temperature=0.7, - extra_body=dict(guided_choice=sample_guided_choice)) + extra_body=dict( + structured_outputs={"choice": sample_structured_outputs_choices})) choice1 = chat_completion.choices[0].message.content - assert choice1 in sample_guided_choice + assert choice1 in sample_structured_outputs_choices messages.append({"role": "assistant", "content": choice1}) messages.append({ @@ -517,17 +514,19 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, messages=messages, max_completion_tokens=10, temperature=0.7, - extra_body=dict(guided_choice=sample_guided_choice)) + extra_body=dict( + structured_outputs={"choice": sample_structured_outputs_choices})) choice2 = chat_completion.choices[0].message.content - assert choice2 in sample_guided_choice + assert choice2 in sample_structured_outputs_choices assert choice1 != choice2 @pytest.mark.asyncio -async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema, - is_v1_server: bool): +async def test_structured_outputs_json_chat(client: openai.AsyncOpenAI, + sample_json_schema, + is_v1_server: bool): if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") + pytest.skip("Structured outputs is only supported in v1 engine") messages = [{ "role": "system", @@ -543,7 +542,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema, model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - extra_body=dict(guided_json=sample_json_schema)) + extra_body=dict(structured_outputs={"json": sample_json_schema})) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) @@ -560,7 +559,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema, model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - extra_body=dict(guided_json=sample_json_schema)) + extra_body=dict(structured_outputs={"json": sample_json_schema})) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -570,10 +569,10 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema, @pytest.mark.asyncio -async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex, - is_v1_server: bool): +async def test_structured_outputs_regex_chat(client: openai.AsyncOpenAI, + sample_regex, is_v1_server: bool): if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") + pytest.skip("Structured outputs is only supported in v1 engine") messages = [{ "role": "system", @@ -588,7 +587,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex, model=MODEL_NAME, messages=messages, max_completion_tokens=20, - extra_body=dict(guided_regex=sample_regex)) + extra_body=dict(structured_outputs={"regex": sample_regex})) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(sample_regex, ip1) is not None @@ -599,7 +598,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex, model=MODEL_NAME, messages=messages, max_completion_tokens=20, - extra_body=dict(guided_regex=sample_regex)) + extra_body=dict(structured_outputs={"regex": sample_regex})) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(sample_regex, ip2) is not None @@ -607,7 +606,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex, @pytest.mark.asyncio -async def test_guided_decoding_type_error(client: openai.AsyncOpenAI): +async def test_structured_outputs_type_error(client: openai.AsyncOpenAI): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -619,17 +618,19 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI): }] with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - extra_body=dict(guided_regex={ - 1: "Python", - 2: "C++" - })) + _ = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + extra_body=dict( + structured_outputs={"regex": { + 1: "Python", + 2: "C++" + }})) @pytest.mark.asyncio -async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, - sample_guided_choice): +async def test_structured_outputs_choice_chat_logprobs( + client: openai.AsyncOpenAI, sample_structured_outputs_choices): messages = [{ "role": "system", @@ -646,7 +647,8 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, top_logprobs=5, - extra_body=dict(guided_choice=sample_guided_choice)) + extra_body=dict( + structured_outputs={"choice": sample_structured_outputs_choices})) assert chat_completion.choices[0].logprobs is not None assert chat_completion.choices[0].logprobs.content is not None @@ -668,10 +670,23 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema, }, { "role": "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}" + "content": ("Give an example JSON for an employee " + "profile using the specified tool.") }] + tools = [{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema + } + }] + tool_choice = { + "type": "function", + "function": { + "name": "dummy_function_name" + } + } # non-streaming @@ -679,20 +694,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema, model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }], - tool_choice={ - "type": "function", - "function": { - "name": "dummy_function_name" - } - }, + tools=tools, + tool_choice=tool_choice, ) message = chat_completion.choices[0].message assert len(message.content) == 0 @@ -710,25 +713,12 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema, # streaming - stream = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }], - tool_choice={ - "type": "function", - "function": { - "name": "dummy_function_name" - } - }, - stream=True) + stream = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=tools, + tool_choice=tool_choice, + stream=True) output = [] finish_reason_count = 0 @@ -968,59 +958,6 @@ async def test_long_seed(client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) -@pytest.mark.asyncio -async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer): - url = f"http://localhost:{server.port}/v1/chat/completions" - headers = { - "Content-Type": "application/json", - } - data = { - # model_name is avoided here. - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "what is 1+1?" - }], - "max_tokens": - 5 - } - - response = requests.post(url, headers=headers, json=data) - response_data = response.json() - print(response_data) - assert response_data.get("model") == MODEL_NAME - choice = response_data.get("choices")[0] - message = choice.get("message") - assert message is not None - content = message.get("content") - assert content is not None - assert len(content) > 0 - - -@pytest.mark.asyncio -async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer): - openai_api_key = "EMPTY" - openai_api_base = f"http://localhost:{server.port}/v1" - - client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, - ) - messages = [ - { - "role": "user", - "content": "Hello, vLLM!" - }, - ] - response = client.chat.completions.create( - model="", # empty string - messages=messages, - ) - assert response.model == MODEL_NAME - - @pytest.mark.asyncio async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): diff --git a/tests/entrypoints/openai/test_chat_echo.py b/tests/entrypoints/openai/test_chat_echo.py index de63f4ed218b6..ce965eb829248 100644 --- a/tests/entrypoints/openai/test_chat_echo.py +++ b/tests/entrypoints/openai/test_chat_echo.py @@ -22,6 +22,8 @@ def server(): "--enforce-eager", "--max-model-len", "4080", + "--max-logprobs", # test prompt_logprobs equal to -1 + "151936" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -77,3 +79,46 @@ async def test_chat_session_with_echo_and_continue_final_message( else: assert message.content is not None and saying not in message.content assert message.role == "assistant" + + +@pytest.mark.asyncio +async def test_prompt_logprobs(client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Beijing is the capital of which country?" + }] + + completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + extra_body={"prompt_logprobs": -1}, + ) + + assert completion.prompt_logprobs is not None + assert len(completion.prompt_logprobs) > 0 + + +@pytest.mark.asyncio +async def test_top_logprobs(client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Beijing is the capital of which country?" + }] + + completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + extra_body={ + "top_logprobs": -1, + "logprobs": "true", + }, + ) + assert completion.choices[0].logprobs is not None + assert completion.choices[0].logprobs.content is not None + assert len(completion.choices[0].logprobs.content) > 0 diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 5b6e2a4146b1f..ce90a67c01517 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt, trust_remote_code=model_info.trust_remote_code, revision=model_info.revision, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) # Initialize the tokenizer tokenizer = get_tokenizer( diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index d55f8d9d65d9b..0347513befe32 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# imports for guided decoding tests +# imports for structured outputs tests import json import os from typing import Optional @@ -23,11 +23,9 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # technically these adapters use a different base model, # but we're not testing generation quality here -GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"] - @pytest.fixture(scope="module") -def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files): +def default_server_args(zephyr_lora_files): return [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -41,7 +39,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files): "--enable-lora", "--lora-modules", f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", "--max-lora-rank", "64", "--max-cpu-loras", @@ -87,7 +84,7 @@ async def client(server): @pytest.mark.parametrize( # first test base model, then test loras "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + [MODEL_NAME, "zephyr-lora"], ) async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): completion = await client.completions.create(model=model_name, @@ -115,20 +112,6 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): assert completion.choices[0].prompt_logprobs is None -@pytest.mark.asyncio -async def test_added_lora_tokens(client: openai.AsyncOpenAI): - # test using token IDs - completion = await client.completions.create( - model="zephyr-lora2", - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - # Added tokens should appear in tokenized prompt - assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3") - - @pytest.mark.asyncio async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): # test using token IDs @@ -147,7 +130,7 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): @pytest.mark.parametrize( # first test base model, then test loras "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + [MODEL_NAME, "zephyr-lora"], ) async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs @@ -610,12 +593,13 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI): @pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_json_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema, is_v1_server: bool): +async def test_structured_outputs_json_completion( + client: openai.AsyncOpenAI, + sample_json_schema, + is_v1_server: bool, +): if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") + pytest.skip("structured outputs is only supported in v1 engine") completion = await client.completions.create( model=MODEL_NAME, @@ -624,8 +608,7 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI, n=3, temperature=1.0, max_tokens=500, - extra_body=dict(guided_json=sample_json_schema, - guided_decoding_backend=guided_decoding_backend)) + extra_body=dict(structured_outputs=dict(json=sample_json_schema))) assert completion.id is not None assert len(completion.choices) == 3 @@ -635,12 +618,13 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_regex_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_regex, is_v1_server: bool): +async def test_structured_outputs_regex_completion( + client: openai.AsyncOpenAI, + sample_regex, + is_v1_server: bool, +): if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") + pytest.skip("structured outputs is only supported in v1 engine") completion = await client.completions.create( model=MODEL_NAME, @@ -648,8 +632,7 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI, n=3, temperature=1.0, max_tokens=20, - extra_body=dict(guided_regex=sample_regex, - guided_decoding_backend=guided_decoding_backend)) + extra_body=dict(structured_outputs=dict(regex=sample_regex))) assert completion.id is not None assert len(completion.choices) == 3 @@ -659,13 +642,13 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_choice_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_guided_choice, - is_v1_server: bool): +async def test_structured_outputs_choice_completion( + client: openai.AsyncOpenAI, + sample_structured_outputs_choices, + is_v1_server: bool, +): if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") + pytest.skip("structured outputs is only supported in v1 engine") completion = await client.completions.create( model=MODEL_NAME, @@ -673,20 +656,21 @@ async def test_guided_choice_completion(client: openai.AsyncOpenAI, n=2, temperature=1.0, max_tokens=10, - extra_body=dict(guided_choice=sample_guided_choice, - guided_decoding_backend=guided_decoding_backend)) + extra_body=dict(structured_outputs=dict( + choice=sample_structured_outputs_choices))) assert completion.id is not None assert len(completion.choices) == 2 for i in range(2): - assert completion.choices[i].text in sample_guided_choice + assert completion.choices[i].text in sample_structured_outputs_choices @pytest.mark.asyncio -async def test_guided_grammar(client: openai.AsyncOpenAI, - sample_sql_statements, is_v1_server: bool): +async def test_structured_outputs_grammar(client: openai.AsyncOpenAI, + sample_sql_statements, + is_v1_server: bool): if not is_v1_server: - pytest.skip("Guided grammar is only supported in v1 engine") + pytest.skip("grammar is only supported in v1 engine") completion = await client.completions.create( model=MODEL_NAME, @@ -694,7 +678,8 @@ async def test_guided_grammar(client: openai.AsyncOpenAI, "table_1 where it is equals to 1"), temperature=1.0, max_tokens=500, - extra_body=dict(guided_grammar=sample_sql_statements)) + extra_body=dict( + structured_outputs=dict(grammar=sample_sql_statements), )) content = completion.choices[0].text @@ -713,7 +698,7 @@ async def test_guided_grammar(client: openai.AsyncOpenAI, @pytest.mark.parametrize( # first test base model, then test loras "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + [MODEL_NAME, "zephyr-lora"], ) @pytest.mark.parametrize("logprobs_arg", [1, 0]) async def test_echo_logprob_completion(client: openai.AsyncOpenAI, @@ -745,27 +730,26 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema, sample_regex, - is_v1_server: bool): +async def test_structured_outputs_type_error(client: openai.AsyncOpenAI, + sample_json_schema, sample_regex, + is_v1_server: bool): if not is_v1_server: - pytest.skip("Guided decoding is only supported in v1 engine") + pytest.skip("structured outputs is only supported in v1 engine") with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42, - guided_decoding_backend=guided_decoding_backend)) + extra_body=dict(structured_outputs=dict(json=42))) with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example string that fits this regex", - extra_body=dict(guided_regex=sample_regex, - guided_json=sample_json_schema)) + extra_body=dict(structured_outputs=dict( + regex=sample_regex, + json=sample_json_schema, + ))) @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index 4ef5d4e8a699a..3649cefa9bf42 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -142,7 +142,7 @@ def server(): # noqa: F811 "--dtype", "half", "--enable-auto-tool-choice", - "--guided-decoding-backend", + "--structured-outputs-config.backend", "xgrammar", "--tool-call-parser", "hermes", @@ -225,7 +225,7 @@ def k2_server(): # noqa: F811 "--dtype", "half", "--enable-auto-tool-choice", - "--guided-decoding-backend", + "--structured-outputs-config.backend", "xgrammar", "--tool-call-parser", "hermes", diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index a0ef31762ea15..7b58f851a4d21 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -21,10 +21,7 @@ CONFIG = AutoConfig.from_pretrained(MODEL_NAME) @pytest.fixture(scope="module") -def default_server_args( - zephyr_lora_files, - zephyr_lora_added_tokens_files, -) -> list[str]: +def default_server_args() -> list[str]: return [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -231,3 +228,20 @@ async def test_completions_with_logprobs_and_prompt_embeds( assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) == 5 + + +@pytest.mark.asyncio +async def test_prompt_logprobs_raises_error( + client_with_prompt_embeds: openai.AsyncOpenAI): + with pytest.raises(BadRequestError, match="not compatible"): + encoded_embeds = create_dummy_embeds() + await client_with_prompt_embeds.completions.create( + model=MODEL_NAME, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={ + "prompt_embeds": encoded_embeds, + "prompt_logprobs": True + }, + ) diff --git a/tests/entrypoints/openai/test_encoder_decoder.py b/tests/entrypoints/openai/test_encoder_decoder.py deleted file mode 100644 index 9c2aef23e8772..0000000000000 --- a/tests/entrypoints/openai/test_encoder_decoder.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import openai -import pytest -import pytest_asyncio - -from ...utils import RemoteOpenAIServer - -MODEL_NAME = "facebook/bart-base" - - -@pytest.fixture(scope="module") -def server(): - args = [ - "--dtype", - "bfloat16", - "--enforce-eager", - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: - yield async_client - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - - choice = completion.choices[0] - assert len(choice.text) >= 5 - assert choice.finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=2, total_tokens=7) - - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 1 diff --git a/tests/entrypoints/openai/test_lora_adapters.py b/tests/entrypoints/openai/test_lora_adapters.py index f91dcf194b839..10c0cb5f4d151 100644 --- a/tests/entrypoints/openai/test_lora_adapters.py +++ b/tests/entrypoints/openai/test_lora_adapters.py @@ -67,12 +67,6 @@ def server_with_lora_modules_json(request, monkeypatch_module, "base_model_name": MODEL_NAME } - lora_module_2 = { - "name": "zephyr-lora2", - "path": zephyr_lora_files, - "base_model_name": MODEL_NAME - } - args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -84,7 +78,6 @@ def server_with_lora_modules_json(request, monkeypatch_module, "--enable-lora", "--lora-modules", json.dumps(lora_module_1), - json.dumps(lora_module_2), "--max-lora-rank", "64", "--max-cpu-loras", @@ -121,7 +114,6 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI, for lora_model in lora_models) assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" - assert lora_models[1].id == "zephyr-lora2" @pytest.mark.asyncio @@ -209,7 +201,7 @@ async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path, @pytest.mark.asyncio async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files): - """Validate that many loras can be dynamically registered and inferenced + """Validate that many loras can be dynamically registered and inferenced with concurrently""" # This test file configures the server with --max-cpu-loras=2 and this test diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index 818efd825640c..e2c83b9c40045 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -9,8 +9,7 @@ from unittest.mock import MagicMock import pytest -from vllm.config import MultiModalConfig -from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_models import (BaseModelPath, @@ -18,6 +17,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath, from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM MODEL_NAME = "openai-community/gpt2" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] @@ -82,7 +82,7 @@ def register_mock_resolver(): @pytest.fixture def mock_serving_setup(): """Provides a mocked engine and serving completion instance.""" - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index ff2e7004ff9f8..0c9e0f3a51429 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -232,6 +232,9 @@ EXPECTED_METRICS_V1 = [ "vllm:gpu_cache_usage_perc", "vllm:gpu_prefix_cache_queries", "vllm:gpu_prefix_cache_hits", + "vllm:kv_cache_usage_perc", + "vllm:prefix_cache_queries", + "vllm:prefix_cache_hits", "vllm:num_preemptions_total", "vllm:prompt_tokens_total", "vllm:generation_tokens_total", @@ -250,12 +253,15 @@ EXPECTED_METRICS_V1 = [ "vllm:request_params_max_tokens_sum", "vllm:request_params_max_tokens_bucket", "vllm:request_params_max_tokens_count", - "vllm:time_to_first_token_seconds_sum", - "vllm:time_to_first_token_seconds_bucket", - "vllm:time_to_first_token_seconds_count", "vllm:time_per_output_token_seconds_sum", "vllm:time_per_output_token_seconds_bucket", "vllm:time_per_output_token_seconds_count", + "vllm:time_to_first_token_seconds_sum", + "vllm:time_to_first_token_seconds_bucket", + "vllm:time_to_first_token_seconds_count", + "vllm:inter_token_latency_seconds_sum", + "vllm:inter_token_latency_seconds_bucket", + "vllm:inter_token_latency_seconds_count", "vllm:e2e_request_latency_seconds_sum", "vllm:e2e_request_latency_seconds_bucket", "vllm:e2e_request_latency_seconds_count", @@ -273,7 +279,14 @@ EXPECTED_METRICS_V1 = [ "vllm:request_decode_time_seconds_count", ] -HIDDEN_DEPRECATED_METRICS: list[str] = [] +HIDDEN_DEPRECATED_METRICS: list[str] = [ + "vllm:gpu_cache_usage_perc", + "vllm:gpu_prefix_cache_queries", + "vllm:gpu_prefix_cache_hits", + "vllm:time_per_output_token_seconds_sum", + "vllm:time_per_output_token_seconds_bucket", + "vllm:time_per_output_token_seconds_count", +] @pytest.mark.asyncio @@ -289,9 +302,10 @@ async def test_metrics_exist(server: RemoteOpenAIServer, assert response.status_code == HTTPStatus.OK for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS): - if (not server.show_hidden_metrics - and metric not in HIDDEN_DEPRECATED_METRICS): - assert metric in response.text + if (metric in HIDDEN_DEPRECATED_METRICS + and not server.show_hidden_metrics): + continue + assert metric in response.text @pytest.mark.asyncio @@ -299,7 +313,7 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool): running_requests, waiting_requests, kv_cache_usage = ( - _get_running_metrics_from_api(server)) + _get_running_metrics_from_api(server, use_v1)) # Expect no running requests or kvcache usage assert running_requests == 0 @@ -322,7 +336,7 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer, # Check that we have running requests running_requests, waiting_requests, kv_cache_usage = ( - _get_running_metrics_from_api(server)) + _get_running_metrics_from_api(server, use_v1)) # Expect running requests and kvcache usage assert running_requests > 0 @@ -341,7 +355,7 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer, # Verify running and waiting requests counts and KV cache usage are zero running_requests_after, waiting_requests_after, kv_cache_usage_after = ( - _get_running_metrics_from_api(server)) + _get_running_metrics_from_api(server, use_v1)) assert running_requests_after == 0,\ (f"Expected 0 running requests after abort, got " @@ -354,7 +368,7 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer, f"{kv_cache_usage_after}") -def _get_running_metrics_from_api(server: RemoteOpenAIServer): +def _get_running_metrics_from_api(server: RemoteOpenAIServer, use_v1: bool): """Return (running_count, waiting_count, kv_cache_usage)""" response = requests.get(server.url_for("metrics")) @@ -363,6 +377,9 @@ def _get_running_metrics_from_api(server: RemoteOpenAIServer): # Verify running and waiting requests counts and KV cache usage are zero running_requests, waiting_requests, kv_cache_usage = None, None, None + kv_cache_usage_metric = ("vllm:kv_cache_usage_perc" + if use_v1 else "vllm:gpu_cache_usage_perc") + for family in text_string_to_metric_families(response.text): if family.name == "vllm:num_requests_running": for sample in family.samples: @@ -374,9 +391,9 @@ def _get_running_metrics_from_api(server: RemoteOpenAIServer): if sample.name == "vllm:num_requests_waiting": waiting_requests = sample.value break - elif family.name == "vllm:gpu_cache_usage_perc": + elif family.name == kv_cache_usage_metric: for sample in family.samples: - if sample.name == "vllm:gpu_cache_usage_perc": + if sample.name == kv_cache_usage_metric: kv_cache_usage = sample.value break diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index 7cd3ca196a431..4ee34b19dea32 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -26,7 +26,6 @@ def server(zephyr_lora_files): "--enable-lora", "--lora-modules", f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_files}", "--max-lora-rank", "64", "--max-cpu-loras", @@ -56,4 +55,3 @@ async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files): assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" - assert lora_models[1].id == "zephyr-lora2" diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 11ed1c4a9ee4b..73f79ac28d110 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -102,12 +102,14 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): if "custom" in tool_call: return False - # Sometimes guided_grammar is generated to be empty + # Sometimes structured_outputs.grammar is generated to be empty # Causing a server error in EBNF grammar parsing # https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421 - guided_grammar = case.body.get("guided_grammar") + structured_outputs = case.body.get("structured_outputs", {}) + grammar = structured_outputs.get("grammar") if isinstance( + structured_outputs, dict) else None - if guided_grammar == '': + if grammar == '': # Allow None (will be handled as no grammar) # But skip empty strings return False diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 4197583074dfe..bb4c633e5e502 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -3,14 +3,14 @@ import io -# imports for guided decoding tests +# imports for structured outputs tests import openai import pybase64 import pytest import regex as re import torch -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.renderer import BaseRenderer from ...utils import RemoteOpenAIServer @@ -27,12 +27,16 @@ async def test_empty_prompt(): with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - with pytest.raises(openai.BadRequestError, - match="decoder prompt cannot be empty"): + with pytest.raises( + openai.BadRequestError, + match= + "Either prompt or prompt_embeds must be provided and non-empty." + ): await client.completions.create(model=model_name, prompt="", max_tokens=5, - temperature=0.0) + temperature=0.0, + extra_body={"prompt_embeds": []}) @pytest.mark.asyncio @@ -83,7 +87,7 @@ def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, buffer.seek(0) encoded_tensor = pybase64.b64encode(buffer.getvalue()) - loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor) + loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor) assert len(loaded_prompt_embeds) == 1 loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] assert loaded_tensor.device.type == "cpu" diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 72d468db08f65..eceaff672112f 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -74,6 +74,20 @@ async def test_basic_with_reasoning_effort(client: OpenAI, model_name: str): assert response.status == "completed" +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_max_tokens(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="What is the first paragraph of Moby Dick?", + reasoning={"effort": "low"}, + max_output_tokens=30, + ) + assert response is not None + assert response.status == "incomplete" + assert response.incomplete_details.reason == "max_output_tokens" + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_chat(client: OpenAI, model_name: str): @@ -275,7 +289,8 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_streaming(client: OpenAI, model_name: str): +@pytest.mark.parametrize("background", [True, False]) +async def test_streaming(client: OpenAI, model_name: str, background: bool): # TODO: Add back when web search and code interpreter are available in CI prompts = [ "tell me a story about a cat in 20 words", @@ -300,15 +315,43 @@ async def test_streaming(client: OpenAI, model_name: str): # }, ], stream=True, + background=background, ) + current_item_id = "" + current_content_index = -1 + events = [] current_event_mode = None + resp_id = None async for event in response: + if event.type == "response.created": + resp_id = event.response.id + if current_event_mode != event.type: current_event_mode = event.type print(f"\n[{event.type}] ", end="", flush=True) + # verify current_item_id is correct + if event.type == "response.output_item.added": + assert event.item.id != current_item_id + current_item_id = event.item.id + elif event.type in [ + "response.output_text.delta", + "response.reasoning_text.delta" + ]: + assert event.item_id == current_item_id + + # verify content_index_id is correct + if event.type == "response.content_part.added": + assert event.content_index != current_content_index + current_content_index = event.content_index + elif event.type in [ + "response.output_text.delta", + "response.reasoning_text.delta" + ]: + assert event.content_index == current_content_index + if "text.delta" in event.type: print(event.delta, end="", flush=True) elif "reasoning_text.delta" in event.type: @@ -321,6 +364,19 @@ async def test_streaming(client: OpenAI, model_name: str): events.append(event) assert len(events) > 0 + response_completed_event = events[-1] + assert len(response_completed_event.response.output) > 0 + + if background: + starting_after = 5 + async with await client.responses.retrieve( + response_id=resp_id, + stream=True, + starting_after=starting_after) as stream: + counter = starting_after + async for event in stream: + counter += 1 + assert event == events[counter] @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_return_token_ids.py b/tests/entrypoints/openai/test_return_token_ids.py index 6addcb41c4098..ff8f193fec552 100644 --- a/tests/entrypoints/openai/test_return_token_ids.py +++ b/tests/entrypoints/openai/test_return_token_ids.py @@ -224,7 +224,7 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server): logprobs_token_ids.append(token_id) # When echo=True, the logprobs include both prompt and response tokens - # The token_ids field should match the the suffix of response portion + # The token_ids field should match the suffix of response portion # The prompt_token_ids should match the prompt portion assert len(completion.choices[0].token_ids) < len(logprobs_token_ids) response_token_ids_length = len(completion.choices[0].token_ids) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 10879f0be83c8..8e68699e5904a 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,25 +1,224 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import asyncio from contextlib import suppress from dataclasses import dataclass, field -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock import pytest +import pytest_asyncio -from vllm.config import MultiModalConfig -from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM + +from ...utils import RemoteOpenAIServer + +if TYPE_CHECKING: + from openai import OpenAI + +GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module", + params=[True, False], + ids=["with_tool_parser", "without_tool_parser"]) +def with_tool_parser(request) -> bool: + return request.param + + +@pytest.fixture(scope="module") +def default_server_args(with_tool_parser: bool): + args = [ + # use half precision for speed and memory savings in CI environment + "--enforce-eager", + "--max-model-len", + "4096", + "--reasoning-parser", + "openai_gptoss", + "--gpu-memory-utilization", + "0.8", + ] + if with_tool_parser: + args.extend([ + "--tool-call-parser", + "openai", + "--enable-auto-tool-choice", + ]) + return args + + +@pytest.fixture(scope="module") +def gptoss_server(monkeypatch_module: pytest.MonkeyPatch, + default_server_args: list[str]): + with monkeypatch_module.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") + with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, + default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def gptoss_client(gptoss_server): + async with gptoss_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI, + with_tool_parser: bool): + tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + }, + "state": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + }, + }] + + messages = [ + { + "role": "user", + "content": "What is the weather in Dallas, TX?" + }, + ] + + stream = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools if with_tool_parser else None, + stream=True) + + name = None + args_buf = "" + content_buf = "" + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.tool_calls: + tc = delta.tool_calls[0] + if tc.function and tc.function.name: + name = tc.function.name + if tc.function and tc.function.arguments: + args_buf += tc.function.arguments + if getattr(delta, "content", None): + content_buf += delta.content + if with_tool_parser: + assert name is not None + assert len(args_buf) > 0 + else: + assert name is None + assert len(args_buf) == 0 + assert len(content_buf) > 0 + + +@pytest.mark.asyncio +async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, + with_tool_parser: bool): + if not with_tool_parser: + pytest.skip("skip non-tool for multi-turn tests") + tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string" + }, + "state": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + }, + }] + + messages = [ + { + "role": "system", + "content": "you are a helpful assistant" + }, + { + "role": "user", + "content": "What is the weather in Dallas, TX with celsius?" + }, + ] + + first = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools, + temperature=0.0, + ) + first_msg = first.choices[0].message + assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0 + tc = first_msg.tool_calls[0] + assert tc.function is not None and tc.function.name == "get_current_weather" + args1 = tc.function.arguments + assert args1 is not None and len(args1) > 0 + + messages.append({"role": "assistant", "content": args1}) + messages.append({ + "role": "user", + "content": "Now convert to celsius and return JSON only" + }) + + second = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages, + tools=tools, + temperature=0.0, + ) + second_msg = second.choices[0].message + assert (second_msg.content is not None and len(second_msg.content) > 0) or \ + (second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) + MODEL_NAME = "openai-community/gpt2" +MODEL_NAME_SHORT = "gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" -BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +BASE_MODEL_PATHS = [ + BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), + BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT) +] @dataclass @@ -75,9 +274,45 @@ def test_async_serving_chat_init(): assert serving_completion.chat_template == CHAT_TEMPLATE +@pytest.mark.asyncio +async def test_serving_chat_returns_correct_model_name(): + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=MockModelConfig()) + serving_chat = OpenAIServingChat(mock_engine, + MockModelConfig(), + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + messages = [{"role": "user", "content": "what is 1+1?"}] + + async def return_model_name(*args): + return args[3] + + serving_chat.chat_completion_full_generator = return_model_name + + # Test that full name is returned when short name is requested + req = ChatCompletionRequest(model=MODEL_NAME_SHORT, messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + # Test that full name is returned when empty string is specified + req = ChatCompletionRequest(model="", messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + # Test that full name is returned when no model is specified + req = ChatCompletionRequest(messages=messages) + assert await serving_chat.create_chat_completion(req) == MODEL_NAME + + @pytest.mark.asyncio async def test_serving_chat_should_set_correct_max_tokens(): - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -98,7 +333,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): "role": "user", "content": "what is 1+1?" }], - guided_decoding_backend="outlines", ) with suppress(Exception): @@ -120,7 +354,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): } # Reinitialize the engine with new settings - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -143,7 +377,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): "role": "user", "content": "what is 1+1?" }], - guided_decoding_backend="outlines", ) with suppress(Exception): @@ -175,7 +408,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): } # Reinitialize the engine with new settings - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -198,7 +431,6 @@ async def test_serving_chat_should_set_correct_max_tokens(): "role": "user", "content": "what is 1+1?" }], - guided_decoding_backend="outlines", ) with suppress(Exception): @@ -232,7 +464,7 @@ async def test_serving_chat_could_load_correct_generation_config(): "repetition_penalty": 1.05 } - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -254,7 +486,6 @@ async def test_serving_chat_could_load_correct_generation_config(): "role": "user", "content": "what is 1+1?" }], - guided_decoding_backend="outlines", ) with suppress(Exception): @@ -288,7 +519,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_model_config = MockModelConfig() mock_model_config.hf_config.model_type = model_type - mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -313,7 +544,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): }], ) - # By default cache_salt in the engine prompt is not set + # By default, cache_salt in the engine prompt is not set with suppress(Exception): await serving_chat.create_chat_completion(req) assert "cache_salt" not in mock_engine.generate.call_args.args[0] diff --git a/tests/entrypoints/openai/test_skip_tokenizer.py b/tests/entrypoints/openai/test_skip_tokenizer.py index 0bb42ed8aa7fb..840e0dac81c97 100644 --- a/tests/entrypoints/openai/test_skip_tokenizer.py +++ b/tests/entrypoints/openai/test_skip_tokenizer.py @@ -11,7 +11,7 @@ import torch from ...utils import RemoteOpenAIServer -MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" DTYPE = "float16" @@ -35,7 +35,9 @@ def server(): "--trust-remote-code", "--skip-tokenizer-init", "--max-num-seqs", - "32" + "32", + "--model-impl", + "terratorch" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 72c8a3510c9b0..ecb7f50fa7400 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -14,7 +14,7 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @pytest.fixture(scope="module") -def server(zephyr_lora_added_tokens_files: str): # noqa: F811 +def server(): args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -24,12 +24,6 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 "--enforce-eager", "--max-num-seqs", "128", - # lora config - "--enable-lora", - "--lora-modules", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", - "--max-lora-rank", - "64", "--enable-tokenizer-info-endpoint", ] @@ -38,10 +32,8 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 @pytest.fixture(scope="module") -def tokenizer_name(model_name: str, - zephyr_lora_added_tokens_files: str): # noqa: F811 - return zephyr_lora_added_tokens_files if ( - model_name == "zephyr-lora2") else model_name +def tokenizer_name(model_name: str): + return model_name @pytest_asyncio.fixture @@ -53,7 +45,7 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenize_completions( @@ -86,7 +78,7 @@ async def test_tokenize_completions( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenize_chat( @@ -148,7 +140,7 @@ async def test_tokenize_chat( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenize_chat_with_tools( @@ -225,7 +217,7 @@ async def test_tokenize_chat_with_tools( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name, tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenize_with_return_token_strs( @@ -260,7 +252,7 @@ async def test_tokenize_with_return_token_strs( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_detokenize( @@ -287,7 +279,7 @@ async def test_detokenize( @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", - [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + [(MODEL_NAME, MODEL_NAME)], indirect=["tokenizer_name"], ) async def test_tokenizer_info_basic( @@ -384,4 +376,4 @@ async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer): if chat_template: assert isinstance(chat_template, str), ("Chat template should be a string") - assert chat_template.strip(), "Chat template should not be empty" \ No newline at end of file + assert chat_template.strip(), "Chat template should not be empty" diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 6a3cdfdfc8081..23c99da97ad3a 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# imports for guided decoding tests +# imports for structured outputs tests import io import json diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index f43b7a253d28d..eb7879927b9b6 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import io -# imports for guided decoding tests +# imports for structured outputs tests import json import httpx diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 106ec121a422e..a324e86666055 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -16,11 +16,11 @@ MODEL_NAME = "microsoft/Phi-3.5-vision-instruct" MAXIMUM_IMAGES = 2 # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] EXPECTED_MM_BEAM_SEARCH_RES = [ @@ -34,11 +34,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [ ], [ "The image shows a Venn diagram with three over", - "The image shows a Venn diagram with three intersect", + "This image shows a Venn diagram with three over", ], [ "This image displays a gradient of colors ranging from", - "The image displays a gradient of colors ranging from", + "This image displays a gradient of colors forming a spectrum", ], ] @@ -69,10 +69,11 @@ async def client(server): @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_asset: + encode_image_base64(local_asset_server.get_image_asset(image_asset)) + for image_asset in TEST_IMAGE_ASSETS } @@ -97,7 +98,7 @@ def get_hf_prompt_tokens(model_name, content, image_url): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image(client: openai.AsyncOpenAI, model_name: str, image_url: str): content_text = "What's in this image?" @@ -157,7 +158,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI, model_name: str, image_url: str): @@ -187,7 +188,7 @@ async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, model_name: str, image_url: str): @@ -223,10 +224,11 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image_base64encoded( - client: openai.AsyncOpenAI, model_name: str, image_url: str, - base64_encoded_image: dict[str, str]): + client: openai.AsyncOpenAI, model_name: str, raw_image_url: str, + image_url: str, base64_encoded_image: dict[str, str]): content_text = "What's in this image?" messages = [{ @@ -237,7 +239,7 @@ async def test_single_chat_session_image_base64encoded( "type": "image_url", "image_url": { "url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" } }, { @@ -287,12 +289,12 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS)))) +@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_ASSETS)))) async def test_single_chat_session_image_base64encoded_beamsearch( client: openai.AsyncOpenAI, model_name: str, image_idx: int, base64_encoded_image: dict[str, str]): # NOTE: This test also validates that we pass MM data through beam search - image_url = TEST_IMAGE_URLS[image_idx] + raw_image_url = TEST_IMAGE_ASSETS[image_idx] expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx] messages = [{ @@ -303,7 +305,7 @@ async def test_single_chat_session_image_base64encoded_beamsearch( "type": "image_url", "image_url": { "url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" } }, { @@ -326,7 +328,7 @@ async def test_single_chat_session_image_base64encoded_beamsearch( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_chat_streaming_image(client: openai.AsyncOpenAI, model_name: str, image_url: str): messages = [{ @@ -385,7 +387,8 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( "image_urls", - [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, image_urls: list[str]): @@ -433,3 +436,197 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, ) message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) +async def test_completions_with_image( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + } + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) +async def test_completions_with_image_with_uuid( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_url + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 + + # Second request, with empty image but the same uuid. + chat_completion_with_empty_image = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": {}, + "uuid": image_url + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion_with_empty_image.choices[ + 0].message.content is not None + assert isinstance( + chat_completion_with_empty_image.choices[0].message.content, str) + assert len( + chat_completion_with_empty_image.choices[0].message.content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_empty_image_with_uuid_without_cache_hit( + client: openai.AsyncOpenAI, + model_name: str, +): + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": {}, + "uuid": "uuid_not_previously_seen" + }, + ], + }, + ], + model=model_name, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) +async def test_completions_with_image_with_incorrect_uuid_format( + client: openai.AsyncOpenAI, + model_name: str, + image_urls: list[str], +): + for image_url in image_urls: + chat_completion = await client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "Describe this image.", + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + "incorrect_uuid_key": image_url, + }, + "also_incorrect_uuid_key": image_url, + }, + ], + }, + ], + model=model_name, + ) + assert chat_completion.choices[0].message.content is not None + assert isinstance(chat_completion.choices[0].message.content, str) + assert len(chat_completion.choices[0].message.content) > 0 diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py index 28b1f8358d80b..4bab849f47c27 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -18,6 +18,8 @@ SERVER_ARGS = [ "--enable-lora", "--lora-modules", f"{LORA_MODEL}={LORA_MODEL}", + "--tokenizer", + f"{LORA_MODEL}", ] TOOLS = [{ diff --git a/tests/async_engine/__init__.py b/tests/entrypoints/pooling/__init__.py similarity index 100% rename from tests/async_engine/__init__.py rename to tests/entrypoints/pooling/__init__.py diff --git a/tests/core/__init__.py b/tests/entrypoints/pooling/correctness/__init__.py similarity index 100% rename from tests/core/__init__.py rename to tests/entrypoints/pooling/correctness/__init__.py diff --git a/tests/entrypoints/openai/correctness/test_mteb_embed.py b/tests/entrypoints/pooling/correctness/test_mteb_embed.py similarity index 71% rename from tests/entrypoints/openai/correctness/test_mteb_embed.py rename to tests/entrypoints/pooling/correctness/test_mteb_embed.py index 783f7d3e0d5aa..12a4875bdacfd 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_embed.py +++ b/tests/entrypoints/pooling/correctness/test_mteb_embed.py @@ -4,10 +4,9 @@ import os import pytest -from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, - MTEB_EMBED_TOL, - OpenAIClientMtebEncoder, - run_mteb_embed_task) +from tests.models.language.pooling_mteb_test.mteb_utils import ( + MTEB_EMBED_TASKS, MTEB_EMBED_TOL, OpenAIClientMtebEncoder, + run_mteb_embed_task) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" @@ -37,4 +36,6 @@ def test_mteb_embed(server): print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < MTEB_EMBED_TOL diff --git a/tests/entrypoints/openai/correctness/test_mteb_score.py b/tests/entrypoints/pooling/correctness/test_mteb_score.py similarity index 64% rename from tests/entrypoints/openai/correctness/test_mteb_score.py rename to tests/entrypoints/pooling/correctness/test_mteb_score.py index cfb865815c9b2..7c059d16b3863 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_score.py +++ b/tests/entrypoints/pooling/correctness/test_mteb_score.py @@ -4,18 +4,15 @@ import os import pytest -# yapf conflicts with isort for this block -# yapf: disable -from tests.models.language.pooling.mteb_utils import ( +from tests.models.language.pooling_mteb_test.mteb_utils import ( MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL, - RerankClientMtebEncoder, ScoreClientMtebEncoder, - mteb_test_rerank_models_hf, run_mteb_rerank) -# yapf: enable + RerankClientMtebEncoder, ScoreClientMtebEncoder, run_mteb_rerank) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" +st_main_score = 0.33457 @pytest.fixture(scope="module") @@ -29,15 +26,7 @@ def server(): yield remote_server -@pytest.fixture(scope="module") -def st_main_score(hf_runner): - # The main score related to the version of the dependency. - # So we need to recalculate every time. - main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME) - return main_score - - -def test_mteb_score(server, st_main_score): +def test_mteb_score(server): url = server.url_for("score") encoder = ScoreClientMtebEncoder(MODEL_NAME, url) vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, @@ -47,10 +36,12 @@ def test_mteb_score(server, st_main_score): print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < MTEB_RERANK_TOL -def test_mteb_rerank(server, st_main_score): +def test_mteb_rerank(server): url = server.url_for("rerank") encoder = RerankClientMtebEncoder(MODEL_NAME, url) vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, @@ -60,4 +51,6 @@ def test_mteb_rerank(server, st_main_score): print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < MTEB_RERANK_TOL diff --git a/tests/core/block/__init__.py b/tests/entrypoints/pooling/llm/__init__.py similarity index 100% rename from tests/core/block/__init__.py rename to tests/entrypoints/pooling/llm/__init__.py diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/pooling/llm/test_classify.py similarity index 98% rename from tests/entrypoints/llm/test_classify.py rename to tests/entrypoints/pooling/llm/test_classify.py index 6c0c9cd015801..ff5cea11a9182 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/pooling/llm/test_classify.py @@ -6,11 +6,10 @@ import weakref import pytest import torch +from tests.models.utils import softmax from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import softmax - MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" prompts = ["The chef prepared a delicious meal."] diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/pooling/llm/test_embedding.py similarity index 100% rename from tests/entrypoints/llm/test_embedding.py rename to tests/entrypoints/pooling/llm/test_embedding.py diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/pooling/llm/test_encode.py similarity index 100% rename from tests/entrypoints/llm/test_encode.py rename to tests/entrypoints/pooling/llm/test_encode.py diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/pooling/llm/test_reward.py similarity index 97% rename from tests/entrypoints/llm/test_reward.py rename to tests/entrypoints/pooling/llm/test_reward.py index 2cee3c8d94e36..11d164c978a92 100644 --- a/tests/entrypoints/llm/test_reward.py +++ b/tests/entrypoints/pooling/llm/test_reward.py @@ -6,11 +6,10 @@ import weakref import pytest import torch +from tests.models.utils import softmax from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import softmax - MODEL_NAME = "internlm/internlm2-1_8b-reward" prompts = ["The chef prepared a delicious meal."] diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/pooling/llm/test_score.py similarity index 97% rename from tests/entrypoints/llm/test_score.py rename to tests/entrypoints/pooling/llm/test_score.py index f715dacacb8ff..447378f989d09 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/pooling/llm/test_score.py @@ -6,11 +6,10 @@ import weakref import pytest import torch +from tests.models.utils import softmax from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import softmax - MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" diff --git a/tests/core/block/e2e/__init__.py b/tests/entrypoints/pooling/openai/__init__.py similarity index 100% rename from tests/core/block/e2e/__init__.py rename to tests/entrypoints/pooling/openai/__init__.py diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/pooling/openai/test_classification.py similarity index 99% rename from tests/entrypoints/openai/test_classification.py rename to tests/entrypoints/pooling/openai/test_classification.py index 36c96d76c2e5f..26c2c8e6af17d 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/pooling/openai/test_classification.py @@ -6,10 +6,9 @@ import requests import torch import torch.nn.functional as F +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import ClassificationResponse -from ...utils import RemoteOpenAIServer - MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" DTYPE = "float32" # Use float32 to avoid NaN issue diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/pooling/openai/test_embedding.py similarity index 98% rename from tests/entrypoints/openai/test_embedding.py rename to tests/entrypoints/pooling/openai/test_embedding.py index d46ab304ba6d5..37a10e79d4fc7 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/pooling/openai/test_embedding.py @@ -11,14 +11,13 @@ import requests import torch import torch.nn.functional as F +from tests.models.language.pooling.embed_utils import ( + run_embedding_correctness_test) +from tests.models.utils import check_embeddings_close +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...models.language.pooling.embed_utils import ( - run_embedding_correctness_test) -from ...models.utils import check_embeddings_close -from ...utils import RemoteOpenAIServer - MODEL_NAME = "intfloat/multilingual-e5-small" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DTYPE = "bfloat16" diff --git a/tests/entrypoints/openai/test_embedding_dimensions.py b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py similarity index 95% rename from tests/entrypoints/openai/test_embedding_dimensions.py rename to tests/entrypoints/pooling/openai/test_embedding_dimensions.py index 91e91699b92ca..3c7e88daa8ff3 100644 --- a/tests/entrypoints/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py @@ -9,13 +9,12 @@ from typing import Optional import openai import pytest -from vllm.entrypoints.openai.protocol import EmbeddingResponse - -from ...conftest import HfRunner -from ...models.language.pooling.embed_utils import ( +from tests.conftest import HfRunner +from tests.models.language.pooling.embed_utils import ( run_embedding_correctness_test) -from ...models.utils import EmbedModelInfo -from ...utils import RemoteOpenAIServer +from tests.models.utils import EmbedModelInfo +from tests.utils import RemoteOpenAIServer +from vllm.entrypoints.openai.protocol import EmbeddingResponse MODELS = [ EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), diff --git a/tests/entrypoints/openai/test_embedding_long_text.py b/tests/entrypoints/pooling/openai/test_embedding_long_text.py similarity index 99% rename from tests/entrypoints/openai/test_embedding_long_text.py rename to tests/entrypoints/pooling/openai/test_embedding_long_text.py index 86bd34abb97e0..2d3da238d245e 100644 --- a/tests/entrypoints/openai/test_embedding_long_text.py +++ b/tests/entrypoints/pooling/openai/test_embedding_long_text.py @@ -14,10 +14,9 @@ import openai import pytest import pytest_asyncio +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse -from ...utils import RemoteOpenAIServer - def _generate_random_text(word_count: int) -> str: """Generate random text with approximately the specified word count.""" diff --git a/tests/entrypoints/openai/test_pooling.py b/tests/entrypoints/pooling/openai/test_pooling.py similarity index 99% rename from tests/entrypoints/openai/test_pooling.py rename to tests/entrypoints/pooling/openai/test_pooling.py index 63f4205e0a42b..9f58955cfb40b 100644 --- a/tests/entrypoints/openai/test_pooling.py +++ b/tests/entrypoints/pooling/openai/test_pooling.py @@ -8,11 +8,10 @@ import pytest import requests from tests.models.utils import check_embeddings_close +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import PoolingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...utils import RemoteOpenAIServer - MODEL_NAME = "internlm/internlm2-1_8b-reward" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/pooling/openai/test_rerank.py similarity index 99% rename from tests/entrypoints/openai/test_rerank.py rename to tests/entrypoints/pooling/openai/test_rerank.py index ce4d6c5f5d337..992cb5147ef0d 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/pooling/openai/test_rerank.py @@ -6,10 +6,9 @@ import requests import torch import torch.nn.functional as F +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import RerankResponse -from ...utils import RemoteOpenAIServer - MODEL_NAME = "BAAI/bge-reranker-base" DTYPE = "bfloat16" diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/pooling/openai/test_score.py similarity index 99% rename from tests/entrypoints/openai/test_score.py rename to tests/entrypoints/pooling/openai/test_score.py index 4fafcfb45fa22..d676ecccbc87c 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/pooling/openai/test_score.py @@ -8,10 +8,9 @@ import torch import torch.nn.functional as F from torch import tensor +from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import ScoreResponse -from ...utils import RemoteOpenAIServer - MODELS = [ { "name": "BAAI/bge-reranker-v2-m3", diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/pooling/openai/test_truncation.py similarity index 87% rename from tests/entrypoints/openai/test_truncation.py rename to tests/entrypoints/pooling/openai/test_truncation.py index 121c0413e1af7..6bdf5ce7c4a6c 100644 --- a/tests/entrypoints/openai/test_truncation.py +++ b/tests/entrypoints/pooling/openai/test_truncation.py @@ -73,17 +73,11 @@ async def test_zero_truncation_size(client: openai.AsyncOpenAI): "truncate_prompt_tokens": truncation_size } - with pytest.raises(openai.BadRequestError) as err: - await client.post(path="embeddings", cast_to=object, body={**kwargs}) + response = await client.post(path="embeddings", + cast_to=object, + body={**kwargs}) - assert err.value.status_code == 400 - error_details = err.value.response.json()["error"] - - assert error_details["type"] == "BadRequestError" - assert "This model's maximum context length is" in error_details["message"] - assert "tokens in the input for embedding generation" in error_details[ - "message"] - assert "Please reduce the length of the input" in error_details["message"] + assert response["usage"]["prompt_tokens"] == truncation_size @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/pooling/openai/test_vision_embedding.py similarity index 74% rename from tests/entrypoints/openai/test_vision_embedding.py rename to tests/entrypoints/pooling/openai/test_vision_embedding.py index d3cc2fac6af57..48434e36eb265 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/pooling/openai/test_vision_embedding.py @@ -7,11 +7,10 @@ import pytest import requests from transformers import AutoProcessor +from tests.utils import VLLM_PATH, RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.multimodal.utils import encode_image_base64, fetch_image -from ...utils import VLLM_PATH, RemoteOpenAIServer - MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" MAXIMUM_IMAGES = 2 @@ -19,11 +18,11 @@ vlm2vec_jinja_path = VLLM_PATH / "examples/template_vlm2vec.jinja" assert vlm2vec_jinja_path.exists() # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] @@ -49,10 +48,11 @@ def server(): @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_url: + encode_image_base64(local_asset_server.get_image_asset(image_url)) + for image_url in TEST_IMAGE_ASSETS } @@ -70,7 +70,7 @@ def get_hf_prompt_tokens(model_name, content, image_url): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, image_url: str): content_text = "Represent the given image." diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py index e4af60a782651..a993e24ff838a 100644 --- a/tests/entrypoints/test_api_server_process_manager.py +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -95,7 +95,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update): assert not proc.is_alive() -@patch("vllm.entrypoints.cli.serve.run_api_server_worker", +@patch("vllm.entrypoints.cli.serve.run_api_server_worker_proc", mock_run_api_server_worker) def test_wait_for_completion_or_failure(api_server_args): """Test that wait_for_completion_or_failure works with failures.""" diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 647f1c7b7f34f..78370d199b566 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -20,11 +20,10 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, parse_chat_messages_futures, resolve_chat_template_content_format, resolve_hf_chat_template) -from vllm.entrypoints.llm import apply_hf_chat_template -from vllm.multimodal import MultiModalDataDict +from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, encode_video_base64) -from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from ..models.registry import HF_EXAMPLE_MODELS @@ -38,7 +37,6 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B" -MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @@ -46,112 +44,103 @@ MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" @pytest.fixture(scope="function") def phi3v_model_config(): - return ModelConfig(PHI3V_MODEL_ID, - runner="generate", - trust_remote_code=True, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="function") def phi3v_model_config_mm_interleaved(): - return ModelConfig(PHI3V_MODEL_ID, - runner="generate", - trust_remote_code=True, - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - }) + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + }, + ) @pytest.fixture(scope="module") def phi3v_tokenizer(): - return TokenizerGroup( - tokenizer_id=PHI3V_MODEL_ID, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, + return get_tokenizer(PHI3V_MODEL_ID) + + +@pytest.fixture(scope="function") +def qwen2_audio_model_config(): + return ModelConfig( + QWEN2AUDIO_MODEL_ID, + runner="generate", + trust_remote_code=True, + limit_mm_per_prompt={ + "audio": 1, + }, ) +@pytest.fixture(scope="module") +def qwen2_audio_tokenizer(): + return get_tokenizer(QWEN2AUDIO_MODEL_ID) + + @pytest.fixture(scope="function") def qwen25omni_model_config_mm_interleaved(): - return ModelConfig(QWEN25OMNI_MODEL_ID, - runner="generate", - interleave_mm_strings=True, - limit_mm_per_prompt={ - "image": 2, - "audio": 1, - "video": 1, - }) + return ModelConfig( + QWEN25OMNI_MODEL_ID, + runner="generate", + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + "audio": 1, + "video": 1, + }, + ) @pytest.fixture(scope="module") def qwen25omni_tokenizer(): - return TokenizerGroup( - tokenizer_id=QWEN25OMNI_MODEL_ID, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, - ) - - -@pytest.fixture(scope="module") -def mllama_model_config(): - return ModelConfig(MLLAMA_MODEL_ID, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) - - -@pytest.fixture(scope="module") -def mllama_tokenizer(): - return TokenizerGroup( - MLLAMA_MODEL_ID, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, - ) + return get_tokenizer(QWEN25OMNI_MODEL_ID) @pytest.fixture(scope="function") def mistral_model_config(): - return ModelConfig(MISTRAL_MODEL_ID, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) - - -@pytest.fixture(scope="module") -def mistral_tokenizer(): - return TokenizerGroup( - tokenizer_id=MISTRAL_MODEL_ID, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, + return ModelConfig( + MISTRAL_MODEL_ID, + runner="generate", + limit_mm_per_prompt={ + "image": 2, + }, ) +@pytest.fixture(scope="module") +def mistral_tokenizer(): + return get_tokenizer(MISTRAL_MODEL_ID) + + @pytest.fixture(scope="module") def image_url(): - image = ImageAsset('cherry_blossom') + image = ImageAsset("cherry_blossom") base64 = encode_image_base64(image.pil_image) return f"data:image/jpeg;base64,{base64}" @pytest.fixture(scope="module") def video_url(): - video = VideoAsset('baby_reading', 1) + video = VideoAsset("baby_reading", 1) base64 = encode_video_base64(video.np_ndarrays) return f"data:video/jpeg;base64,{base64}" @pytest.fixture(scope="module") def audio_url(): - audio = AudioAsset('mary_had_lamb') + audio = AudioAsset("mary_had_lamb") base64 = encode_audio_base64(*audio.audio_and_sample_rate) return f"data:audio/ogg;base64,{base64}" @@ -159,6 +148,7 @@ def audio_url(): def _assert_mm_data_is_image_input( mm_data: Optional[MultiModalDataDict], image_count: int, + skipped_image_indices: Optional[list] = None, ) -> None: assert mm_data is not None assert set(mm_data.keys()) == {"image"} @@ -167,6 +157,30 @@ def _assert_mm_data_is_image_input( assert image_data is not None assert isinstance(image_data, list) and len(image_data) == image_count + if skipped_image_indices is not None: + for i in skipped_image_indices: + assert image_data[i] is None + + +def _assert_mm_uuids( + mm_uuids: Optional[MultiModalUUIDDict], + media_count: int, + expected_uuids: list[Optional[str]], + modality: str = "image", +) -> None: + if len(expected_uuids) > 0: + assert mm_uuids is not None + assert modality in mm_uuids + + image_uuids = mm_uuids.get(modality) + assert image_uuids is not None + + assert isinstance(image_uuids, + list) and len(image_uuids) == media_count + + assert image_uuids == expected_uuids + else: + assert mm_uuids is None ModalityType = Literal["image", "video", "audio"] @@ -174,8 +188,10 @@ MultiModalDataCounts = Mapping[ModalityType, int] def _assert_mm_data_inputs( - mm_data: Optional[MultiModalDataDict], - data_count: MultiModalDataCounts, + mm_data: Optional[MultiModalDataDict], + data_count: MultiModalDataCounts, + skipped_media_indices: Optional[dict[ + str, list]] = None, # modality -> list[int] ) -> None: assert mm_data is not None assert set(data_count.keys()) == (set(mm_data.keys())) @@ -185,25 +201,35 @@ def _assert_mm_data_inputs( assert modality_data is not None assert isinstance(modality_data, list) and len(modality_data) == n + if skipped_media_indices is not None: + skipped_media_indices_for_modality = skipped_media_indices.get( + modality) + assert skipped_media_indices_for_modality is not None + for i in skipped_media_indices_for_modality: + assert modality_data[i] is None + def test_parse_chat_messages_single_image( phi3v_model_config, phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -215,87 +241,281 @@ def test_parse_chat_messages_single_image( "content": "<|image_1|>\nWhat's in the image?" }] _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) -def test_parse_chat_messages_empty_system( - mistral_model_config, - mistral_tokenizer, -): - # Test string format - conversation, _ = parse_chat_messages( - [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }], - mistral_model_config, - mistral_tokenizer, - content_format="string", - ) - assert conversation == [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": "Who are you?" - }] - - # Test openai format - conversation, _ = parse_chat_messages( - [{ - "role": "system", - "content": "" - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }], - mistral_model_config, - mistral_tokenizer, - content_format="openai", - ) - assert conversation == [{ - "role": "system", - "content": [{ - "type": "text", - "text": "" - }] - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }] - - -@pytest.mark.asyncio -async def test_parse_chat_messages_single_image_async( +def test_parse_chat_messages_single_image_with_uuid( phi3v_model_config, phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures( + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +def test_parse_chat_messages_single_empty_image_with_uuid( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0]) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +def test_parse_chat_messages_single_image_with_bad_uuid_format( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + "uuid": image_uuid, + }, + "bad_uuid_key": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in the image?", + }] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +def test_parse_chat_messages_multiple_empty_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in the image?", + }] + _assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[0, 1]) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +def test_parse_chat_messages_mixed_empty_images_with_uuids( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in the image?", + }] + _assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[1]) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_image_with_uuid_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -307,29 +527,78 @@ async def test_parse_chat_messages_single_image_async( "content": "<|image_1|>\nWhat's in the image?" }] _assert_mm_data_is_image_input(await mm_future, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) -def test_parse_chat_messages_multiple_images( +@pytest.mark.asyncio +async def test_parse_chat_messages_empty_image_with_uuid_async( phi3v_model_config, phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + image_uuid = str(hash(image_url)) + conversation, mm_future, mm_uuids = parse_chat_messages_futures( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(await mm_future, + 1, + skipped_image_indices=[0]) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid1, + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -340,9 +609,358 @@ def test_parse_chat_messages_multiple_images( "role": "user", "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" + "<|image_1|>\n<|image_2|>\nWhat's in these images?", + }] + _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid1 = "my_uuid_1" + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid1, + }, + { + "type": "image_pil", + "image_pil": None, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?", + }] + _assert_mm_data_is_image_input(await mm_future, + 2, + skipped_image_indices=[0, 1]) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid2 = "my_uuid_2" + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?", + }] + _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, image_uuid2]) + + +def test_parse_chat_messages_empty_system( + mistral_model_config, + mistral_tokenizer, +): + # Test string format + conversation, _, _ = parse_chat_messages( + [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "Who are you?" + }], + }, + ], + mistral_model_config, + mistral_tokenizer, + content_format="string", + ) + assert conversation == [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": "Who are you?" + }, + ] + + # Test openai format + conversation, _, _ = parse_chat_messages( + [ + { + "role": "system", + "content": "" + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "Who are you?" + }], + }, + ], + mistral_model_config, + mistral_tokenizer, + content_format="openai", + ) + assert conversation == [ + { + "role": "system", + "content": [{ + "type": "text", + "text": "" + }] + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "Who are you?" + }] + }, + ] + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_image_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in the image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(await mm_future, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_images( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_empty_pil_image_with_uuid( + phi3v_model_config, + phi3v_tokenizer, +): + uuid = "abcd" + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_pil", + "image_pil": None, + "uuid": uuid + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + }] + _assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0]) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) + + +def test_parse_chat_messages_empty_image_embeds_with_uuid( + phi3v_model_config, + phi3v_tokenizer, +): + uuid = "abcd" + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": None, + "uuid": uuid + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + }] + assert mm_data is not None + assert "image" in mm_data + assert mm_data["image"] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( + phi3v_model_config, + phi3v_tokenizer, +): + uuid = "abcd" + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": None, + "uuid": uuid + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + }] + mm_data = await mm_future + assert mm_data is not None + assert "image" in mm_data + assert mm_data["image"] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) @pytest.mark.asyncio @@ -351,22 +969,26 @@ async def test_parse_chat_messages_multiple_images_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures( + conversation, mm_future, mm_uuids = parse_chat_messages_futures( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_pil", - "image_pil": ImageAsset('cherry_blossom').pil_image - }, { - "type": "text", - "text": "What's in these images?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -377,9 +999,10 @@ async def test_parse_chat_messages_multiple_images_async( "role": "user", "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" + "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(await mm_future, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_placeholder_already_in_prompt( @@ -387,46 +1010,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - }], - phi3v_model_config, - phi3v_tokenizer, - content_format="string", - ) - assert conversation == [{ - "role": - "user", - "content": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - _assert_mm_data_is_image_input(mm_data, 2) - - -def test_parse_chat_messages_placeholder_one_already_in_prompt( - phi3v_model_config, - phi3v_tokenizer, - image_url, -): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", @@ -447,9 +1031,53 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( "type": "text", "text": - "What's in <|image_1|> and how does it compare to the other one?" # noqa: E501 - } - ] + "What's in <|image_1|> and how does it compare to <|image_2|>?", # noqa: E501 + }, + ], + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + assert conversation == [{ + "role": + "user", + "content": + "What's in <|image_1|> and how does it compare to <|image_2|>?", + }] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_placeholder_one_already_in_prompt( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to the other one?", # noqa: E501 + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -461,9 +1089,10 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( "user", "content": "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " - "other one?" + "other one?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_multiple_images_across_messages( @@ -471,35 +1100,45 @@ def test_parse_chat_messages_multiple_images_across_messages( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about this one?" - }] - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What about this one?" + }, + ], + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -520,26 +1159,101 @@ def test_parse_chat_messages_multiple_images_across_messages( }, ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_images_with_uuids_across_messages( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "What about this one?" + }, + ], + }, + ], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?" + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "<|image_2|>\nWhat about this one?" + }, + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_context_text_format( phi3v_model_config, phi3v_tokenizer, ): - conversation, mm_data = parse_chat_messages( - [{ - "role": "user", - "content": [{ - "type": "text", - "text": "What's in this text?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What about this one?" - }], + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [{ + "type": "text", + "text": "What's in this text?" + }], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "What about this one?" + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="openai", @@ -551,23 +1265,25 @@ def test_parse_chat_messages_context_text_format( "content": [{ "type": "text", "text": "What's in this text?" - }] + }], }, { "role": "assistant", "content": [{ "type": "text", "text": "Some stuff." - }] + }], }, { "role": "user", "content": [{ "type": "text", "text": "What about this one?" - }] + }], }, ] + assert mm_data is None + assert mm_uuids is None def test_parse_chat_messages_rejects_too_many_images_in_one_message( @@ -578,31 +1294,37 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + message="coroutine 'async_get_and_parse_image' was never awaited", + ) with pytest.raises(ValueError, match="At most"): parse_chat_messages( [{ "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "text", + "text": "What's in these images?" + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -618,42 +1340,54 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( with warnings.catch_warnings(): warnings.filterwarnings( "ignore", - message="coroutine 'async_get_and_parse_image' was never awaited") + message="coroutine 'async_get_and_parse_image' was never awaited", + ) with pytest.raises(ValueError, match="At most"): parse_chat_messages( - [{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about these two?" - }] - }], + [ + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + { + "type": "text", + "text": "What about these two?" + }, + ], + }, + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -665,17 +1399,19 @@ def test_parse_chat_messages_multiple_images_uncommon_input( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", "content": [ - "What's in these images?", { + "What's in these images?", + { "image_url": image_url - }, { + }, + { "image_url": image_url - } - ] + }, + ], }], phi3v_model_config, phi3v_tokenizer, @@ -686,9 +1422,10 @@ def test_parse_chat_messages_multiple_images_uncommon_input( "role": "user", "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?" + "<|image_1|>\n<|image_2|>\nWhat's in these images?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) def test_parse_chat_messages_multiple_images_interleave( @@ -696,30 +1433,36 @@ def test_parse_chat_messages_multiple_images_interleave( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + conversation, mm_data, mm_uuids = parse_chat_messages( [{ "role": "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "and this one" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Do they have differences?" + }, + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -731,9 +1474,10 @@ def test_parse_chat_messages_multiple_images_interleave( "user", "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" + "Do they have differences?", }] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @pytest.mark.asyncio @@ -742,30 +1486,36 @@ async def test_parse_chat_messages_multiple_images_interleave_async( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages_futures( + conversation, mm_data, mm_uuids = parse_chat_messages_futures( [{ "role": "user", - "content": [{ - "type": "text", - "text": "I need you to compare this image" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "and this one" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "Do they have differences?" - }] + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "and this one" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Do they have differences?" + }, + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -777,51 +1527,51 @@ async def test_parse_chat_messages_multiple_images_interleave_async( "user", "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" + "Do they have differences?", }] _assert_mm_data_is_image_input(await mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) -def test_parse_chat_messages_multiple_images_multiple_messages_interleave( +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( phi3v_model_config_mm_interleaved, phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages( + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages_futures( [{ "role": "user", "content": [ { "type": "text", - "text": "What's on this image?" + "text": "I need you to compare this image", }, { "type": "image_url", "image_url": { "url": image_url - } + }, + "uuid": image_uuid, }, { "type": "text", - "text": "Be accurate." + "text": "and this one" }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }] + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "Do they have differences?" + }, + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -832,93 +1582,586 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( "role": "user", "content": - "What's on this image?\n<|image_1|>\nBe accurate." - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": "user", - "content": "What's on this image?\n<|image_2|>" + "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", }] + _assert_mm_data_is_image_input(await mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) + + +def test_parse_chat_messages_multiple_images_multiple_messages_interleave( + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Be accurate." + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + ], + }, + ], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|image_1|>\nBe accurate.", + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "What's on this image?\n<|image_2|>" + }, + ] _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave( # noqa: E501 + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + image_uuid = str(hash(image_url)) + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + { + "type": "text", + "text": "Be accurate." + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": image_uuid, + }, + ], + }, + ], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": "user", + "content": "What's on this image?\n<|image_1|>\nBe accurate.", + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "What's on this image?\n<|image_2|>" + }, + ] + _assert_mm_data_is_image_input(mm_data, 2) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( - qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, - image_url, video_url, audio_url): - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Now listen to this audio" - }, - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - ] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What's on this image?" - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "And what's in the video?" - }, { - "type": "video_url", - "video_url": { - "url": video_url - } - }] - }], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": { + "url": video_url + } + }, + ], + }, + ], qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>" - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>" - }] + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=[None, None]) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=[None]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": "image_123", + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + }, + "uuid": "audio_123", + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": "image_123", + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": { + "url": video_url + }, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=["image_123", "image_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="video", + expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="audio", + expected_uuids=["audio_123"]) + + +def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": None, + "uuid": "image_123", + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": None, + "uuid": "audio_123", + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": None, + "uuid": "image_123", + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": None, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, { + "image": 2, + "video": 1, + "audio": 1 + }, + skipped_media_indices={ + "image": [0, 1], + "video": [0], + "audio": [0] + }) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=["image_123", "image_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="video", + expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, + 1, + modality="audio", + expected_uuids=["audio_123"]) + + +def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501 + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + image_url, + video_url, + audio_url, +): + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + "uuid": "image_123", + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + ], + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "And what's in the video?" + }, + { + "type": "video_url", + "video_url": { + "url": video_url + }, + "uuid": "video_123", + }, + ], + }, + ], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [ + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + }, + ] + + _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + _assert_mm_uuids(mm_uuids, + 2, + modality="image", + expected_uuids=["image_123", None]) + _assert_mm_uuids(mm_uuids, + 1, + modality="video", + expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) def test_parse_chat_messages_multiple_images_interleave_with_placeholders( @@ -929,7 +2172,8 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( with pytest.raises( ValueError, match=r"Found more '<|image_1|>' placeholders in input prompt " - "than actual multimodal data items."): + "than actual multimodal data items.", + ): parse_chat_messages( [{ "role": @@ -952,9 +2196,9 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( "text", "text": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?" + "Do they have differences?", }, - ] + ], }], phi3v_model_config_mm_interleaved, phi3v_tokenizer, @@ -962,171 +2206,13 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( ) -### Mllama currently wraps images / texts as interleaved dictionaries -def test_mllama_single_image( - mllama_model_config, - mllama_tokenizer, - image_url, -): - """Ensures that a single image is parsed correctly mllama.""" - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - "image_url": image_url - }] - }], - mllama_model_config, - mllama_tokenizer, - content_format="openai", - ) - _assert_mm_data_is_image_input(mm_data, 1) - assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - 'type': 'image' - }] - }] - - -def test_mllama_interleaved_images( - mllama_model_config, - mllama_tokenizer, - image_url, -): - """Ensures that multiple image are parsed as interleaved dicts.""" - conversation, mm_data = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - 'type': 'text', - 'text': 'The content of the first image is:' - }, - { - "image_url": image_url - }, - { - 'type': 'text', - 'text': 'The content of the second image is:' - }, - { - "image_url": image_url - }, - ] - }], - mllama_model_config, - mllama_tokenizer, - content_format="openai", - ) - _assert_mm_data_is_image_input(mm_data, 2) - assert conversation == [{ - 'role': - 'user', - 'content': [{ - 'type': 'text', - 'text': 'The content of the first image is:' - }, { - 'type': 'image' - }, { - 'type': 'text', - 'text': 'The content of the second image is:' - }, { - 'type': 'image' - }] - }] - - -@pytest.mark.parametrize("model", [MLLAMA_MODEL_ID]) -def test_multimodal_image_parsing_matches_hf(model, image_url): - """Checks end to end hf alignment for multimodal [image] parsing.""" - - def get_conversation(is_hf: bool): - img_part = {"type": "image_url", "image_url": {"url": image_url}} - if is_hf: - img_part = {'type': 'image'} - return [{ - 'role': - 'user', - 'content': [ - { - 'type': 'text', - 'text': 'The content of the first image is:' - }, - img_part, - { - 'type': 'text', - 'text': 'The content of the second image is:' - }, - img_part, - { - 'type': 'text', - 'text': 'What animal is in the first image?' - }, - ] - }] - - # Build a config for the model - model_config = ModelConfig(model, - runner="generate", - limit_mm_per_prompt={ - "image": 2, - }) - - # Build the tokenizer group and grab the underlying tokenizer - tokenizer_group = TokenizerGroup( - model, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, - trust_remote_code=model_config.trust_remote_code, - ) - tokenizer = tokenizer_group.tokenizer - - # Build and parse a conversation with {"type": "image"} using the tokenizer - hf_conversation = get_conversation(is_hf=True) - hf_result = tokenizer.apply_chat_template( - hf_conversation, - tokenize=False, - add_generation_prompt=True, - ) - - # Now parse with vLLMs chat utils & apply the template - vllm_conversation = get_conversation(is_hf=False) - conversation, _ = parse_chat_messages( - vllm_conversation, - model_config, - tokenizer_group, - content_format="openai", - ) - - vllm_result = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - chat_template=None, - model_config=model_config, - tools=None, - add_generation_prompt=True, - ) - - assert hf_result == vllm_result - - @pytest.mark.parametrize( "model", [ QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str HERMES_MODEL_ID, # tokenizer.chat_template is of type dict - ]) + ], +) @pytest.mark.parametrize("use_tools", [True, False]) def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): """checks that chat_template is a dict type for HF models.""" @@ -1140,26 +2226,24 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) - # Build the tokenizer group and grab the underlying tokenizer - tokenizer_group = TokenizerGroup( + # Build the tokenizer + tokenizer = get_tokenizer( model, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, trust_remote_code=model_config.trust_remote_code, ) - tokenizer = tokenizer_group.tokenizer - tools = [{ + tools = ([{ "type": "function", "function": { "name": "dummy_function_name", "description": "This is a dummy function", - "parameters": sample_json_schema - } - }] if use_tools else None + "parameters": sample_json_schema, + }, + }] if use_tools else None) # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( @@ -1181,7 +2265,6 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): (QWEN25VL_MODEL_ID, "openai"), (ULTRAVOX_MODEL_ID, "string"), (QWEN2AUDIO_MODEL_ID, "openai"), - (MLLAMA_MODEL_ID, "openai"), (LLAMA_GUARD_MODEL_ID, "openai")], ) # yapf: enable @@ -1196,16 +2279,14 @@ def test_resolve_content_format_hf_defined(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) - tokenizer_group = TokenizerGroup( + tokenizer = get_tokenizer( model, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, trust_remote_code=model_config.trust_remote_code, ) - tokenizer = tokenizer_group.tokenizer # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( @@ -1238,7 +2319,6 @@ def test_resolve_content_format_hf_defined(model, expected_format): [("Salesforce/blip2-opt-2.7b", "string"), ("facebook/chameleon-7b", "string"), ("deepseek-ai/deepseek-vl2-tiny", "string"), - ("microsoft/Florence-2-base", "string"), ("adept/fuyu-8b", "string"), ("google/paligemma-3b-mix-224", "string"), ("Qwen/Qwen-VL", "string"), @@ -1256,16 +2336,14 @@ def test_resolve_content_format_fallbacks(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) - tokenizer_group = TokenizerGroup( + tokenizer = get_tokenizer( model_config.tokenizer, - enable_lora=False, - max_num_seqs=5, - max_input_length=None, trust_remote_code=model_config.trust_remote_code, ) - tokenizer = tokenizer_group.tokenizer # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( @@ -1322,14 +2400,10 @@ def test_resolve_content_format_examples(template_path, expected_format): trust_remote_code=True, ) - tokenizer_group = TokenizerGroup( + dummy_tokenizer = get_tokenizer( PHI3V_MODEL_ID, # Dummy - enable_lora=False, - max_num_seqs=5, - max_input_length=None, trust_remote_code=model_config.trust_remote_code, ) - dummy_tokenizer = tokenizer_group.tokenizer dummy_tokenizer.chat_template = None chat_template = load_chat_template(EXAMPLES_DIR / template_path) @@ -1386,7 +2460,7 @@ def test_parse_chat_messages_include_thinking_chunk(mistral_model_config, }], }] - conversation_with_thinking, _ = parse_chat_messages( + conversation_with_thinking, _, _ = parse_chat_messages( messages, mistral_model_config, mistral_tokenizer, @@ -1511,3 +2585,82 @@ def test_apply_mistral_chat_template_thinking_chunk(): r"[INST]Thanks, what is 3+3?[/INST]") assert string_tokens == expected_tokens + + +def test_parse_chat_messages_single_empty_audio_with_uuid( + qwen2_audio_model_config, + qwen2_audio_tokenizer, +): + audio_uuid = "abcd" + conversation, mm_data, mm_uuids = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "input_audio", + "input_audio": {}, + "uuid": audio_uuid, + }, + { + "type": "text", + "text": "What does the audio say?" + }, + ], + }], + qwen2_audio_model_config, + qwen2_audio_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the audio say?" + }] + _assert_mm_data_inputs(mm_data, {"audio": 1}) + _assert_mm_uuids(mm_uuids, + 1, + modality="audio", + expected_uuids=[audio_uuid]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_empty_audio_with_uuid_async( + qwen2_audio_model_config, + qwen2_audio_tokenizer, +): + audio_uuid = "abcd" + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [ + { + "type": "input_audio", + "input_audio": {}, + "uuid": audio_uuid, + }, + { + "type": "text", + "text": "What does the audio say?" + }, + ], + }], + qwen2_audio_model_config, + qwen2_audio_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the audio say?" + }] + _assert_mm_data_inputs(await mm_future, {"audio": 1}) + _assert_mm_uuids(mm_uuids, + 1, + modality="audio", + expected_uuids=[audio_uuid]) diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py new file mode 100644 index 0000000000000..2afe9758c2adf --- /dev/null +++ b/tests/entrypoints/test_context.py @@ -0,0 +1,500 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock, patch + +import pytest +from openai_harmony import Author, Message, Role, StreamState, TextContent + +from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext +from vllm.outputs import CompletionOutput, RequestOutput + + +# Helper function for Python < 3.10 compatibility +async def async_next(async_iterator): + """Compatibility function equivalent to Python 3.10's anext().""" + return await async_iterator.__anext__() + + +def create_mock_request_output( + prompt_token_ids=None, + output_token_ids=None, + num_cached_tokens=0, + finished=True, +): + """Helper function to create a mock RequestOutput object for testing.""" + outputs = [] + token_ids = output_token_ids if output_token_ids is not None else [] + outputs = [ + CompletionOutput( + index=0, + text="Test output", + token_ids=token_ids, + cumulative_logprob=0.0, + logprobs=None, + finish_reason=None, + stop_reason=None, + ) + ] + + return RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + outputs=outputs, + finished=finished, + num_cached_tokens=num_cached_tokens, + ) + + +async def generate_mock_outputs(num_turns, + prompt_token_counts, + output_token_counts, + cached_token_counts=None): + """Generate a sequence of mock RequestOutput objects to simulate multiple + turns.""" + if cached_token_counts is None: + cached_token_counts = [0] * num_turns + + for i in range(num_turns): + # Create mock prompt token IDs and output token IDs + prompt_token_ids = list(range(1, prompt_token_counts[i] + 1)) + output_token_ids = list(range(1, output_token_counts[i] + 1)) + + # Create and yield the RequestOutput + yield create_mock_request_output( + prompt_token_ids=prompt_token_ids, + output_token_ids=output_token_ids, + num_cached_tokens=cached_token_counts[i], + ) + + +@pytest.fixture +def mock_parser(): + """Set up a mock parser for tests.""" + with patch("vllm.entrypoints.context.get_streamable_parser_for_assistant" + ) as mock_parser_factory: + # Create a mock parser object + parser = MagicMock() + parser.messages = [] + parser.current_channel = None + parser.state = StreamState.EXPECT_START + mock_parser_factory.return_value = parser + yield parser + + +def test_single_turn_token_counting(): + """Test token counting behavior for a single turn.""" + # Create a context + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a mock RequestOutput with specific token counts + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3, 4, 5], # 5 prompt tokens + output_token_ids=[6, 7, 8], # 3 output tokens + num_cached_tokens=2, # 2 cached tokens + ) + + # Append the output to the context + context.append_output(mock_output) + + # Verify the token counts + assert context.num_prompt_tokens == 5 + assert context.num_output_tokens == 3 + assert context.num_cached_tokens == 2 + assert context.num_tool_output_tokens == 0 # No tool tokens in first turn + + # Verify internal state tracking + assert not context.is_first_turn + assert context.previous_turn.input_tokens == 5 + assert context.previous_turn.output_tokens == 3 + + +@pytest.mark.asyncio +async def test_multi_turn_token_counting(): + """Test token counting behavior across multiple turns with tool output.""" + # Create a context + context = HarmonyContext(messages=[], available_tools=["browser"]) + + # Simulate a conversation with 3 turns + # Turn 1: prefill 5, decode 3, tool 7 + # Turn 2: prefill 15, cached 5, decode 4, tool 1 + # Turn 3: prefill 20, cached 15, decode 5 + prompt_token_counts = [5, 15, 20] + output_token_counts = [3, 4, 5] + cached_token_counts = [0, 5, 15] + mock_generator = generate_mock_outputs(3, prompt_token_counts, + output_token_counts, + cached_token_counts) + + # First turn - initial prompt and response + mock_output1 = await async_next(mock_generator) + context.append_output(mock_output1) + + # At this point, we should have 5 prompt tokens and 3 output tokens + assert context.num_prompt_tokens == 5 + assert context.num_output_tokens == 3 + assert context.num_tool_output_tokens == 0 + + # Second turn - after tool output + mock_output2 = await async_next(mock_generator) + context.append_output(mock_output2) + # Current prompt tokens (15) - last_turn_input_tokens (5) - + # last_turn_output_tokens (3) = 7 + expected_tool_output = 7 + + assert context.num_prompt_tokens == 5 + 15 + assert context.num_output_tokens == 3 + 4 + assert context.num_tool_output_tokens == expected_tool_output + assert context.num_cached_tokens == 5 + + # Third turn - final response + mock_output3 = await async_next(mock_generator) + context.append_output(mock_output3) + # Additional tool output tokens from third turn: + # Current prompt (20) - last_turn_input_tokens (15) - + # last_turn_output_tokens (4) = 1 + expected_tool_output = 7 + 1 + + assert context.num_prompt_tokens == 5 + 15 + 20 + assert context.num_output_tokens == 3 + 4 + 5 + assert context.num_tool_output_tokens == expected_tool_output + assert context.num_cached_tokens == 5 + 15 + + +def test_empty_output_tokens(): + """Test behavior when RequestOutput has empty output tokens.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a RequestOutput with empty output tokens + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3], # 3 prompt tokens + output_token_ids=[], # Empty output tokens list + num_cached_tokens=1, + ) + + context.append_output(mock_output) + + # Should handle empty outputs gracefully + assert context.num_prompt_tokens == 3 + assert context.num_output_tokens == 0 # No output tokens + assert context.num_cached_tokens == 1 + assert context.num_tool_output_tokens == 0 + + +def test_missing_prompt_token_ids(): + """Test behavior when RequestOutput has None prompt_token_ids.""" + context = HarmonyContext(messages=[], available_tools=[]) + + mock_output = create_mock_request_output( + prompt_token_ids=None, # No prompt token IDs + output_token_ids=[1, 2], # 2 output tokens + num_cached_tokens=0, + ) + + # Logger.error will be called, but we don't need to check for warnings + # here Just ensure it doesn't raise an exception + context.append_output(mock_output) + + # Should handle missing prompt tokens gracefully + assert context.num_prompt_tokens == 0 + assert context.num_output_tokens == 2 + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 + + +def test_reasoning_tokens_counting(mock_parser): + """Test that reasoning tokens are counted correctly.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Mock parser to simulate reasoning channel + mock_parser.current_channel = "analysis" # Reasoning channel + + mock_output = create_mock_request_output( + prompt_token_ids=[1, 2, 3], + output_token_ids=[4, 5, 6, 7], # 4 tokens, all in reasoning + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # All output tokens should be counted as reasoning + assert context.num_reasoning_tokens == 4 + assert context.num_output_tokens == 4 + + +def test_zero_tokens_edge_case(): + """Test behavior with all zero token counts.""" + context = HarmonyContext(messages=[], available_tools=[]) + + # Create a request with empty lists (not None) for both prompt and + # output tokens + mock_output = create_mock_request_output( + prompt_token_ids=[], # Empty prompt tokens + output_token_ids=[], # Empty output tokens + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # All counts should be zero + assert context.num_prompt_tokens == 0 + assert context.num_output_tokens == 0 + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 + assert context.num_reasoning_tokens == 0 + + +@pytest.mark.asyncio +async def test_single_turn_no_tool_output(): + """Test that first turn never generates tool output tokens.""" + context = HarmonyContext( + messages=[], + available_tools=["browser"] # Tools available + ) + + # Even with large prompt in first turn, no tool tokens should be counted + mock_output = create_mock_request_output( + prompt_token_ids=list(range(100)), # 100 tokens + output_token_ids=[1, 2, 3], + num_cached_tokens=0, + ) + + context.append_output(mock_output) + + # First turn should never have tool output tokens + assert context.num_tool_output_tokens == 0 + assert context.is_first_turn is False # Should be updated after first turn + + +@pytest.mark.asyncio +async def test_negative_tool_tokens_edge_case(): + """Test edge case where calculation could result in negative tool + tokens. We should log an error and clamp the value to 0.""" + # Use patch to check if logger.error was called + with patch("vllm.entrypoints.context.logger.error") as mock_log: + context = HarmonyContext(messages=[], available_tools=["browser"]) + + # First turn + mock_output1 = create_mock_request_output( + prompt_token_ids=list(range(10)), # 10 tokens + output_token_ids=[1, 2, 3, 4, 5], # 5 tokens + ) + context.append_output(mock_output1) + + # Second turn with fewer new tokens than previous output + # This could happen in edge cases with aggressive caching + mock_output2 = create_mock_request_output( + prompt_token_ids=list(range(12)), # 12 tokens (only 2 new) + output_token_ids=[6, 7], # 2 tokens + ) + context.append_output(mock_output2) + + # Calculated negative tool tokens (12 - 10 - 5 = -3) should be clamped + # to 0 and an error should be logged + assert context.num_tool_output_tokens == 0 + assert context.num_prompt_tokens == 10 + 12 + assert context.num_output_tokens == 5 + 2 + + # Verify the error was logged properly + mock_log.assert_called_once() + + # Extract the actual log message and arguments from the call + args, _ = mock_log.call_args + log_message = args[0] + + # Check for key parts of the message + assert "Negative tool output tokens calculated" in log_message + assert "-3" in str(args) # Check that -3 is in the arguments + + +@pytest.mark.asyncio +async def test_streaming_multi_turn_token_counting(mock_parser): + """Test token counting for streaming multi-turn conversations. + + This test focuses on how StreamingHarmonyContext counts tokens in a + multi-turn conversation with streaming (token-by-token) outputs and + message boundaries. + """ + # Create a streaming context + context = StreamingHarmonyContext(messages=[], available_tools=["browser"]) + + # Simulate three turns of conversation: + # Turn 1: stream tokens one by one, then finish the message + # Turn 2: new prompt, stream more tokens with a reasoning segment + # Turn 3: new prompt with tool output and cached tokens + + # First turn: 3 tokens streamed one by one + # First token of first turn + context.append_output( + create_mock_request_output( + prompt_token_ids=[1, 2, 3], # 3 prompt tokens + output_token_ids=[101], # Single token + num_cached_tokens=0, + finished=False, # Not end of message yet + )) + + # Second token of first turn + context.append_output( + create_mock_request_output( + output_token_ids=[102], + finished=False, + )) + + # Last token of first turn (finished=True signals end of message) + context.append_output( + create_mock_request_output( + output_token_ids=[103], + finished=True, # End of message + )) + + # Check token counts after first turn + assert context.num_prompt_tokens == 3 # Initial prompt tokens + assert context.num_output_tokens == 3 # Three output tokens + assert context.num_cached_tokens == 0 + assert context.num_tool_output_tokens == 0 # No tool output in first turn + assert context.first_tok_of_message is True # Ready for next message + + # Second turn: reasoning tokens in analysis channel + mock_parser.current_channel = "analysis" # Set to reasoning channel + + # First token of second turn + context.append_output( + create_mock_request_output( + prompt_token_ids=[1, 2, 3, 101, 102, 103, 4, + 5], # 8 tokens (includes previous) + output_token_ids=[201], + num_cached_tokens=3, # Some tokens cached + finished=False, + )) + + # More tokens in reasoning channel + context.append_output( + create_mock_request_output( + output_token_ids=[202], + finished=False, + )) + + context.append_output( + create_mock_request_output( + output_token_ids=[203], + finished=True, # End of reasoning message + )) + + # Check counts after second turn (reasoning message) + assert context.num_prompt_tokens == 3 + 8 # Initial + second prompt + assert context.num_output_tokens == 3 + 3 # First turn + second turn + assert context.num_reasoning_tokens == 3 # All tokens in analysis channel + assert context.num_cached_tokens == 3 # Cached tokens from second turn + + # Formula: this turn prompt tokens - last turn prompt - last turn output + expected_tool_tokens = 8 - 3 - 3 # = 2 + assert context.num_tool_output_tokens == expected_tool_tokens + + # Third turn: regular output channel + mock_parser.current_channel = "final" # Switch back to regular channel + + # Third turn (with more cached tokens) + context.append_output( + create_mock_request_output( + prompt_token_ids=[ + 1, 2, 3, 101, 102, 103, 4, 5, 201, 202, 203, 6, 7 + ], # 13 tokens + output_token_ids=[301], + num_cached_tokens=8, # More cached tokens + finished=False, + )) + + context.append_output( + create_mock_request_output( + output_token_ids=[302], + finished=True, + )) + + # Final token counts check + assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts + assert context.num_output_tokens == 3 + 3 + 2 # All outputs + assert context.num_reasoning_tokens == 3 # Unchanged from second turn + assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens + + # Additional tool tokens from third turn + # Formula: this turn prompt - last turn prompt - last turn output + additional_tool_tokens = 13 - 8 - 3 # = 2 + assert context.num_tool_output_tokens == expected_tool_tokens \ + + additional_tool_tokens + + +@pytest.mark.asyncio +async def test_streaming_message_synchronization(mock_parser): + """Test message synchronization logic from lines 413-417 in context.py. + + This test verifies that when parser.messages contains more messages than + the context's _messages (minus initial messages), the context properly + extends its message list with the new parser messages. + """ + + # Create a streaming context with some initial messages + initial_messages = [ + Message( + author=Author(role=Role.USER, name="user"), + content=[TextContent(text="Hello")], + recipient=Role.ASSISTANT, + ) + ] + context = StreamingHarmonyContext(messages=initial_messages, + available_tools=[]) + + # Verify initial state + assert len(context._messages) == 1 + assert context.num_init_messages == 1 + + # Mock parser to have more messages than context + # Simulate parser having processed 3 new messages + mock_parser.messages = [ + Message( + author=Author(role=Role.ASSISTANT, name="assistant"), + content=[TextContent(text="Response 1")], + recipient=Role.USER, + ), + ] + + # This should trigger the message synchronization logic + context.append_output( + create_mock_request_output(prompt_token_ids=[1, 2, 3], + output_token_ids=[101], + finished=False)) + + # Verify that messages were synchronized + assert len(context._messages) == 2 + + # Verify the new messages were added correctly + assert context._messages[1].content[0].text == "Response 1" + + # Test the specific condition from line 413-414: + # len(self._messages) - self.num_init_messages < len(self.parser.messages) + messages_minus_init = len(context._messages) - context.num_init_messages + parser_messages_count = len(mock_parser.messages) + + # After synchronization, they should be equal (no longer less than) + assert messages_minus_init == parser_messages_count + + # Test edge case: add one more parser message + mock_parser.messages.append( + Message( + author=Author(role=Role.ASSISTANT, name="assistant"), + content=[TextContent(text="Response 4")], + recipient=Role.USER, + )) + + # Create another output to trigger synchronization again + mock_output2 = create_mock_request_output(prompt_token_ids=[1, 2, 3], + output_token_ids=[102], + finished=True) + + context.append_output(mock_output2) + + # Verify the fourth message was added, num_init_messages is still 1 + assert len(context._messages) == 3 + assert context.num_init_messages == 1 + assert context._messages[2].content[0].text == "Response 4" diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py new file mode 100644 index 0000000000000..1f55b1fba613b --- /dev/null +++ b/tests/entrypoints/test_renderer.py @@ -0,0 +1,333 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import io +from dataclasses import dataclass +from typing import Optional +from unittest.mock import AsyncMock, MagicMock + +import pybase64 +import pytest +import torch + +from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig +from vllm.inputs.data import is_embeds_prompt + + +@dataclass +class MockModelConfig: + max_model_len: int = 100 + encoder_config: Optional[dict] = None + + +class MockTokenizerResult: + + def __init__(self, input_ids): + self.input_ids = input_ids + + +@pytest.fixture +def mock_model_config(): + return MockModelConfig() + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + return tokenizer + + +@pytest.fixture +def mock_async_tokenizer(): + async_tokenizer = AsyncMock() + return async_tokenizer + + +@pytest.fixture +def renderer(mock_model_config, mock_tokenizer): + return CompletionRenderer(model_config=mock_model_config, + tokenizer=mock_tokenizer, + async_tokenizer_pool={}) + + +class TestRenderPrompt: + """Test Category A: Basic Functionality Tests""" + + @pytest.mark.asyncio + async def test_token_input(self, renderer): + tokens = [101, 7592, 2088] + results = await renderer.render_prompt( + prompt_or_prompts=tokens, config=RenderConfig(max_length=100)) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == tokens + + @pytest.mark.asyncio + async def test_token_list_input(self, renderer): + token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] + results = await renderer.render_prompt( + prompt_or_prompts=token_lists, config=RenderConfig(max_length=100)) + + assert len(results) == 3 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012] + assert results[2]["prompt_token_ids"] == [103, 4567] + + @pytest.mark.asyncio + async def test_text_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=100)) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == [101, 7592, 2088] + mock_async_tokenizer.assert_called_once() + + @pytest.mark.asyncio + async def test_text_list_input(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + text_list_input = ["Hello world", "How are you?", "Good morning"] + results = await renderer.render_prompt( + prompt_or_prompts=text_list_input, + config=RenderConfig(max_length=100)) + + assert len(results) == 3 + for result in results: + assert result["prompt_token_ids"] == [101, 7592, 2088] + assert mock_async_tokenizer.call_count == 3 + + @pytest.mark.asyncio + async def test_no_truncation(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=100)) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert "truncation" not in call_args.kwargs or call_args.kwargs[ + "truncation"] is False + + @pytest.mark.asyncio + async def test_truncation_positive(self, renderer, mock_async_tokenizer): + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) # Truncated + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + config=RenderConfig( + max_length=100, + truncate_prompt_tokens=50)) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 50 + + @pytest.mark.asyncio + async def test_truncation_negative(self, renderer, mock_async_tokenizer): + # Test that negative truncation uses model's max_model_len + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) # Truncated to max_model_len + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + config=RenderConfig( + max_length=200, + truncate_prompt_tokens=-1)) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 100 # model's max_model_len + + @pytest.mark.asyncio + async def test_token_truncation_last_elements(self, renderer): + # Test that token truncation keeps the last N elements + long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, + 109] # 10 tokens + results = await renderer.render_prompt(prompt_or_prompts=long_tokens, + config=RenderConfig( + max_length=100, + truncate_prompt_tokens=5)) + + assert len(results) == 1 + # Should keep the last 5 tokens: [105, 106, 107, 108, 109] + assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109] + + @pytest.mark.asyncio + async def test_max_length_exceeded(self, renderer): + long_tokens = list(range(150)) # Exceeds max_model_len=100 + + with pytest.raises(ValueError, match="maximum context length"): + await renderer.render_prompt(prompt_or_prompts=long_tokens, + config=RenderConfig(max_length=100)) + + @pytest.mark.asyncio + async def test_no_tokenizer_for_text(self, mock_model_config): + renderer_no_tokenizer = CompletionRenderer( + model_config=mock_model_config, + tokenizer=None, + async_tokenizer_pool={}) + + with pytest.raises(ValueError, match="No tokenizer available"): + await renderer_no_tokenizer.render_prompt( + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=100)) + + @pytest.mark.asyncio + async def test_token_input_with_needs_detokenization( + self, renderer, mock_async_tokenizer): + # When needs_detokenization=True for token inputs, renderer should + # use the async tokenizer to decode and include the original text + # in the returned prompt object. + mock_async_tokenizer.decode = AsyncMock(return_value="decoded text") + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + tokens = [1, 2, 3, 4] + results = await renderer.render_prompt( + prompt_or_prompts=tokens, + config=RenderConfig(needs_detokenization=True), + ) + + assert len(results) == 1 + assert results[0]["prompt_token_ids"] == tokens + assert results[0]["prompt"] == "decoded text" + mock_async_tokenizer.decode.assert_awaited_once() + + +class TestRenderEmbedPrompt: + + def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes: + """Helper to create base64-encoded tensor bytes""" + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + return pybase64.b64encode(buffer.read()) + + @pytest.mark.asyncio + async def test_single_prompt_embed(self, renderer): + # Create a test tensor + test_tensor = torch.randn(10, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(cache_salt="test_salt"), + ) + + assert len(results) == 1 + assert is_embeds_prompt(results[0]) + assert torch.allclose(results[0]["prompt_embeds"], test_tensor) + assert results[0]["cache_salt"] == "test_salt" + + @pytest.mark.asyncio + async def test_multiple_prompt_embeds(self, renderer): + # Create multiple test tensors + test_tensors = [ + torch.randn(8, 512, dtype=torch.float32), + torch.randn(12, 512, dtype=torch.float32), + ] + embed_bytes_list = [ + self._create_test_embed_bytes(t) for t in test_tensors + ] + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes_list, + config=RenderConfig(), + ) + + assert len(results) == 2 + for i, result in enumerate(results): + assert is_embeds_prompt(result) + assert torch.allclose(result["prompt_embeds"], test_tensors[i]) + + @pytest.mark.asyncio + async def test_prompt_embed_truncation(self, renderer): + # Create tensor with more tokens than truncation limit + test_tensor = torch.randn(20, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(truncate_prompt_tokens=10), + ) + + assert len(results) == 1 + # Should keep last 10 tokens + expected = test_tensor[-10:] + assert torch.allclose(results[0]["prompt_embeds"], expected) + + @pytest.mark.asyncio + async def test_prompt_embed_different_dtypes(self, renderer): + # Test different supported dtypes + dtypes = [torch.float32, torch.float16, torch.bfloat16] + + for dtype in dtypes: + test_tensor = torch.randn(5, 256, dtype=dtype) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(), + ) + + assert len(results) == 1 + assert results[0]["prompt_embeds"].dtype == dtype + + @pytest.mark.asyncio + async def test_prompt_embed_squeeze_batch_dim(self, renderer): + # Test tensor with batch dimension gets squeezed + test_tensor = torch.randn(1, 10, 768, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_embeds=embed_bytes, + config=RenderConfig(), + ) + + assert len(results) == 1 + # Should be squeezed to 2D + assert results[0]["prompt_embeds"].shape == (10, 768) + + @pytest.mark.asyncio + async def test_both_prompts_and_embeds(self, renderer, + mock_async_tokenizer): + # Set up text tokenization + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 102, 103]) + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + # Create embed + test_tensor = torch.randn(5, 256, dtype=torch.float32) + embed_bytes = self._create_test_embed_bytes(test_tensor) + + results = await renderer.render_prompt_and_embeds( + prompt_or_prompts="Hello world", + prompt_embeds=embed_bytes, + config=RenderConfig(), + ) + + assert len(results) == 2 + # First should be embed prompt + assert is_embeds_prompt(results[0]) + # Second should be tokens prompt + assert "prompt_token_ids" in results[1] + assert results[1]["prompt_token_ids"] == [101, 102, 103] diff --git a/tests/evals/gpt_oss/__init__.py b/tests/evals/gpt_oss/__init__.py new file mode 100644 index 0000000000000..0fec1fe5bcdfd --- /dev/null +++ b/tests/evals/gpt_oss/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file diff --git a/tests/evals/gpt_oss/conftest.py b/tests/evals/gpt_oss/conftest.py new file mode 100644 index 0000000000000..35528c0a6a36a --- /dev/null +++ b/tests/evals/gpt_oss/conftest.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Pytest configuration for GPT-OSS evaluation tests. +""" + + +def pytest_addoption(parser): + """Add command line options for pytest.""" + parser.addoption("--model", action="store", help="Model name to evaluate") + parser.addoption("--metric", + action="store", + type=float, + help="Expected metric threshold") + parser.addoption("--server-args", + action="store", + default="", + help="Additional server arguments") diff --git a/tests/evals/gpt_oss/test_gpqa_correctness.py b/tests/evals/gpt_oss/test_gpqa_correctness.py new file mode 100644 index 0000000000000..4cc4041a60ce7 --- /dev/null +++ b/tests/evals/gpt_oss/test_gpqa_correctness.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GPQA evaluation using vLLM server and GPT-OSS evaluation package. + +Usage: +pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \ + --model openai/gpt-oss-20b \ + --metric 0.58 \ + --server-args "--tensor-parallel-size 2" +""" + +import subprocess +import sys + +import regex as re + +from tests.utils import RemoteOpenAIServer + +TOL = 0.05 # Absolute tolerance for accuracy comparison + + +def run_gpqa_eval(model_name: str, base_url: str) -> float: + """Run GPQA evaluation using the gpt-oss evaluation package.""" + + # Build the command to run the evaluation + cmd = [ + sys.executable, "-m", "gpt_oss.evals", "--eval", "gpqa", "--model", + model_name, "--reasoning-effort", "low", "--base-url", base_url + ] + + try: + # Run the evaluation + result = subprocess.run( + cmd, + text=True, + capture_output=True, + timeout=1800, # 30 minute timeout + env={"OPENAI_API_KEY": "dummy"}) + + print("Evaluation process output:\n", result.stdout) + + # Parse the output to extract the score + match = re.search(r"'metric':\s*([\d.]+)", result.stdout) + if match: + return float(match.group(1)) + + # If we still can't find it, raise an error + raise ValueError( + f"Could not parse score from evaluation output:\n{result.stdout}") + + except subprocess.TimeoutExpired as e: + raise RuntimeError("Evaluation timed out") from e + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"Evaluation failed with exit code {e.returncode}:\n" + f"stdout: {e.stdout}\nstderr: {e.stderr}") from e + + +def test_gpqa_correctness(request): + """Test GPQA correctness for GPT-OSS model.""" + + # Get command line arguments + model_name = request.config.getoption("--model") + expected_metric = request.config.getoption("--metric") + server_args_str = request.config.getoption("--server-args") + + # Parse server arguments + server_args = [] + if server_args_str: + server_args = server_args_str.split() + + # Add standard server arguments + server_args.extend([ + "--max-model-len", + "32768", + "--trust-remote-code", + ]) + + print(f"Starting GPQA evaluation for model: {model_name}") + print(f"Expected metric threshold: {expected_metric}") + print(f"Server args: {' '.join(server_args)}") + + # Launch server and run evaluation + with RemoteOpenAIServer(model_name, server_args, + max_wait_seconds=1800) as remote_server: + base_url = remote_server.url_for("v1") + print(f"Server started at: {base_url}") + + measured_metric = run_gpqa_eval(model_name, base_url) + + print(f"GPQA Results for {model_name}:") + print(f" Measured metric: {measured_metric:.4f}") + print(f" Expected metric: {expected_metric:.4f}") + print(f" Tolerance: {TOL:.4f}") + + # Verify metric is within tolerance + assert measured_metric >= expected_metric - TOL, ( + f"GPQA metric too low: {measured_metric:.4f} < " + f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}") + + print(f"✅ GPQA test passed for {model_name}") diff --git a/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml new file mode 100644 index 0000000000000..7ec6a1e0be27f --- /dev/null +++ b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml @@ -0,0 +1,6 @@ +model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8" +accuracy_threshold: 0.72 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 + diff --git a/tests/evals/gsm8k/configs/models-small.txt b/tests/evals/gsm8k/configs/models-small.txt index afd1065b9191b..7bce3f0004f7d 100644 --- a/tests/evals/gsm8k/configs/models-small.txt +++ b/tests/evals/gsm8k/configs/models-small.txt @@ -3,3 +3,4 @@ Llama-3.2-1B-Instruct-INT8-CT.yaml Llama-3-8B-Instruct-nonuniform-CT.yaml Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml Qwen1.5-MoE-W4A16-CT.yaml +DeepSeek-V2-Lite-Instruct-FP8.yaml diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index aea166da3af2f..190c92e1251c2 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -22,7 +22,10 @@ def clear_cache(): # Define MLA and non-MLA backends separately DEVICE_MLA_BACKENDS = { - "cuda": ["TRITON_MLA", "FLASHMLA"], + "cuda": [ + "TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA", + "CUTLASS_MLA" + ], "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], "cpu": [], } @@ -90,42 +93,56 @@ def test_env( with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - block_size, False) + backend = get_attn_backend(16, torch.float16, None, block_size, + False) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "hip": with patch("vllm.attention.selector.current_platform", RocmPlatform()): if use_mla: - # Validate HIP MLA backend-block_size combinations - valid_combination = ( - (name == "TRITON_MLA" and block_size != 1) - or (name == "ROCM_AITER_MLA" and block_size == 1)) + # ROCm MLA backend logic: + # - TRITON_MLA: supported when block_size != 1 + # - ROCM_AITER_MLA: supported when block_size == 1 + # If backend is forced but doesn't match block_size, + # should raise ValueError - if valid_combination: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name - assert backend.get_name() == expected - else: + if name == "TRITON_MLA" and block_size == 1: + # TRITON_MLA doesn't support block_size == 1 with pytest.raises(ValueError) as exc_info: get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) assert f"The selected backend, {name}" in str( exc_info.value) + elif name == "ROCM_AITER_MLA" and block_size != 1: + # ROCM_AITER_MLA only supports block_size == 1 + with pytest.raises(ValueError) as exc_info: + get_attn_backend(16, + torch.float16, + None, + block_size, + False, + use_mla=use_mla) + assert f"The selected backend, {name}" in str( + exc_info.value) + else: + # Valid backend-block_size combination + backend = get_attn_backend(16, + torch.float16, + None, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -136,29 +153,92 @@ def test_env( with patch("vllm.attention.selector.current_platform", CudaPlatform()): if use_mla: - if name == "FLASHMLA" and block_size == 64: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) + # CUDA MLA backend logic: + # - CUTLASS_MLA: only supported with block_size == 128 + # and Blackwell GPUs (SM 10.0), V1 only + # - FLASHINFER_MLA: only supported on Blackwell GPUs + # (SM 10.0+), V1 only + # - FLASHMLA: only supported with block_size == 64 + # - FLASH_ATTN_MLA: V1 only + # - TRITON_MLA: fallback for other cases - # only on cuda platforms with specific capability. - is_supported, _ = is_flashmla_supported() - - if not is_supported: - # if platform is not supported then skip this case. - pytest.skip() + if name == "CUTLASS_MLA": + if not use_v1: + # CUTLASS_MLA only supported on V1 engine + pytest.skip( + "CUTLASS_MLA only supported on V1 engine") + elif block_size != 128: + # CUTLASS_MLA only supports block_size == 128 + pytest.skip( + "CUTLASS_MLA only supports block_size 128") else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name + expected = "CUTLASS_MLA_VLLM_V1" + assert backend.get_name() == expected + elif name == "FLASHINFER_MLA": + if not use_v1: + # FlashInfer MLA only supported on V1 engine + pytest.skip( + "FlashInfer MLA only supported on V1 engine") + elif block_size not in [32, 64]: + # FlashInfer MLA only supports block_size 32 or 64 + pytest.skip( + "FlashInfer MLA only supports block_size 32 " + "or 64") + else: + backend = get_attn_backend(16, + torch.float16, + None, + block_size, + False, + use_mla=use_mla) + expected = "FLASHINFER_MLA" + assert backend.get_name() == expected + elif name == "FLASHMLA": + if block_size != 64: + # FlashMLA only supports block_size == 64 + pytest.skip("FlashMLA only supports block_size 64") + else: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + is_supported, _ = is_flashmla_supported() + if not is_supported: + pytest.skip( + "FlashMLA not supported on this platform") + else: + backend = get_attn_backend(16, + torch.float16, + None, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + elif name == "FLASH_ATTN_MLA": + if not use_v1: + # FlashAttention MLA only supported on V1 engine + pytest.skip( + "FlashAttention MLA only supported on V1 engine" + ) + else: + backend = get_attn_backend(16, + torch.float16, + None, + block_size, + False, + use_mla=use_mla) + expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: + # TRITON_MLA or other fallback backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -168,7 +248,7 @@ def test_env( elif name == "FLASHINFER": backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -177,7 +257,7 @@ def test_env( else: backend = get_attn_backend(32, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -187,7 +267,7 @@ def test_env( if use_v1: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -213,15 +293,13 @@ def test_fp32_fallback( with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) + backend = get_attn_backend(16, torch.float32, None, 16, False) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "cuda": with patch("vllm.attention.selector.current_platform", CudaPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) + backend = get_attn_backend(16, torch.float32, None, 16, False) assert (backend.get_name() == "FLEX_ATTENTION" if use_v1 else "XFORMERS") @@ -275,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): assert backend.get_name() != STR_FLASH_ATTN_VAL # Attention-free models should bypass env and use PlaceholderAttention - backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) + backend = get_attn_backend(16, torch.float16, None, 16, True) assert backend.get_name() != STR_FLASH_ATTN_VAL @@ -290,4 +368,4 @@ def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch): # Should raise ValueError for invalid backend with pytest.raises(ValueError) as exc_info: get_attn_backend(32, torch.float16, None, 16, False) - assert "Invalid attention backend: 'INVALID'" in str(exc_info.value) + assert "Invalid value 'INVALID'" in str(exc_info.value) diff --git a/tests/kernels/attention/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py new file mode 100644 index 0000000000000..5078bd730a1a3 --- /dev/null +++ b/tests/kernels/attention/test_cutlass_mla_decode.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +import random +from typing import Optional + +import pytest +import torch + +import vllm._custom_ops as ops +from vllm.platforms import current_platform +from vllm.triton_utils import triton + + +def cal_diff(x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False, + diff_threshold: Optional[float] = None) -> None: + x, y = x.double(), y.double() + cos_diff = 1 - 2 * (x * y).sum().item() / max( + (x * x + y * y).sum().item(), 1e-12) + if diff_threshold is not None: + # directly compare the cos_diff with the threshold + assert cos_diff < diff_threshold + else: + # use the default threshold + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 + + +CUTLASS_MLA_UNSUPPORTED_REASON = \ + "Cutlass MLA Requires compute capability of 10 or above." \ + if not current_platform.is_device_capability(100) \ + else "Cutlass MLA is supported" + + +@pytest.mark.skipif(not current_platform.has_device_capability(100), + reason=CUTLASS_MLA_UNSUPPORTED_REASON) +@pytest.mark.parametrize("b", [128]) +@pytest.mark.parametrize("s_q", [1]) +@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) +@pytest.mark.parametrize("h_q", [16, 32, 64, 128]) +@pytest.mark.parametrize("h_kv", [1]) +@pytest.mark.parametrize("d", [576]) +@pytest.mark.parametrize("dv", [512]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("varlen", [False, True]) +@pytest.mark.parametrize( + "torch_dtype", + [ + torch.bfloat16, + # fp8 can have occasional precision-related failures. + pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2)) + ]) +@torch.inference_mode() +def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, + causal, varlen, torch_dtype): + device = torch.device("cuda:0") + if torch_dtype == torch.float8_e4m3fn: + init_dtype = torch.bfloat16 + else: + init_dtype = torch_dtype + torch.set_default_dtype(init_dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(42) + random.seed(42) + + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + + use_fp8 = torch_dtype == torch.float8_e4m3fn + scale = math.sqrt(d)**(-1) + cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), + s_q) + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + + q = torch.randn(b, s_q, h_q, d) + block_table = torch.arange(b * max_seqlen_pad // block_size, + dtype=torch.int32).view( + b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + blocked_v = blocked_k[..., :dv] + + init_dtype = q.dtype + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q = q.to(fp8_dtype) + blocked_k = blocked_k.to(fp8_dtype) + blocked_v = blocked_v.to(fp8_dtype) + else: + descale_q = None + descale_k = None + + def cutlass_mla(): + MAX_HEADS = 128 + + q_reshaped = q.squeeze(1) + q_nope = q_reshaped[:, :, :dv].clone() + q_pe = q_reshaped[:, :, dv:].clone() + + if h_q < MAX_HEADS: + q_nope_padded = q_nope.new_empty((b, MAX_HEADS, dv)) + q_nope_padded[:, :h_q] = q_nope + q_nope = q_nope_padded + + q_pe_padded = q_pe.new_empty((b, MAX_HEADS, d - dv)) + q_pe_padded[:, :h_q] = q_pe + q_pe = q_pe_padded + + kv_cache_flat = blocked_k.squeeze(2) + device_properties = torch.cuda.get_device_properties( + torch.device("cuda:0")) + sm_count = device_properties.multi_processor_count + workspace_size = ops.sm100_cutlass_mla_get_workspace_size( + max_seqlen * block_size, b, sm_count, num_kv_splits=1) + workspace = torch.empty(workspace_size, + device="cuda", + dtype=torch.uint8) + + out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype) + output_lse = torch.empty((b, MAX_HEADS), + dtype=torch.float32, + device=q_nope.device) + ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe, + kv_cache_flat, cache_seqlens, block_table, + workspace, scale, 1) + return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous() + + def scaled_dot_product_attention(query, key, value, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, + dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_v + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + out_i, lse_i = scaled_dot_product_attention( + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + is_causal=causal, + ) + out[i] = out_i.transpose(0, 1) + lse[i] = lse_i + return out, lse + + out_cutlass, lse_cutlass = cutlass_mla() + out_torch, lse_torch = ref_mla() + # Extract the single token (s_q=1) slice to match cutlass output shape + out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv] + lse_torch_slice = lse_torch[:, 0, :] # [b, h_q] + cal_diff(out_cutlass, out_torch_slice, "out", use_fp8) + # lse has larger numerical error, so use a larger threshold + cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3) + + t = triton.testing.do_bench(cutlass_mla) + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( + b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", + f"{bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/tests/kernels/attention/test_encoder_decoder_attn.py b/tests/kernels/attention/test_encoder_decoder_attn.py deleted file mode 100644 index a2e6986460904..0000000000000 --- a/tests/kernels/attention/test_encoder_decoder_attn.py +++ /dev/null @@ -1,1105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Tests: - -* E2E test of Encoder attention + Decoder self-attention + - Encoder/decoder cross-attention (collectively - "encoder/decoder attention") - -""" - -from typing import NamedTuple, Optional - -import pytest -import torch - -from tests.kernels.utils import * -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.attention.selector import (_Backend, _cached_get_attn_backend, - global_force_attn_backend_context_manager) -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.forward_context import set_forward_context -from vllm.platforms import current_platform - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Encoder-decoder is only supported on V0, so set - VLLM_USE_V1=0 for all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -# List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] -HEAD_SIZES = [64, 256] - -NUM_HEADS = [1, 16] - -BATCH_SIZES = [1, 16] -BLOCK_SIZES = [16] -CUDA_DEVICE = "cuda:0" - -MAX_DEC_SEQ_LENS = [128] -MAX_ENC_SEQ_LENS = [128] - -# Narrow test-cases for unsupported-scenario -# tests -HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]] - - -class TestPoint(NamedTuple): - """ - Encapsulates the attributes which define a single invocation - of the test_e2e_enc_dec_attn() test - - Attributes: - num_heads: The number of heads in the model. - head_size: Head dimension - backend_name: Name of the backend framework used. - batch_size: Number of samples per batch. - block_size: Size of each block of data processed. - max_dec_seq_len: Maximum sequence length for the decoder. - max_enc_seq_len: Maximum sequence length for the encoder. - num_blocks: Number of blocks in the model. - """ - - num_heads: int - head_size: int - backend_name: str - batch_size: int - block_size: int - max_dec_seq_len: int - max_enc_seq_len: int - num_blocks: int - attn_type: AttentionType - - -class TestResources(NamedTuple): - ''' - Encapsulates key components for performing an - encoder/decoder attention test - - Note that - (1) attn automatically selects an attention backend - based on platform info & a set of canned - heuristics - (2) attn_backend is thus *not the same backend - instance* used by attn, but rather it is - intended to be a - *different instance* of the *same backend class*; - it is assumed that the user of TestResources - will leverage attn_backend for the purpose of - constructing backend-compatible attention - metadata instances - - Attributes: - - * scale: 1/sqrt(d) scale factor for attn - * attn_backend: implementations of abstraction - attention interface using - a particular kernel library - i.e. XFormers - * attn: Attention layer instance - * kv_cache: shared key/value cache for all attention - ''' - - scale: float - attn: Attention - kv_cache: torch.Tensor - - -def _make_test_resources(test_pt: TestPoint, ) -> TestResources: - ''' - Build key components for performing encoder/decoder attention test. - - Note that - (1) The Attention instance constructed here, automatically selects - an attention backend class based on platform info & a set of canned - heuristics, so - (2) The attention backend instance constructed here is thus *not - the same backend instance* used by attn, but rather it is - intended to be a *different instance* of the *same backend class*; - therefore, - (3) This function requires that test_pt.backend_name matches the backend - class that Attention will automatically select when it is constructed. - - - Arguments: - - * test_pt: TestPoint data structure; this function relies on the - following fields: num_heads, head_size, num_blocks, - block_size, backend_name - - Returns: - - * TestResources data structure. - ''' - - scale = float(1.0 / (test_pt.head_size**0.5)) - attn = Attention( - test_pt.num_heads, - test_pt.head_size, - scale=scale, - prefix=f"{test_pt.attn_type}", - attn_type=test_pt.attn_type, - ) - if test_pt.num_blocks is None or test_pt.num_heads is None: - # Caller does not require a KV cache - return TestResources( - scale, attn, - torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) - - # Construct KV cache - if test_pt.attn_type in (AttentionType.DECODER, - AttentionType.ENCODER_DECODER): - kv_cache = make_kv_cache(test_pt.num_blocks, - test_pt.num_heads, - test_pt.head_size, - test_pt.block_size, - device=CUDA_DEVICE, - backend=test_pt.backend_name) - else: - kv_cache = torch.tensor([]) - - attn.kv_cache = [kv_cache] - return TestResources(scale, attn, kv_cache) - - -def _encoder_attn_setup( - test_pt: TestPoint, - test_rsrcs: TestResources, -) -> PhaseTestParameters: - ''' - Set up test vectors & data structures for encoder attention test. - - A triplet of synthetic query/key/value tensors are constructed. - Given this is an encoder attention test, the key & value - sequences will have the same length as the corresponding queries. - - The query/key/value tensors are passed to an ideal reference - self-attention implementation to generate an ideal output tensor. - - Encoder inference does not populate the KV cache, therefore - no KV cache memory mapping is constructed - - Arguments: - - * test_pt: TestPoint data structure; this function relies on the - following fields: batch_size, num_heads, head_size, - block_size, max_q_seq_len - * test_rsrcs: TestResources data structure; this function relies on the - scale field - - - Returns: - - * PhaseTestParameters data structure comprising (1) packed query/key/value - tensors, (2) the ideal output of attention computed using a naive - implementation, and (3) KVCache field set to None - ''' - - ( - num_heads, - head_size, - _, - batch_size, - _, - _, - max_q_seq_len, - _, - _, - ) = test_pt - - scale = test_rsrcs.scale - - max_kv_seq_len = max_q_seq_len - - # Make test tensors - - qkv_in, _, _ = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) - - # Compute correct answer using naive non-causal attention - # implementation - - ideal_output = ref_masked_attention(qkv_in.query, - qkv_in.key, - qkv_in.value, - scale=scale, - q_seq_lens=qkv_in.q_seq_lens, - kv_seq_lens=qkv_in.kv_seq_lens) - - packed_ideal_output, _ = pack_tensor(ideal_output, - qkv_in.q_seq_lens, - device=CUDA_DEVICE) - - packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) - - return PhaseTestParameters( - PackedQKVO(packed_qkv, packed_ideal_output), - None # No KV cache - ) - - -def _decoder_attn_setup( - test_pt: TestPoint, - test_rsrcs: TestResources, - block_base_addr: int = 0, -) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: - ''' - Set up test vectors & data structures for self-attention test. - - A triplet of synthetic query/key/value tensors are constructed ("baseline" - query/key/value). Given this is a self-attention test, the key & value - sequences will have the same length as the corresponding queries. - - "Prefill" query/key/value tensors are derived by masking out the last value - in each baseline query/key/value. These tensors are used to test prefill & - populate KV cache for a subsequent decode test. - - "Decode" query/key/value tensors are derived by extracting *only* the last - value from each baseline query/key/value (i.e. complement of the prefill - tensors.) These tensors are used to test decode, conditional on the kv cache - being populated during the prefill test. - - The baseline query/key/value tensors are passed to an ideal reference - self-attention implementation to generate a "Baseline" ideal output tensor. - This tensor is split into the "Prefill" ideal output tensor (all but the - last element of each output sequence) and the "Decode" ideal output tensor - (*only* the last element of each output sequence); the "Prefill" and - "Decode" ideal output tensors can be used to validate the prefill and decode - test results, respectively. - - This function also constructs the self-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts at - block_base_addr - - Arguments: - - * test_pt: TestPoint data structure; this function relies on the - following fields: batch_size, num_heads, head_size, - block_size, max_q_seq_len - * test_rsrcs: TestResources data structure; this function relies on the - scale field - * block_base_addr: decoder self-attention block-table base address - - Returns: - * qkv: Unpacked (batch_size x padded_seq_len x num_heads x - head_size) query/key/value tensors - * Prefill-phase decoder self-attention PhaseTestParameters data structure, - including (1) packed (number_of_tokens x num_heads x head_size) - query/key/value tensors along with (2) ideal attention output - computed using a naive implementation, and (3) memory-mapping data - structures appropriate for prefill phase. - * Decode-phase decoder self-attention PhaseTestParameters data structure, - including (1) packed (number_of_tokens x num_heads x head_size) - query/key/value tensors along with (2) ideal attention output - computed using a naive implementation, and (3) memory-mapping data - structures appropriate for decode phase. - * max_block_idx: max physical address in decoder self-attention block-table - (intended to be used as the base address for the encoder/ - decoder cross-attention block-table, which is not - constructed in this function) - ''' - - ( - num_heads, - head_size, - _, - batch_size, - block_size, - max_q_seq_len, - _, - _, - _, - ) = test_pt - - scale = test_rsrcs.scale - - max_kv_seq_len = max_q_seq_len - - # Build test tensors - - ( - qkv, - prefill_qkv, - decode_qkv, - ) = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.DECODER, - device=CUDA_DEVICE) - - # Compute correct answer using naive attention implementation - # with causal attention mask - - causal_mask = make_causal_mask(max_q_seq_len, - max_kv_seq_len).to(CUDA_DEVICE) - - ideal_output = ref_masked_attention(qkv.query, - qkv.key, - qkv.value, - scale=scale, - custom_mask=causal_mask, - q_seq_lens=qkv.q_seq_lens, - kv_seq_lens=qkv.kv_seq_lens) - - # Split out the prefill- & decode-phase ideal answers & pack them - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens): - prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - bdx, :prefill_q_seq_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( - prefill_q_seq_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_qkv.q_seq_lens, - device=CUDA_DEVICE) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)], - device=CUDA_DEVICE) - - # Build prefill- & decode-phase data structures - # for decoder self-attention. Block tables and - # slot mapping must be in a format compatible - # with KV caching & attention kernels - # - # Prefill-phase: - # - # * Empty block-tables tensor - # * Slot-mapping with entries for prompt tokens - # - # Decode-phase: - # * Block-tables tensor with minimum number of blocks - # required by total num. tokens in the entirety of all sequences - # (including both prefill & decode) - # * Slot-mapping with entries for tokens that will be decoded in the - # current decode iteration - # - # Note: the format described above is simply mirroring what ModelRunner - # produces - - prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) - - ( - decode_block_tables, - slot_mapping_list, - max_block_idx, - ) = make_block_tables_slot_mapping(block_size, - qkv.q_seq_lens, - device=CUDA_DEVICE, - block_base_addr=block_base_addr) - - ( - prefill_slot_mapping, - decode_slot_mapping, - ) = split_slot_mapping(slot_mapping_list, - qkv.q_seq_lens, - device=CUDA_DEVICE) - - prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE) - - decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) - - return ( - qkv, - PhaseTestParameters( # Prefill test params - PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output), - KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), - PhaseTestParameters( # Decode test params - PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output), - KVMemoryMap(decode_block_tables, decode_slot_mapping)), - max_block_idx) - - -def _enc_dec_cross_attn_setup_reuses_query( - decoder_qkv: QKVInputs, - encoder_test_params: PhaseTestParameters, - prefill_decoder_phase_test_params: PhaseTestParameters, - test_pt: TestPoint, - test_rsrcs: TestResources, - block_base_addr: int = 0, -) -> tuple[PhaseTestParameters, PhaseTestParameters]: - ''' - Set up test vectors & data structures for cross-attention test. - - A triplet of synthetic cross-attention key/value tensors are constructed - ("baseline" key/value). Given this is a cross-attention test, we assume - query tensors were already synthesized for a prior self-attention test and - will be reused for cross-attention. The key & value sequences generated here - may have a different length than the corresponding queries (as is often - the case for cross-attention between decoder and encoder sequences.) - - Cross attention key & value tensors do not grow during autoregressive - inference; thus this function obtains a single key/value pair suitable for - both prefill and decode. - - The "baseline" query tensor is received as an argument. The "baseline" - query/key/value tensors are passed to an ideal reference cross-attention - implementation to generate a "baseline" ideal output tensor. This tensor is - split into the "Prefill" ideal output tensor (all but the last element of - each output sequence) and the "Decode" ideal output tensor (*only* the last - element of each output sequence); the "Prefill" and "Decode" ideal output - tensors can be used to validate the prefill and decode test results, - respectively. - - This function also constructs the cross-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts at - block_base_addr. - - Arguments: - - * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x - num_heads x head_size) decoder self-attention inputs; - this function relies on the query and q_seq_lens - fields - * encoder_test_params: PhaseTestParameters data structure which was - used for encoder inference; KV cache field - is not used by this function - * prefill_decoder_phase_test_params: PhaseTestParameters data structure - used for prefill-phase decoder - self-attention; all fields - including KV cache required - * test_pt: TestPoint data structure; this function relies on the - following fields: batch_size, num_heads, head_size, - block_size, max_q_seq_len - * test_rsrcs: TestResources data structure; this function relies on the - scale field - * block_base_addr: decoder self-attention block-table base address - - Returns: - - * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data - structure, including (1) packed - (number_of_tokens x num_heads x head_size) query/key/value tensors - along with (2) ideal attention output computed using a - naive implementation, and (3) memory-mapping data structures appropriate - for prefill phase. - * Decode-phase encoder/decoder cross-attention PhaseTestParameters data - structure, including (1) packed - (number_of_tokens x num_heads x head_size) query/key/value tensors - along with (2) ideal attention output computed using a - naive implementation, and (3) memory-mapping data structures appropriate - for decode phase. - ''' - - assert encoder_test_params.packed_qkvo.packed_qkv is not None - assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None - - ( - num_heads, - head_size, - _, - batch_size, - block_size, - max_decoder_seq_len, - max_encoder_seq_len, - _, - _, - ) = test_pt - - scale = test_rsrcs.scale - - decoder_query = decoder_qkv.query - decoder_seq_lens = decoder_qkv.q_seq_lens - encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - prefill_q_seq_lens = ( - prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens) - - assert prefill_q_seq_lens is not None - - ( - cross_kv, - _, - _, - ) = make_qkv(batch_size, - max_decoder_seq_len, - max_encoder_seq_len, - num_heads, - head_size, - force_kv_seq_lens=encoder_seq_lens, - attn_type=AttentionType.ENCODER_DECODER, - device=CUDA_DEVICE) - - ideal_output = ref_masked_attention(decoder_query, - cross_kv.key, - cross_kv.value, - scale=scale, - q_seq_lens=decoder_seq_lens, - kv_seq_lens=cross_kv.kv_seq_lens) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): - prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - bdx, :prefill_q_seq_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( - prefill_q_seq_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens, - device=CUDA_DEVICE) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)], - device=CUDA_DEVICE) - - # Build prefill- & decode-phase data structures - # for encoder/decoder cross-attention. Block tables and - # slot mapping must be in a format compatible - # with KV caching & attention kernels - # - # Whereas decoder self-attention extracts relationships between - # equal-length Q/K/V sequences, which mutually grow in length - # with each decoded token, cross-attention relates the Q sequence - # - which grows with each new decoded token - to fixed-length - # K and V sequences derived from the encoder hidden states. - # - # Prefill-phase: - # - # * Empty block-tables tensor - # * Slot-mapping with as many entries as there are tokens in the encoder - # prompt. - # - # Decode-phase: - # * Block-tables tensor with minimum number of blocks to - # accommodate K & V tensors which are equal in lnegth - # to the encoder prompt length - # * Empty slot-mapping tensor (since K & V are fixed in size, - # new decoded tokens are not KV-cached and require no slot- - # mapping) - # - # Note: the format above is simply an extension of what ModelRunner - # produces for decoder-only models - - prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) - decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) - - ( - decode_block_tables, - prefill_slot_mapping_list, - _, - ) = make_block_tables_slot_mapping(block_size, - cross_kv.kv_seq_lens, - block_base_addr=block_base_addr, - device=CUDA_DEVICE) - - prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list, - device=CUDA_DEVICE) - - # Packed key/value (query is already provided) - packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) - - return ( - PhaseTestParameters( # Prefill-phase test params - PackedQKVO(packed_cross_kv, prefill_packed_ideal_output), - KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), - PhaseTestParameters( # Decode-phase test params - PackedQKVO(None, decode_packed_ideal_output), - KVMemoryMap(decode_block_tables, decode_slot_mapping))) - - -def _run_encoder_attention_test( - attn: Attention, - encoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata, - test_pt: TestPoint, - vllm_config: VllmConfig, -) -> torch.Tensor: - ''' - Run encoder attention. - - attn.forward() is passed attn_type=AttentionType.ENCODER in order - to configure the kernel invocation for encoder attention - - Requires attn_metadata.num_decode_tokens == 0 - (There is no encoder execution in the decode-phase) - - Arguments: - - * attn: Attention wrapper instance - * encoder_test_params: encoder PhaseTestParameters data structure; - this function relies on the packed - (number_of_tokens x num_heads x head_size) - query/key/value fields - * attn_metadata: attention metadata for encoder/decoder-self attention - * test_pt: The TestPoint object containing test details like number of - model heads, head size, name of the backend being used etc. - - Returns: - * Attention.forward() applied to packed {query,key,value} and - & attn_metadata - ''' - assert attn_metadata.num_decode_tokens == 0 - packed_qkv = encoder_test_params.packed_qkvo.packed_qkv - assert packed_qkv is not None - with set_forward_context(attn_metadata, vllm_config): - # In the test setup the shape of the query is - # [batch_size, seq_len, num_heads, head_size]. However - # the attention backend expect the shape to be - # [num_tokens, hidden_size]. Hence reshape the query before - # invoking the forward method. - # TODO - Update the way we construct the query so that it - # is shaped as [num_tokens, hidden_size] and we can skip the reshape. - reshaped_query = packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) - - -def _run_decoder_self_attention_test( - test_rsrcs: TestResources, - decoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata, - test_pt: TestPoint, - vllm_config: VllmConfig, -) -> torch.Tensor: - ''' - Run decoder self-attention test. - - attn.forward() is passed attn_type=AttentionType.DECODER - in order to configure the kernel invocation for decoder self-attention. - - Arguments: - - * test_rsrcs: TestResources instance; this function relies on the kv_cache - and attn (Attention wrapper instance) fields - * decoder_test_params: decoder PhaseTestParameters data structure; - this function relies on the packed - (number_of_tokens x num_heads x head_size) - query/key/value fields - * attn_metadata: attention metadata for decoder-self attention - (contains KV cache memory-mapping) - * test_pt: The TestPoint object containing test details like number of - model heads, head size, name of the backend being used etc. - - Returns: - * Attention.forward() applied to packed_{query,key,value}, kv_cache - & attn_metadata - ''' - attn = test_rsrcs.attn - packed_qkv = decoder_test_params.packed_qkvo.packed_qkv - assert packed_qkv is not None - with set_forward_context(attn_metadata, vllm_config): - # In the test setup the shape of the query is - # [batch_size, seq_len, num_heads, head_size]. However - # the attention backend expect the shape to be - # [num_tokens, hidden_size]. Hence reshape the query before - # invoking the forward method. - # TODO - Update the way we construct the query so that it - # is shaped as [num_tokens, hidden_size] and we can skip the reshape. - reshaped_query = packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value) - - -def _run_encoder_decoder_cross_attention_test( - test_rsrcs: TestResources, - decoder_test_params: PhaseTestParameters, - cross_test_params: Optional[PhaseTestParameters], - attn_metadata: AttentionMetadata, - test_pt: TestPoint, - vllm_config: VllmConfig, -) -> torch.Tensor: - ''' - Run encoder/decoder cross-attention test. - - Via PhaseTestParameters data structures, consumes the same query utilized - for decoder self-attention, plus a key/value specific to cross-attention. - - if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv - is None, this reflects that in decode-phase cross attention there - is no growth in the key and value tensors. - - attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER - in order to configure the kernel invocation for encoder/decoder cross- - attention. - - Arguments: - - * test_rsrcs: TestResources instance; this function relies on the kv_cache - and attn (Attention wrapper instance) fields - * decoder_test_params: decoder PhaseTestParameters data structure; - this function relies on the packed - (number_of_tokens x num_heads x head_size) - query field - * cross_test_params: encoder/decoder PhaseTestParameters data structure; - this function relies on the packed - (number_of_tokens x num_heads x head_size) - key/value fields - * attn_metadata: attention metadata for encoder/decoder-self attention - * test_pt: The TestPoint object containing test details like number of - model heads, head size, name of the backend being used etc. - - Returns: - * Attention.forward() applied to packed_{query,key,value}, kv_cache - & attn_metadata - ''' - assert decoder_test_params.packed_qkvo.packed_qkv is not None - - attn = test_rsrcs.attn - if cross_test_params is None: - key = None - value = None - else: - cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv - key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) - value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) - with set_forward_context(attn_metadata, vllm_config): - # In the test setup the shape of the query is - # [batch_size, seq_len, num_heads, head_size]. However - # the attention backend expect the shape to be - # [num_tokens, hidden_size]. Hence reshape the query before - # invoking the forward method. - # TODO - Update the way we construct the query so that it - # is shaped as [num_tokens, hidden_size] and we can skip the reshape. - reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( - -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, key, value) - - -@pytest.fixture(autouse=True) -def set_reset_environment(attn_backend): - # Set the default torch datatype to bfloat16 to enable - # testing of the Flash Attention backend. Also clear the - # cached value of the backend. - default_dtype = torch.get_default_dtype() - if attn_backend.name == 'FLASH_ATTN': - torch.set_default_dtype(torch.bfloat16) - _cached_get_attn_backend.cache_clear() - yield - # Reset the torch datatype to what it was before the test - # so as not to impact the remaining tests. - torch.set_default_dtype(default_dtype) - - -@pytest.mark.skipif(current_platform.is_rocm(), - reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) -@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_encoder_only( - num_heads: int, - head_size: int, - attn_backend: _Backend, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, -): - ''' - End-to-end encoder-only attention test: - - * Construct fake test vectors for (1) encoder attention - * Construct (1) attention metadata structure with prefill-phase - encoder attention, and (2) an analogous attention metadata - structure but for decode-phase - * Test & validate encoder attention against ideal output - - No KV cache is required for encoder-only attention. - - Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if - current_platform.is_rocm(). - - This test globally forces an override of the usual backend - auto-selection process, forcing the specific backend-under-test - to be utilized. - - Arguments: - - * num_heads - * head_size, - * attn_backend: The attention backend to employ for testing - * batch_size - * block_size: KV cache block size - * max_dec_seq_len: max length of decoder input sequences - * max_enc_seq_len: max length of encoder input sequences - ''' - # Force Attention wrapper backend - with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally - # to be more than necessary, since exceeding the kv cache size - # is not part of this test - test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, AttentionType.ENCODER) - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - test_rsrcs = _make_test_resources(test_pt) - - # Construct encoder attention test params (only used - # during prefill) - - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - - # Shared prefill metadata structure - - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - None, - decoder_test_params=None, - encoder_test_params=enc_test_params, - cross_test_params=None, - device=CUDA_DEVICE) - - # PREFILL: encoder attention - - enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata, - test_pt=test_pt, - vllm_config=vllm_config)) - - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, - attn_backend.name) - - -@pytest.mark.skipif(current_platform.is_rocm(), - reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) -@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_e2e_enc_dec_attn( - num_heads: int, - head_size: int, - attn_backend: _Backend, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, -) -> None: - ''' - End-to-end encoder/decoder test: - - * Construct fake test vectors for (1) encoder attention, - (2) decoder self-attention, and (3) encoder/decoder cross-attention - * Construct (1) attention metadata structure with self- and cross-attention - attributes for prefill-phase, and (2) an analogous attention metadata - structure but for decode-phase - * Test attention steps in the following order - - * Encoder attention - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * Besides being reflective of realistic use-cases, this order would - exacerbate any accidental overlap in the self-/cross-attention - block tables, which one hopes to avoid - - - * Validate output correctness against ideal reference attention - implementation - - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. - - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. - - This test globally forces an override of the usual backend - auto-selection process, forcing the specific backend-under-test - to be utilized. - - Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if - current_platform.is_rocm(). - - Note on metadata: there is a single attention metadata structure shared by - all prefill-phase attention operations (encoder, decoder, enc/dec cross), - and a single one shared by all decode-phase attention operations - (decoder & enc/dec cross.) This is intended to reflect the behavior - of EncoderDecoderModelRunner, which constructs a single attention metadata - structure for each prefill or decode run. A realistic scenario would rely - on the attention backend to utilize the appropriate attention metadata - fields according to the value of attn_metadata.attention_type. Thus, - this test is organized so as to confirm that the backend-under-test can - handle a shared prefill attention metadata structure & a shared decode\ - attention metadata structure. - - Arguments: - - * num_heads - * head_size, - * attn_backend: The attention backend to employ for testing - * batch_size - * block_size: KV cache block size - * max_dec_seq_len: max length of decoder input sequences - * max_enc_seq_len: max length of encoder input sequences - ''' - # Force Attention wrapper backend - with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally - # to be more than necessary, since exceeding the kv cache size - # is not part of this test - enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, AttentionType.ENCODER) - enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, - AttentionType.ENCODER_DECODER) - dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096, AttentionType.DECODER) - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - enc_test_rsrcs = _make_test_resources(enc_test_pt) - enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt) - dec_test_rsrcs = _make_test_resources(dec_test_pt) - - # Construct encoder attention test params (only used - # during prefill) - - enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs) - - # Construct Decoder self-attention prefill-phase & decode-phase - # test params, including query/key/value tensors, decoder self-attention - # memory-mapping. cross_block_base_addr is the uppermost address in the - # decoder self-attention block-table, i.e. a base address which the - # encoder/decoder cross-attention block-table may build downward toward. - - ( - dec_qkv, - prephase_dec_test_params, - decphase_dec_test_params, - cross_block_base_addr, - ) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs) - - # Construct encoder/decoder cross-attention prefill-phase - # & decode-phase test params, including key/value tensors, - # cross-attention memory-mapping - - ( - prephase_cross_test_params, - decphase_cross_test_params, - ) = _enc_dec_cross_attn_setup_reuses_query( - dec_qkv, - enc_test_params, - prephase_dec_test_params, - enc_dec_test_pt, - enc_dec_test_rsrcs, - block_base_addr=cross_block_base_addr) - - # Shared prefill metadata structure - assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, - decoder_test_params=prephase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=prephase_cross_test_params, - device=CUDA_DEVICE) - - # PREFILL: encoder attention - - enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata, - test_pt=enc_test_pt, - vllm_config=vllm_config) - - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, - attn_backend.name) - - # PREFILL: decoder self-attention test - - prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - dec_test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata, - test_pt=dec_test_pt, - vllm_config=vllm_config) - - # - Is prefill decoder self-attention correct? - assert_actual_matches_ideal(prephase_dec_test_params, - prephase_dec_pckd_act_out, - attn_backend.name) - - # PREFILL: encoder/decoder cross-attention test - - prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - enc_dec_test_rsrcs, - prephase_dec_test_params, - prephase_cross_test_params, - prephase_attn_metadata, - test_pt=enc_dec_test_pt, - vllm_config=vllm_config) - - # - Is prefill encoder/decoder cross-attention correct? - assert_actual_matches_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out, - attn_backend.name) - - # DECODE: build decode-phase attention metadata - - decphase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - False, - dec_qkv.q_seq_lens, - decoder_test_params=decphase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=decphase_cross_test_params, - device=CUDA_DEVICE) - - # DECODE: decoder self-attention test - - decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - dec_test_rsrcs, - decphase_dec_test_params, - decphase_attn_metadata, - test_pt=dec_test_pt, - vllm_config=vllm_config) - - # - Is decode-phase decoder self-attention correct? - assert_actual_matches_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out, - attn_backend.name) - - # DECODE: encoder/decoder cross-attention test - - decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - enc_dec_test_rsrcs, - decphase_dec_test_params, - None, - decphase_attn_metadata, - test_pt=enc_dec_test_pt, - vllm_config=vllm_config) - - # - Is decode-phase encoder/decoder cross-attention correct? - assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out, - attn_backend.name) diff --git a/tests/kernels/attention/test_flashinfer_mla_decode.py b/tests/kernels/attention/test_flashinfer_mla_decode.py new file mode 100644 index 0000000000000..02225432f77fc --- /dev/null +++ b/tests/kernels/attention/test_flashinfer_mla_decode.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla +from torch import Tensor + +from vllm.platforms import current_platform + +FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="FlashInfer MLA Requires compute capability of 10 or above.", + allow_module_level=True) + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[ + block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, + head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, + kv, + v, + scale=scale, + enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("bs", [1, 2, 4, 16]) +@pytest.mark.parametrize("block_size", [32, 64]) +def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int): + torch.set_default_device('cuda') + torch.manual_seed(42) + + # Deepseek R1 config + num_heads = 128 + kv_lora_rank = 512 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + qk_head_dim = kv_lora_rank + qk_rope_head_dim + scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5 + + MAX_SEQ_LEN = 1024 + + seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) + + # Generate block tables with random but unique block IDs + # From https://github.com/flashinfer-ai/flashinfer/pull/1222 + blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size + max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4) + total_blocks_needed = sum(blocks_per_seq) + # Get random unique IDs for all blocks + all_block_ids = torch.randperm(total_blocks_needed) + + block_id = 0 + block_tables = torch.zeros( + (bs, max_num_blocks_per_seq), + dtype=torch.int32, + ) + + # Populate block tables and track block assignments + block_id = 0 + for i in range(bs): + num_blocks_needed = blocks_per_seq[i] + block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id + + num_blocks_needed] + block_id += num_blocks_needed + + kv_cache = torch.randn(block_tables.numel(), block_size, + qk_head_dim).to(dtype) + q = torch.randn(bs, num_heads, qk_head_dim).to(dtype) + + out_ref = q.new_zeros(bs, num_heads, kv_lora_rank) + ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor) + + workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=q.device, + ) + # Flashinfer MLA expects the query to be of shape + # (bs, q_len_per_request, num_heads, qk_head_dim), + # where q_len_per_request is the MTP query length (=1 without MTP) + q = q.unsqueeze(1) + + out_ans = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens_tensor, + max_seq_len=max_seq_len, + bmm1_scale=scale, + ) + out_ans = out_ans.squeeze(1) + torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 8d0a11d8eb8ab..bd3ba554b32e2 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -35,6 +35,7 @@ QUANT_DTYPES = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (None, FP8_DTYPE, None), + (FP8_DTYPE, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] @@ -44,6 +45,7 @@ NUM_HEADS = [(64, 8), (40, 8)] HEAD_SIZE = [128] KV_LAYOUT = ["HND"] # currently only HND is supported BLOCK_SIZE = [16] +WINDOW_LEFT = [-1, 127] SOFT_CAP = [None, 50.0] NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @@ -57,6 +59,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @pytest.mark.parametrize("head_size", HEAD_SIZE) @pytest.mark.parametrize("kv_layout", KV_LAYOUT) @pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", SOFT_CAP) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( @@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline( head_size: int, kv_layout: str, block_size: int, + window_left: int, soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") @@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline( sm_scale=sm_scale, q_data_type=dtype, kv_data_type=dtype, + window_left=window_left, logits_soft_cap=soft_cap) output = torch.empty(ref_query.shape, dtype=dtype) @@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline( max_seq_len=max_seq_len, bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, + window_left=window_left, o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline( @pytest.mark.parametrize("head_size", HEAD_SIZE) @pytest.mark.parametrize("kv_layout", KV_LAYOUT) @pytest.mark.parametrize("block_size", BLOCK_SIZE) +@pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", [None]) @torch.inference_mode def test_flashinfer_trtllm_prefill_with_baseline( @@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( head_size: int, kv_layout: str, block_size: int, + window_left: int, soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") @@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( sm_scale=sm_scale, q_data_type=dtype, kv_data_type=dtype, + window_left=window_left, logits_soft_cap=soft_cap) output = torch.empty(ref_query.shape, dtype=dtype) @@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + window_left=window_left, o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline( rtol, atol = 4e-1, 1e0 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 5e-2, 7e-2 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: + rtol, atol = 4e-2, 6e-2 else: rtol, atol = 1e-2, 1e-2 diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 53c37554b15a3..d37b968ed9792 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -23,6 +23,9 @@ def clear_cache(): """Clear lru cache to ensure each test case runs without caching. """ _cached_get_attn_backend.cache_clear() + # Clear xformers availability cache + import vllm.attention.layer as layer_module + layer_module.USE_XFORMERS_OPS = None @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) @@ -33,22 +36,52 @@ def test_mha_attn_platform(device: str): torch.set_default_dtype(torch.float16) if device == "cpu": - with patch("vllm.attention.selector.current_platform", CpuPlatform()): + with patch("vllm.attention.layer.current_platform", CpuPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CpuPlatform()): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA elif device == "hip": - with patch("vllm.attention.selector.current_platform", RocmPlatform()): + with patch("vllm.attention.layer.current_platform", RocmPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + RocmPlatform()): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA else: - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + # Test CUDA with head_size=64 (divisible by 32) + # - should use vLLM's FlashAttention + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.XFORMERS + assert attn.attn_backend == _Backend.FLASH_ATTN - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA not available + # - should use xformers + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ + patch("vllm.attention.layer.check_upstream_fa_availability", + return_value=False): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS + # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA available + # - should use upstream FA + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ + patch("vllm.attention.layer.check_upstream_fa_availability", + return_value=True), \ + patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (), + { + 'flash_attn_varlen_func': lambda *args, **kwargs: None + })()}): + attn = MultiHeadAttention(16, 72, scale=1) + assert attn.attn_backend == _Backend.FLASH_ATTN + def ref_attention( query: torch.Tensor, diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 4b97d51e6ed21..ab91560e995c8 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -102,9 +102,6 @@ def test_triton_unified_attn( ) -> None: torch.set_default_device("cuda") - if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: - pytest.skip("block size must be at least 32 for fp8") - current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 02316ceaac735..53e6d793cf2f9 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -6,7 +6,7 @@ import torch from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -70,6 +70,37 @@ def test_rms_norm( (out, x, layer.weight.data, layer.variance_epsilon)) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_poly_norm( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + layer = PolyNorm().to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + layer.bias.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + + ref_out = layer.forward_native(x) + out = layer(x) + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + opcheck( + torch.ops._C.poly_norm, + (out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon)) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index ab6f1ccf881fd..bf9b1d9b4401a 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from itertools import accumulate, product +from itertools import product from typing import Callable, Optional import pytest @@ -111,151 +111,6 @@ def test_rotary_embedding( "expected returned key to be None" -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_key", USE_KEY) -@torch.inference_mode() -def test_batched_rotary_embedding( - is_neox_style: bool, - tensor_shape_fn: Callable[[int, int, int, int], tuple[int]], - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: Optional[int], - dtype: torch.dtype, - seed: int, - device: str, - use_key: bool, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": (1, ) - }) - rope = rope.to(dtype=dtype, device=torch.get_default_device()) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) - query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) if use_key else None - - # slice tensor if required, noop otherwise - query = query[..., :head_size] - key = key[..., :head_size] if use_key else None - - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_query, ref_key = rope.forward_native(positions, query, key) - out_query, out_key = rope.forward(positions, - query, - key, - offsets=torch.zeros(batch_size * seq_len, - dtype=torch.long, - device=device)) - # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) - if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) - else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" - - -@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_key", USE_KEY) -@torch.inference_mode() -def test_batched_rotary_embedding_multi_lora( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: Optional[int], - dtype: torch.dtype, - seed: int, - device: str, - use_key: bool, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - scaling_factors: list[int] = [1, 2, 4] - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "rope_type": "linear", - "factor": tuple(scaling_factors) - }) - rope = rope.to(dtype=dtype, device=torch.get_default_device()) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) - key = torch.randn_like(query) if use_key else None - - offset_map = torch.tensor( - list( - accumulate([0] + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ]))) - query_types = torch.randint(0, - len(scaling_factors), (batch_size, seq_len), - device=device) - query_offsets = offset_map[query_types] - - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_query, ref_key = rope.forward_native(positions, query, key, - query_offsets) - out_query, out_key = rope.forward(positions, query, key, - query_offsets.flatten()) - # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) - if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) - else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" - - @torch.inference_mode() def test_rope_module_cache(): MAX_POSITIONS = [123, 1234] diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index d1fd960bf115c..5857dd5ba3fad 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -16,20 +16,14 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding def rotary_embedding_opcheck(rot, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None): + key: Optional[torch.Tensor] = None): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - opcheck(torch.ops._C.batched_rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style, rot.rotary_dim, offsets)) - else: - opcheck(torch.ops._C.rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style)) + # ops.rotary_embedding() is a in-place operation + # that updates the query and key tensors. + opcheck(torch.ops._C.rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, + rot.is_neox_style)) @pytest.mark.parametrize("device", ["cuda"]) @@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, key = key[..., :head_size] if use_key else None rotary_embedding_opcheck(rot, positions, query, key) - offsets = torch.zeros(batch_size * seq_len, - device=device, - dtype=torch.long) - rotary_embedding_opcheck(rot, positions, query, key, offsets) # if we have a contiguous head stride, test the alternate # [..., num_heads * head_dim] shape/layout diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 2c554baaff76c..fc60d5ac82b27 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -115,21 +115,27 @@ def generate_continuous_batched_examples(example_lens_by_batch, n_heads, d_head, itype, - device='cuda'): + device='cuda', + return_naive_ref=True): # this function generates a random examples of certain length # and then cut according to "example_lens_by_batch" and feed - # them in continuous batches to the kernels + # them in continuous batches to the kernels. + # If if return_naive_ref=True, the naive torch implementation + # ssd_minimal_discrete will be used to compute and return + # reference output. # generate the full-length example A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, d_head, itype) - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), - A * dt, - B, - C, - block_len=full_length // 4) + if return_naive_ref: + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), + A * dt, + B, + C, + block_len=full_length // + 4) # internal function that outputs a cont batch of examples # given a tuple of lengths for each example in the batch @@ -179,7 +185,8 @@ def generate_continuous_batched_examples(example_lens_by_batch, IND_S = [x % full_length for x in IND_E] IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] - yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], + yield ([Y_min[s, IND_S[s]:IND_E[s]] + for s in range(num_examples)] if return_naive_ref else None, cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) @@ -324,3 +331,213 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, if clear: states[i].fill_(0.) exhausted[i] = False + + +@pytest.mark.parametrize("chunk_size", [8, 256]) +@pytest.mark.parametrize("seqlens", [ + (16, 2, 8, 13), + (270, 88, 212, 203), + (16, 20), +]) +def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): + + # This test verifies the correctness of the chunked prefill implementation + # in the mamba2 ssd kernels, by comparing concatenation (in the sequence + # dimension) of chunked results with the full sequence result. + # It is different from test_mamba_chunk_scan_cont_batch by: + # 1. Not using the naive torch implementation (ssd_minimal_discrete) to get + # reference outputs. Instead, it compares chunked kernel outputs to full + # sequence kernel outputs. This is the most straightforward way to + # assert chunked prefill correctness. + # 2. It focuses on cases where sequences change in the middle of mamba + # chunks, and not necessarily on chunk boundaries. + + max_seqlen = max(seqlens) + # This test can have larger error for longer sequences + if max_seqlen > 256: + atol, rtol = 1e-2, 5e-3 + else: + atol, rtol = 5e-3, 5e-3 + + num_sequences = len(seqlens) + n_heads = 16 + d_head = 64 + itype = torch.float32 + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: dict = {} # map: eg -> pointer to last taken sample + exhausted: dict = {} # map: eg -> boolean indicating example is exhausted + _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next( + generate_continuous_batched_examples([seqlens], + num_sequences, + max_seqlen, + last_taken, + exhausted, + n_heads, + d_head, + itype, + return_naive_ref=False)) + seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device) + device = X.device + + ## full seqlen computation + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1]) + Y_ref = torch.empty_like(X) + state_ref = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_ref, + ) + + ## chunked seqlen computation + # first chunk + chunked_seqlens = seqlens // 2 + chunked_cu_seqlens = torch.cat([ + torch.tensor([0], device=device), + torch.cumsum(chunked_seqlens, dim=0) + ], + dim=0) + chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(chunked_seqlens), device=device), + chunked_seqlens, + output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32) + chunked_input_seq_len = chunked_cu_seqlens[-1] + X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...] + dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...] + B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...] + C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...] + for i in range(num_sequences): + # fmt: off + chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 + + X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501 + dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501 + B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501 + C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501 + # fmt: on + + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1]) + Y_partial = torch.empty_like(X_chunked) + partial_state = mamba_chunk_scan_combined( + X_chunked, + dt_chunked, + A, + B_chunked, + C_chunked, + chunk_size, + D=None, + cu_seqlens=chunked_cu_seqlens, + seq_idx=chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_partial, + ) + + # remaining chunk + remaining_chunked_seqlens = seqlens - chunked_seqlens + remaining_chunked_cu_seqlens = torch.cat([ + torch.tensor([0], device=device), + torch.cumsum(remaining_chunked_seqlens, dim=0) + ], + dim=0) + remaining_chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(remaining_chunked_seqlens), device=device), + remaining_chunked_seqlens, + output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to( + torch.int32) + remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] + # fmt: off + remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + for i in range(num_sequences): + remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 + + remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 + remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501 + remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501 + remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501 + + # assert input chunking is correct + concat_chunk_f = lambda pt1, pt2, i: torch.cat([ + pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], + pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + ], + dim=1) + concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501 + # fmt: on + + assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) + assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) + assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B) + assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C) + + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + remaining_chunked_cu_seqlens, + chunk_size, + remaining_chunked_cu_seqlens[-1]) + + Y_chunked = torch.empty_like(remaining_X_chunked) + state_chunked = mamba_chunk_scan_combined( + remaining_X_chunked, + remaining_dt_chunked, + A, + remaining_B_chunked, + remaining_C_chunked, + chunk_size, + D=None, + cu_seqlens=remaining_chunked_cu_seqlens, + seq_idx=remaining_chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=partial_state, + out=Y_chunked, + ) + Y = concat_batch_f(Y_partial, Y_chunked) + + # kernel chunked is same as kernel overall + for i in range(num_sequences): + Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + torch.testing.assert_close( + Y_seq[:, :chunked_seqlens[i], ...], + Y_ref_seq[:, :chunked_seqlens[i], ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023 + torch.testing.assert_close( + Y_seq[:, chunked_seqlens[i]:, ...], + Y_ref_seq[:, chunked_seqlens[i]:, ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023 + + state_seq = state_chunked[i] + state_seq_ref = state_ref[i] + torch.testing.assert_close( + state_seq, + state_seq_ref, + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} state " + x) # noqa: B023 diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index a10666b6ec9a7..b5fcc4cd70bf8 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx -from .mk_objects import (expert_info, make_fused_experts, +from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts, make_prepare_finalize, prepare_finalize_info) from .parallel_utils import ProcessGroupInfo @@ -40,7 +40,7 @@ class Config: E: int topks: Union[list[int], int] dtype: torch.dtype - quant_config: Optional[FusedMoEQuantConfig] + quant_config: Optional[TestMoEQuantConfig] prepare_finalize_type: mk.FusedMoEPrepareAndFinalize fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute @@ -52,7 +52,7 @@ class Config: def __post_init__(self): if self.quant_config is None: - self.quant_config = FusedMoEQuantConfig() + self.quant_config = TestMoEQuantConfig(None, False, False, None) def describe(self) -> str: s = "" @@ -275,21 +275,19 @@ class WeightTensors: or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8) def to_current_device(self): - self.w1 = self.w1.to(device=torch.cuda.current_device()) - self.w2 = self.w2.to(device=torch.cuda.current_device()) + device = torch.cuda.current_device() + self.w1 = self.w1.to(device=device) + self.w2 = self.w2.to(device=device) - if self.is_quantized(): - assert self.w1_scale is not None - assert self.w2_scale is not None - self.w1_scale = self.w1_scale.to( - device=torch.cuda.current_device()) - self.w2_scale = self.w2_scale.to( - device=torch.cuda.current_device()) + if self.w1_scale is not None: + self.w1_scale = self.w1_scale.to(device=device) + if self.w2_scale is not None: + self.w2_scale = self.w2_scale.to(device=device) if self.w1_gs is not None: - assert self.w2_gs is not None - self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device()) - self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device()) + self.w1_gs = self.w1_gs.to(device=device) + if self.w2_gs is not None: + self.w2_gs = self.w2_gs.to(device=device) def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors": @@ -297,20 +295,12 @@ class WeightTensors: e = s + num_local_experts w1 = self.w1[s:e, :, :] w2 = self.w2[s:e, :, :] - - w1_scale, w2_scale = (None, None) - if self.is_quantized(): - assert self.w1_scale is not None - assert self.w2_scale is not None - w1_scale = self.w1_scale[s:e, :, :] - w2_scale = self.w2_scale[s:e, :, :] - - w1_gs = self.w1_gs - w2_gs = self.w2_gs - if w1_gs is not None: - assert w2_gs is not None - w1_gs = w1_gs[s:e] - w2_gs = w2_gs[s:e] + w1_scale = self.w1_scale[ + s:e, :, :] if self.w1_scale is not None else None + w2_scale = self.w2_scale[ + s:e, :, :] if self.w2_scale is not None else None + w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None + w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs) @@ -323,7 +313,8 @@ class WeightTensors: in_dtype=config.dtype, quant_dtype=config.quant_dtype, block_shape=config.quant_block_shape, - per_act_token_quant=config.is_per_out_ch_quant, + per_out_ch_quant=config. + is_per_act_token_quant, # or config.is_per_out_ch_quant ) return WeightTensors(w1=w1, w2=w2, @@ -342,8 +333,6 @@ class RankTensors: topk_ids: torch.Tensor expert_map: Optional[torch.Tensor] - quant_config: Optional[FusedMoEQuantConfig] - def describe(self): s = "" s += "== Rank Tensors: \n" @@ -426,7 +415,6 @@ class RankTensors: topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, - quant_config=config.quant_config, ) @@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors, and config.supports_apply_weight_on_input()) +def _make_gscale(num_experts: int) -> torch.Tensor: + return torch.ones((num_experts, ), + device=torch.cuda.current_device(), + dtype=torch.float32) + + def make_modular_kernel( config: Config, vllm_config: VllmConfig, - weights: WeightTensors, + quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEModularKernel: def next_power_of_2(x): @@ -548,20 +542,20 @@ def make_modular_kernel( num_local_experts=config.num_local_experts, moe_parallel_config=moe_parallel_config, in_dtype=config.dtype, - quant_config=config.quant_config, max_num_tokens=next_power_of_2(config.M), ) # make modular kernel prepare_finalize = make_prepare_finalize(config.prepare_finalize_type, - config.all2all_backend(), moe) + config.all2all_backend(), moe, + quant_config) fused_experts = make_fused_experts( config.fused_experts_type, moe, + quant_config, prepare_finalize.num_dispatchers(), - weights.w1_gs, - weights.w2_gs, + config.N, ) modular_kernel = mk.FusedMoEModularKernel( @@ -583,12 +577,38 @@ def run_modular_kernel( # weights for rank rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) - mk = make_modular_kernel(config, vllm_config, weights) + if config.quant_dtype == "nvfp4": + gscale = _make_gscale(config.num_local_experts) + else: + gscale = None + + quant_config = FusedMoEQuantConfig.make( + config.quant_dtype, + w1_scale=rank_weights.w1_scale, + w2_scale=rank_weights.w2_scale, + a1_scale=rank_tensors.hidden_states_scale, + g1_alphas=(1 / rank_weights.w1_gs) + if rank_weights.w1_gs is not None else None, + g2_alphas=(1 / rank_weights.w2_gs) + if rank_weights.w2_gs is not None else None, + a1_gscale=gscale, + a2_gscale=gscale, + block_shape=config.quant_block_shape, + per_act_token_quant=config.is_per_act_token_quant, + per_out_ch_quant=config.is_per_out_ch_quant, + ) + + mk = make_modular_kernel(config, vllm_config, quant_config) + + # impls might update the tensor in place + hidden_states = rank_tensors.hidden_states.clone() + + topk_ids = rank_tensors.topk_ids.to( + mk.prepare_finalize.topk_indices_dtype()) mk_kwargs = { "hidden_states": - rank_tensors.hidden_states.clone( - ), # impls might update the tensor in place + hidden_states, "w1": rank_weights.w1, "w2": @@ -596,15 +616,9 @@ def run_modular_kernel( "topk_weights": rank_tensors.topk_weights, "topk_ids": - rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()), + topk_ids, "expert_map": rank_tensors.expert_map, - "w1_scale": - rank_weights.w1_scale, - "w2_scale": - rank_weights.w2_scale, - "a1_scale": - rank_tensors.hidden_states_scale, "global_num_experts": config.E, "apply_router_weight_on_input": diff --git a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py index 5dbfdfc153f9f..c1037b60bf383 100644 --- a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py +++ b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py @@ -10,7 +10,8 @@ import torch from tqdm import tqdm from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG) from vllm.platforms import current_platform from .common import (Config, RankTensors, WeightTensors, reference_moe_impl, @@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str): quant_config_dict = config_dict['quant_config'] del config_dict['quant_config'] if quant_config_dict is None: - quant_config = FusedMoEQuantConfig(None) + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config_dict = asdict(quant_config) config_dict |= quant_config_dict diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index aecffae36ae5e..7947391d03483 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +@dataclass +class TestMoEQuantConfig: + quant_dtype: Union[torch.dtype, str, None] + per_out_ch_quant: bool + per_act_token_quant: bool + block_shape: Optional[list[int]] + + @dataclass class PrepareFinalizeInfo: activation_format: mk.FusedMoEActivationFormat @@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [ torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32 ] common_float_and_int_types = common_float_types + [torch.int8] -nv_fp4_types = ["nvfp4"] +nvfp4_types = ["nvfp4"] fp8_types = [torch.float8_e4m3fn] @@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe() register_prepare_and_finalize( FlashInferCutlassMoEPrepareAndFinalize, standard_format, - nv_fp4_types, + nvfp4_types, blocked_quantization_support=True, backend=None, force_multigpu=True, @@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe() register_experts( FlashInferExperts, standard_format, - nv_fp4_types, + nvfp4_types, blocked_quantization_support=True, supports_chunking=True, # Note: this is a hack to get it to run for now @@ -306,39 +314,39 @@ if cutlass_fp4_supported(): register_experts( CutlassExpertsFp4, standard_format, - nv_fp4_types, + nvfp4_types, blocked_quantization_support=True, supports_chunking=True, supports_expert_map=False, ) -MK_QUANT_CONFIGS = [ +MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [ None, # per-channel / per-column weights and per-tensor activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=False, + block_shape=None), # per-channel / per-column weights and per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=True, - block_shape=None), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=True, + block_shape=None), # per-tensor weights and per-tensor activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None), # per-tensor weights and per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=True, - block_shape=None), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=True, + block_shape=None), # block-quantized weights and 128 block per-token activations - FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=[128, 128]), + TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=[128, 128]), # TODO (varun) : Should we test the following combinations ? # block-quantized weights and per-token activations # block-quantized weights and per-tensor activations @@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): MK_QUANT_CONFIGS += [ - FusedMoEQuantConfig(quant_dtype="nvfp4", - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig(quant_dtype="nvfp4", + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None), ] -def _make_gscale(num_experts: int) -> torch.Tensor: - return torch.ones((num_experts, ), - device=torch.cuda.current_device(), - dtype=torch.float32) - - def make_prepare_finalize( prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, backend: Optional[str], moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEPrepareAndFinalize: if backend != "naive" and backend is not None: - prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize( + moe, quant_config) assert prepare_finalize is not None return prepare_finalize elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: return FlashInferCutlassMoEPrepareAndFinalize( - use_dp=moe.moe_parallel_config.dp_size > 1, - a1_gscale=_make_gscale(moe.num_local_experts), - ) + use_dp=moe.moe_parallel_config.dp_size > 1) else: return MoEPrepareAndFinalizeNoEP() @@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: return t[s:e] +def make_cutlass_strides( + e: int, + n: int, + k: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + return ab_strides1, ab_strides2, c_strides1, c_strides2 + + def make_fused_experts( fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, num_dispatchers: int, - w1_gs: Optional[torch.Tensor], - w2_gs: Optional[torch.Tensor], + N: int, ) -> mk.FusedMoEPermuteExpertsUnpermute: - use_fp8 = moe.quant_dtype == torch.float8_e4m3fn batch_kwargs = { "max_num_tokens": moe.max_num_tokens, "num_dispatchers": num_dispatchers, } quant_kwargs = { - "use_fp8_w8a8": use_fp8, - "use_int8_w8a8": False, - "use_int8_w8a16": False, - "use_int4_w4a16": False, - "block_shape": moe.block_shape, - "per_act_token_quant": moe.per_act_token_quant, + "quant_config": quant_config, } deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} + torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000) + if fused_experts_type == BatchedDeepGemmExperts: - kwargs = batch_kwargs | { - "block_shape": moe.block_shape, - "per_act_token_quant": moe.per_act_token_quant, - } + kwargs = batch_kwargs | quant_kwargs print(f"Making BatchedDeepGemmExperts {kwargs} ...") experts = BatchedDeepGemmExperts(**kwargs) elif fused_experts_type == BatchedTritonExperts: @@ -422,8 +429,8 @@ def make_fused_experts( print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") experts = BatchedTritonOrDeepGemmExperts(**kwargs) elif fused_experts_type == DeepGemmExperts: - print("Making DeepGemmExperts () ...") - experts = DeepGemmExperts() + print("Making DeepGemmExperts {quant_config} ...") + experts = DeepGemmExperts(quant_config) elif fused_experts_type == TritonExperts: kwargs = quant_kwargs print(f"Making TritonExperts {kwargs} ...") @@ -437,62 +444,50 @@ def make_fused_experts( print(f"Making NaiveBatchedExperts {kwargs} ...") experts = NaiveBatchedExperts(**kwargs) elif fused_experts_type == CutlassExpertsFp8: + strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) kwargs = { "out_dtype": moe.in_dtype, - "per_act_token_quant": moe.per_act_token_quant, - "per_out_ch_quant": moe.per_out_ch_quant, - "block_shape": moe.block_shape, - } + "ab_strides1": strides[0], + "ab_strides2": strides[1], + "c_strides1": strides[2], + "c_strides2": strides[3], + } | quant_kwargs print(f"Making CutlassExpertsFp8 {kwargs} ...") experts = CutlassExpertsFp8(**kwargs) elif fused_experts_type == CutlassBatchedExpertsFp8: + strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim) kwargs = { "max_experts_per_worker": moe.num_local_experts, "num_dispatchers": num_dispatchers, "out_dtype": moe.in_dtype, - "per_act_token_quant": moe.per_act_token_quant, - "per_out_ch_quant": moe.per_out_ch_quant, - "block_shape": moe.block_shape, - } + "ab_strides1": strides[0], + "ab_strides2": strides[1], + "c_strides1": strides[2], + "c_strides2": strides[3], + } | quant_kwargs print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...") experts = CutlassBatchedExpertsFp8(**kwargs) elif fused_experts_type == CutlassExpertsFp4: - assert w1_gs is not None and w2_gs is not None - num_experts = moe.num_local_experts - rank = moe.moe_parallel_config.dp_rank kwargs = { - "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), - "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), - "a1_gscale": _make_gscale(num_experts), - "a2_gscale": _make_gscale(num_experts), - "max_experts_per_worker": num_experts, - "out_dtype": moe.in_dtype, - "per_act_token_quant": moe.per_act_token_quant, - "per_out_ch_quant": moe.per_out_ch_quant, - "block_shape": moe.block_shape, + "max_experts_per_worker": moe.num_local_experts, "num_dispatchers": num_dispatchers, - } + "out_dtype": moe.in_dtype, + } | quant_kwargs print(f"Making CutlassExpertsFp4 {kwargs} ...") experts = CutlassExpertsFp4(**kwargs) elif fused_experts_type == FlashInferExperts: - assert w1_gs is not None and w2_gs is not None - num_experts = moe.num_local_experts - rank = moe.moe_parallel_config.dp_rank kwargs = { - "g1_alphas": _slice(rank, num_experts, (1 / w1_gs)), - "g2_alphas": _slice(rank, num_experts, (1 / w2_gs)), - "a1_gscale": _make_gscale(num_experts), - "a2_gscale": _make_gscale(num_experts), "out_dtype": moe.in_dtype, - "quant_dtype": "nvfp4", "ep_rank": moe.ep_rank, "ep_size": moe.ep_size, "tp_rank": moe.tp_rank, "tp_size": moe.tp_size, - } + } | quant_kwargs print(f"Making FlashInferExperts {kwargs} ...") experts = FlashInferExperts(**kwargs) else: raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}") + torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80) + return experts diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index 018d4c224f75e..afec97e8cffd0 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -6,6 +6,8 @@ import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -56,13 +58,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, rank=0, ) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_s, + w2_scale=w2_s, + per_act_token_quant=False, + block_shape=BLOCK_SIZE, + ) + # triton (reference) triton_experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - use_fp8_w8a8=True, - per_act_token_quant=False, - block_shape=BLOCK_SIZE, + quant_config=quant_config, ) mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts) @@ -73,8 +80,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - w1_scale=w1_s, - w2_scale=w2_s, global_num_experts=E, ) @@ -82,8 +87,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, deepgemm_experts = BatchedDeepGemmExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - block_shape=BLOCK_SIZE, - per_act_token_quant=False, + quant_config=quant_config, ) mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts) @@ -94,8 +98,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - w1_scale=w1_s, - w2_scale=w2_s, global_num_experts=E, ) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 00b2d780e66f5..7e79828937c77 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -140,7 +140,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, in_dtype=act_dtype, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) out_shape = (num_experts, max_tokens_per_expert, N) @@ -250,7 +250,7 @@ def test_fused_moe_batched_experts( block_shape=block_shape, in_dtype=act_dtype, quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) if input_scales and quant_dtype is not None: diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index ecc57acc67963..da383e18c3721 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -4,7 +4,7 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_weights +from tests.kernels.moe.utils import make_test_quant_config, make_test_weights from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, native_w8a8_block_matmul) from vllm.config import VllmConfig, set_current_vllm_config @@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - (_, w1, w1_s, _), (_, w2, w2_s, - _) = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + w1, w2, quant_config = make_test_quant_config( + E, + N, + K, + dtype, + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=block_size, + ) - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - use_mxfp4_w4a4=False, - per_act_token_quant=False, - block_shape=block_size) + m_fused_moe = modular_triton_fused_moe(quant_config) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) @@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, a, w1, w2, - w1_s, - w2_s, + quant_config.w1_scale, + quant_config.w2_scale, topk_weights, topk_ids, block_size, ) - out = fused_experts( - a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) + out = fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config) - m_out = m_fused_moe( - a, - w1, - w2, - topk_weights, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s, - ) + m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids) - # 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0] - tol = 0.035 if M < 40000 else 0.039 + # 0.039 only needed for M >= 8192 + tol = 0.035 if M < 8192 else 0.039 torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol) torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol) @@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) - (_, w1, w1_s, _), (_, w2, w2_s, - _) = make_test_weights(E, - N, - K, - dtype, - torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=block_size) + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( + E, + N, + K, + dtype, + torch.float8_e4m3fn, + per_out_ch_quant=False, + block_shape=block_size, + ) # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 5e4a93963f8e8..041a13ca5585a 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -4,12 +4,12 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_weights +from tests.kernels.moe.utils import make_test_quant_config from tests.kernels.quant_utils import (native_per_token_group_quant_int8, native_w8a8_block_matmul) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): @@ -50,7 +50,7 @@ MNK_FACTORS = [ (2048, 128, 128), (2048, 1024, 7168), (2048, 4096, 512), - (2048, 4096, 7168), + (2048, 4096, 4096), ] E = [8, 24] @@ -117,31 +117,28 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): a = torch.randn((M, K), dtype=dtype) / 10 score = torch.randn((M, E), dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) - (_, w1, w1_s, _), (_, w2, w2_s, - _) = make_test_weights(E, - N, - K, - dtype, - torch.int8, - per_act_token_quant=False, - block_shape=block_size) + w1, w2, quant_config = make_test_quant_config( + E, + N, + K, + dtype, + quant_dtype=torch.int8, + per_act_token_quant=False, + block_shape=block_size, + ) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_int8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, + out = fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + quant_config=quant_config) + ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale, + quant_config.w2_scale, score, topk, block_size) # Check results diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index c84f66383b902..ca6be767dab39 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import dataclasses from math import prod from typing import Optional @@ -9,6 +10,8 @@ import torch from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp8, run_cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, @@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, def slice_experts(): slice_params = [ "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", - "c_strides2", "w1_scale", "w2_scale" + "c_strides2" ] full_tensors = { k: v @@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, if k in slice_params and k in cutlass_moe_kwargs } + quant_config = cutlass_moe_kwargs["quant_config"] + for i in range(0, num_experts, num_local_experts): s, e = i, i + num_local_experts @@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int, for k, t in full_tensors.items(): cutlass_moe_kwargs[k] = t[s:e] + new_quant_config = copy.deepcopy(quant_config) + new_quant_config._w1.scale = quant_config.w1_scale[s:e] + new_quant_config._w2.scale = quant_config.w2_scale[s:e] + + cutlass_moe_kwargs["quant_config"] = new_quant_config + yield cutlass_moe_kwargs out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) @@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, + per_out_ch: bool, num_local_experts: Optional[int] = None) -> torch.Tensor: assert not any([ t is None for t in [ @@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ] ]) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=moe_tensors.w1_scale, + w2_scale=moe_tensors.w2_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + # Set to moe_tensors.a_scale iff static scales + per tensor. + # This is not currently being tested. + a1_scale=None, + ) + kwargs = { 'a': moe_tensors.a, 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr] 'topk_weights': topk_weights, 'topk_ids': topk_ids, - 'w1_scale': moe_tensors.w1_scale, - 'w2_scale': moe_tensors.w2_scale, 'ab_strides1': moe_tensors.ab_strides1, 'ab_strides2': moe_tensors.ab_strides2, 'c_strides1': moe_tensors.c_strides1, 'c_strides2': moe_tensors.c_strides2, - 'per_act_token': per_act_token, - 'a1_scale': None #moe_tensors.a_scale + 'quant_config': quant_config, } num_experts = moe_tensors.w1.size(0) @@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph( # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) + + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + triton_output = fused_experts(mt.a_d, + mt.w1_d, + mt.w2_d, + topk_weights, + topk_ids, + quant_config=quant_config) if ep_size is not None: assert e % ep_size == 0, "Cannot distribute experts evenly" number_local_experts = e // ep_size else: number_local_experts = None + cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token, - number_local_experts) + per_out_ch, number_local_experts) # Note 5.5 only needed for larger problem sizes, 5 works ok for # the rest. @@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph( # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + triton_output = fused_experts(mt.a_d, + mt.w1_d, + mt.w2_d, + topk_weights, + topk_ids, + quant_config=quant_config) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): cutlass_output = run_8_bit(mt, topk_weights, topk_ids, - per_act_token) + per_act_token, per_out_ch) torch.cuda.synchronize() graph.replay() diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 6558cab6a9eff..ced5457d4f53b 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup from typing_extensions import ParamSpec from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) @@ -71,9 +73,12 @@ def make_block_quant_fp8_weights( Return weights w1q, w2q, w1_scale, w2_scale """ (_, w1q, w1_scale, _), (_, w2q, w2_scale, - _) = make_test_weights(e, n, k, torch.bfloat16, + _) = make_test_weights(e, + n, + k, + torch.bfloat16, torch.float8_e4m3fn, - block_size) + block_shape=block_size) return w1q, w2q, w1_scale, w2_scale @@ -130,10 +135,11 @@ class TestTensors: config=config) -def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, - max_tokens_per_rank: int, dp_size: int, - hidden_size: int, q_dtype: Optional[torch.dtype], - test_config: TestConfig) -> FusedMoEModularKernel: +def make_ll_modular_kernel( + pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int, + dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype], + test_config: TestConfig, + quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel: assert test_config.low_latency assert test_config.use_fp8_dispatch is not None @@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, fused_experts = BatchedDeepGemmExperts( max_num_tokens=max_tokens_per_rank, num_dispatchers=pgi.world_size // dp_size, - block_shape=test_config.block_size, - per_act_token_quant=test_config.per_act_token_quant) + quant_config=quant_config, + ) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk -def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, - dp_size: int, num_local_experts: int, - q_dtype: Optional[torch.dtype], - test_config: TestConfig) -> FusedMoEModularKernel: +def make_ht_modular_kernel( + pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, + num_local_experts: int, q_dtype: Optional[torch.dtype], + test_config: TestConfig, + quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel: assert not test_config.low_latency assert test_config.use_fp8_dispatch is None @@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, q_dtype=q_dtype, block_shape=test_config.block_size) - fused_experts = DeepGemmExperts() + fused_experts = DeepGemmExperts(quant_config) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk -def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, - num_local_experts: int, - test_tensors: TestTensors) -> FusedMoEModularKernel: +def make_modular_kernel( + pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, + num_local_experts: int, test_tensors: TestTensors, + quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel: q_dtype = torch.float8_e4m3fn test_config = test_tensors.config @@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, dp_size=dp_size, hidden_size=hidden_size, q_dtype=q_dtype, - test_config=test_config) + test_config=test_config, + quant_config=quant_config) else: - mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts, - q_dtype, test_config) + mk = make_ht_modular_kernel(pg, + pgi, + dp_size, + num_local_experts, + q_dtype, + test_config, + quant_config=quant_config) return mk @@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + # Low-Latency kernels can't dispatch scales. + a1_scale=(None if test_config.low_latency else + test_tensors.rank_token_scales), + block_shape=test_config.block_size, + ) + # Make modular kernel mk: FusedMoEModularKernel = make_modular_kernel( pg=pg, pgi=pgi, dp_size=dp_size, num_local_experts=num_local_experts, - test_tensors=test_tensors) - - # Low-Latency kernels can't dispatch scales. - a1_scale = (None - if test_config.low_latency else test_tensors.rank_token_scales) + test_tensors=test_tensors, + quant_config=quant_config) out = mk.forward(hidden_states=test_tensors.rank_tokens, w1=w1, @@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, activation="silu", global_num_experts=num_experts, expert_map=build_expert_map(), - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=None, - w2_zp=None, - a1_scale=a1_scale, - a2_scale=None, apply_router_weight_on_input=False) return out @@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, a1_scale: torch.Tensor, block_shape: list[int]): + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + block_shape=block_shape, + ) + return fused_experts( hidden_states=a, w1=w1, @@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - block_shape=block_shape, + quant_config=quant_config, # Make sure this is set to False so we # don't end up comparing the same implementation. allow_deep_gemm=False) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 6a53af68cd53a..54d3a62b03fcc 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -15,6 +15,7 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -129,11 +130,9 @@ def make_modular_kernel( num_local_experts: int, q_dtype: Optional[torch.dtype], use_fp8_dispatch: bool, - per_act_token_quant: bool, + quant_config: FusedMoEQuantConfig, ) -> FusedMoEModularKernel: - is_quantized = q_dtype is not None - ht_args: Optional[DeepEPHTArgs] = None ll_args: Optional[DeepEPLLArgs] = None @@ -159,24 +158,14 @@ def make_modular_kernel( num_dispatchers = pgi.world_size // dp_size if low_latency_mode: - assert not per_act_token_quant, "not supported in ll mode" + assert not quant_config.per_act_token_quant, "not supported in ll mode" fused_experts = BatchedTritonExperts( max_num_tokens=MAX_TOKENS_PER_RANK, num_dispatchers=num_dispatchers, - use_fp8_w8a8=is_quantized, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_act_token_quant=False, + quant_config=quant_config, ) else: - fused_experts = TritonExperts( - use_fp8_w8a8=is_quantized, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_act_token_quant=per_act_token_quant, - ) + fused_experts = TritonExperts(quant_config=quant_config) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) @@ -217,11 +206,6 @@ def deep_ep_moe_impl( if is_quantized: q_dtype = torch.float8_e4m3fn - # Make modular kernel - mk: FusedMoEModularKernel = make_modular_kernel( - pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts, - num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant) - out_hidden_states = torch.empty_like(test_tensors.rank_tokens) total_num_tokens = test_tensors.rank_tokens.size(0) @@ -236,6 +220,19 @@ def deep_ep_moe_impl( rank_token_scales_chunk = rank_token_scales_chunk[ chunk_start:chunk_end] + quant_config = FusedMoEQuantConfig.make( + q_dtype, + w1_scale=w1_scale, + w2_scale=w2_scale, + per_act_token_quant=per_act_token_quant, + a1_scale=rank_token_scales_chunk, + ) + + # Make modular kernel + mk: FusedMoEModularKernel = make_modular_kernel( + pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts, + num_local_experts, q_dtype, use_fp8_dispatch, quant_config) + out = mk.forward(hidden_states=rank_tokens_chunk, w1=w1, w2=w2, @@ -245,12 +242,6 @@ def deep_ep_moe_impl( activation="silu", global_num_experts=num_experts, expert_map=build_expert_map(), - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=None, - w2_zp=None, - a1_scale=rank_token_scales_chunk, - a2_scale=None, apply_router_weight_on_input=False) if not skip_result_store: @@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("m,n,k", MNKs) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn] @requires_deep_ep def test_deep_ep_moe( dtype: torch.dtype, - mnk: tuple[int, int, int], + m: int, + n: int, + k: int, num_experts: int, topk: int, world_dp_size: tuple[int, int], @@ -424,7 +417,6 @@ def test_deep_ep_moe( ): low_latency_mode = False use_fp8_dispatch = False - m, n, k = mnk current_platform.seed_everything(7) world_size, dp_size = world_dp_size @@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False] @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("m,n,k", MNKs) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @multi_gpu_test(num_gpus=2) @requires_deep_ep -def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], - num_experts: int, topk: int, - world_dp_size: tuple[int, int], - use_fp8_dispatch: bool): - +def test_low_latency_deep_ep_moe( + dtype: torch.dtype, + m: int, + n: int, + k: int, + num_experts: int, + topk: int, + world_dp_size: tuple[int, int], + use_fp8_dispatch: bool, +): low_latency_mode = True - m, n, k = mnk if (low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES): diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 4472f34a6291a..d575b6d4ca62c 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -11,6 +11,8 @@ import math import pytest import torch +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) # vLLM fused-expert reference (Triton fallback + DeepGEMM option) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_s, + w2_scale=w2_s, + a1_scale=a1_scale, + block_shape=block_size, + ) + # triton reference out_triton = fused_experts( hidden_states=tokens_bf16, @@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - block_shape=block_size, + quant_config=quant_config, allow_deep_gemm=False, ) @@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - a1_scale=a1_scale, - block_shape=block_size, + quant_config=quant_config, allow_deep_gemm=True, ) diff = calc_diff(out_deepgemm, out_triton) assert diff < 0.001, f"Diff exceeded 1%: {diff}" -# Note: W1 has shape (E, 2N, K), so N = 512 -# can trigger the deepgemm path. +# Note: N <= 512 will disable the deepgemm path due to performance issues. MNKs = [ (1024, 768, 128), (1024, 768, 512), @@ -144,15 +144,15 @@ TOPKS = [2, 6] NUM_EXPERTS = [32] -@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize(("m", "n", "k"), MNKs) @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") -def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): +def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_DEEP_GEMM", "1") + with monkeypatch.context() as mp: + mp.setenv("VLLM_USE_DEEP_GEMM", "1") _fused_moe_mod = importlib.import_module( "vllm.model_executor.layers.fused_moe.fused_moe") @@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch): monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8) - m, n, k = mnk - if topk > num_experts: pytest.skip(f"topk={topk} > num_experts={num_experts}") diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 52a3d2ca3b422..5564db3cda0e3 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -6,6 +6,8 @@ import pytest import torch from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( @@ -145,6 +147,14 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( custom_routing_function=Llama4MoE.custom_routing_function, scoring_func="softmax") + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + per_act_token_quant=False, + ) + output = fused_experts( td.hidden_states, td.w13_quantized, @@ -153,15 +163,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( topk_ids=topk_ids, inplace=False, activation="silu", - use_fp8_w8a8=True, - per_channel_quant=False, global_num_experts=e, expert_map=None, - w1_scale=td.w13_weight_scale, - w2_scale=td.w2_weight_scale, - a1_scale=td.a1_scale, - a2_scale=td.a2_scale, apply_router_weight_on_input=True, + quant_config=quant_config, ) flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( @@ -210,6 +215,14 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( custom_routing_function=Llama4MoE.custom_routing_function, scoring_func="softmax") + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=td.w13_weight_scale, + w2_scale=td.w2_weight_scale, + a1_scale=td.a1_scale, + a2_scale=td.a2_scale, + per_act_token_quant=False, + ) + output = fused_experts( td.hidden_states, td.w13_quantized, @@ -218,15 +231,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( topk_ids=topk_ids, inplace=False, activation="silu", - use_fp8_w8a8=True, - per_channel_quant=False, global_num_experts=e, expert_map=None, - w1_scale=td.w13_weight_scale, - w2_scale=td.w2_weight_scale, - a1_scale=td.a1_scale, - a2_scale=td.a2_scale, apply_router_weight_on_input=True, + quant_config=quant_config, ) td.layer.dp_size = 1 diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 1c14df2b914aa..8bf096b798cb8 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -3,7 +3,7 @@ import pytest import torch -from tests.kernels.moe.utils import make_test_weights +from tests.kernels.moe.utils import make_test_quant_config from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype) @@ -41,7 +41,6 @@ MNK_FACTORS = [ @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", [40, 64, 256]) -#@pytest.mark.parametrize("e", [128, 256]) @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @torch.inference_mode() @@ -56,16 +55,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, quant_blocksize = 16 - (_, w1_q, w1_blockscale, - w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( - e, - n, - k, - in_dtype=dtype, - quant_dtype="nvfp4", - block_shape=None, # use quant_blocksize? - per_act_token_quant=False, - ) + w1_q, w2_q, quant_config = make_test_quant_config( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, + per_act_token_quant=False, + ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, @@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, topk, renormalize=False) - a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) - a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) - assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) - assert w1_gs is not None - assert w2_gs is not None - assert w1_blockscale is not None - assert w2_blockscale is not None - flashinfer_experts = FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - FlashInferExperts( - a1_gscale=a1_gs, - g1_alphas=(1 / w1_gs), - a2_gscale=a2_gs, - g2_alphas=(1 / w2_gs), - out_dtype=dtype, - quant_dtype="nvfp4", - )) + FlashInferExperts(out_dtype=dtype, quant_config=quant_config), + ) flashinfer_output = flashinfer_experts( hidden_states=a, w1=w1_q, - w1_scale=w1_blockscale, w2=w2_q, - w2_scale=w2_blockscale, - a1_scale=a1_gs, - a2_scale=a2_gs, topk_weights=topk_weights, topk_ids=topk_ids, ) @@ -122,18 +102,18 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): - w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], - w1_blockscale[idx], - w1_gs[idx], - dtype=dtype, - device=w1_q.device, - block_size=quant_blocksize) - w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=dtype, - device=w2_q.device, - block_size=quant_blocksize) + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]), + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]), + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 54f2351bf6d9b..024993c7677dd 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.testing import assert_close +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): pc2, ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) + quant_config = FusedMoEQuantConfig.make( + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2, + ) + out_triton_monolithic = triton_kernel_moe_forward( hidden_states=x_tri, w1=w1_tri, @@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): gating_output=exp_data_tri, topk=topk, renormalize=True, - w1_bias=w1_bias_tri, - w2_bias=w2_bias_tri, - w1_precision=pc1, - w2_precision=pc2, + quant_config=quant_config, ) out_triton_monolithic = out_triton_monolithic[..., :K] @@ -336,6 +341,13 @@ def batched_moe( ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) + quant_config = FusedMoEQuantConfig.make( + w1_precision=w1_precision, + w2_precision=w2_precision, + w1_bias=w1_bias, + w2_bias=w2_bias, + ) + fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize( max_num_tokens, @@ -344,19 +356,12 @@ def batched_moe( rank=0, ), BatchedOAITritonExperts( - None, max_num_tokens=max_num_tokens, num_dispatchers=1, - w1_precision=w1_precision, - w2_precision=w2_precision, + quant_config=quant_config, ), ) - extra_expert_args = { - "w1_bias": w1_bias, - "w2_bias": w2_bias, - } - topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) return fused_experts( @@ -365,7 +370,6 @@ def batched_moe( w2, topk_weight, topk_ids, - extra_expert_args=extra_expert_args, ) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 6112183be5475..19c4301bd23d5 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -12,7 +12,6 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, current_platform, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe @@ -22,7 +21,8 @@ from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, run_modular_kernel) from .modular_kernel_tools.mk_objects import ( MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, - MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info) + MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig, + expert_info) from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, parallel_launch_with_config) @@ -55,7 +55,7 @@ def rank_worker( pgi: ProcessGroupInfo, vllm_config: VllmConfig, cpu_group, - config: Config, + base_config: Config, weights: WeightTensors, verbose: bool, ): @@ -63,42 +63,44 @@ def rank_worker( # sanity check from vllm import envs - if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + if base_config.fused_moe_chunk_size is not None: + assert ( + base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) # get weights to this device weights.to_current_device() - Ms = config.Ms + Ms = base_config.Ms assert isinstance(Ms, list) - TOPKs = config.topks + TOPKs = base_config.topks assert isinstance(TOPKs, list) exceptions = [] count = 0 for m, topk in product(Ms, TOPKs): + # override m and topk + config = copy.deepcopy(base_config) + config.Ms = m + config.topks = topk + try: print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...") count = count + 1 - # override m and topk - cfgx = copy.deepcopy(config) - cfgx.Ms = m - cfgx.topks = topk # inputs for rank - rank_tensors = RankTensors.make(cfgx, pgi) + rank_tensors = RankTensors.make(config, pgi) # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, + mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors) with set_current_vllm_config(vllm_config): - ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + ref_out = reference_moe_impl(config, weights, rank_tensors) if config.quant_dtype == "nvfp4": - atol = 1e-1 - rtol = 1e-1 + atol = 1e-1 if config.K < 4096 else 2e-1 + rtol = 1e-1 if config.K < 4096 else 2e-1 else: atol = 3e-2 rtol = 3e-2 @@ -132,7 +134,7 @@ Ms = [32, 64] # hidden sizes, making this too large will cause fp4 tests to fail. # Also needs to be a multiple of 1024 for deep_gemm. Ks = [2048] -Ns = [2048] +Ns = [1024] TOPKs = [4, 1] Es = [32] DTYPEs = [torch.bfloat16] @@ -167,7 +169,7 @@ def is_nyi_config(config: Config) -> bool: @meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[FusedMoEQuantConfig], + quant_config: Optional[TestMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): @@ -208,7 +210,7 @@ def test_modular_kernel_combinations_multigpu( @pytest.mark.parametrize("world_size", [1]) def test_modular_kernel_combinations_singlegpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[FusedMoEQuantConfig], + quant_config: Optional[TestMoEQuantConfig], combination: tuple[mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute], fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 0ea9667914fd5..00835bec9a15c 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -15,11 +15,14 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.moe.utils import fused_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( @@ -187,14 +190,9 @@ def test_fused_moe( # # Setup test functions # + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - use_mxfp4_w4a4=False, - per_act_token_quant=False, - block_shape=None) + m_fused_moe_fn = modular_triton_fused_moe(quant_config) def m_fused_moe( a: torch.Tensor, @@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, else: e_map = None + if weight_bits == 4: + quant_config_builder = int4_w4a16_moe_quant_config + else: + assert weight_bits == 8 + quant_config_builder = int8_w8a16_moe_quant_config + + quant_config = quant_config_builder(w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size]) + with set_current_vllm_config(vllm_config): triton_output = fused_moe(a, w1_qweight, @@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, score, topk, renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, global_num_experts=e, expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) + quant_config=quant_config) torch_output = torch_moe(a, w1_ref, w2_ref, @@ -371,8 +375,8 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) @torch.inference_mode() -def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, - monkeypatch): +def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool, + use_rocm_aiter: bool, monkeypatch): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 7bd1ffce58e96..a3b8f07638d9a 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -11,6 +11,7 @@ import torch from packaging import version from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( "quark") is not None and version.parse( @@ -19,11 +20,17 @@ QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( ) and current_platform.is_device_capability(100) +HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda() + and current_platform.is_device_capability(90) + and has_flashinfer()) + if TRTLLM_GEN_MXFP4_AVAILABLE: from flashinfer import (fp4_quantize, mxfp8_quantize, next_positive_power_of_2, reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + from flashinfer.fp4_quantization import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices @dataclass @@ -204,6 +211,7 @@ def tg_mxfp4_moe( alpha, beta, limit, + transpose_optimized: bool = False, ) -> torch.Tensor: sf_block_size = 32 assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts @@ -224,7 +232,7 @@ def tg_mxfp4_moe( assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts and w2_bias.shape[1] == hidden_size) - # Swap w1 and w3 as the defenition of + # Swap w1 and w3 as the definition of # swiglu is different in the trtllm-gen w13_weight_scale_ = w13_weight_scale.clone() w13_weight_ = w13_weight.clone() @@ -267,22 +275,85 @@ def tg_mxfp4_moe( gemm1_bias_shuffled = [] gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals - for i in range(num_experts): - gemm1_weights_shuffled.append( - shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm1_scales_shuffled.append( - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) + _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + if transpose_optimized: + for i in range(num_experts): + # w13 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_shuffled.append(w13_weight[i].view( + torch.uint8)[permute_indices.to( + w13_weight.device)].contiguous()) + # w13 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm1_scales_shuffled.append( + nvfp4_block_scale_interleave(w13_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w13_weight_scale.device)].contiguous())) + # w13 bias shuffling + permute_bias_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm1_bias_shuffled.append(w13_bias[i].clone().reshape( + -1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous()) + # w2 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_shuffled.append(w2_weight[i].view( + torch.uint8)[permute_indices.to( + w2_weight.device)].contiguous()) + # w2 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_shuffled.append( + nvfp4_block_scale_interleave(w2_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w2_weight_scale.device)].contiguous())) + # w2 bias shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + _cache_permute_indices, + w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm2_bias_shuffled.append(w2_bias[i].clone().reshape( + -1, 1)[permute_indices.to(w2_bias.device)].contiguous()) - gemm2_weights_shuffled.append( - shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)) - gemm2_scales_shuffled.append( - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm1_bias_shuffled.append( - shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) - gemm2_bias_shuffled.append( - shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) + else: + for i in range(num_experts): + gemm1_weights_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + + gemm2_weights_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) w13_weight = torch.stack(gemm1_weights_shuffled) w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape( @@ -356,6 +427,7 @@ def check_accuracy(a, b, atol, rtol, percent): @pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) @pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) +@pytest.mark.parametrize("transpose_optimized", [False, True]) @pytest.mark.skipif( not TRTLLM_GEN_MXFP4_AVAILABLE, reason="nvidia gpu and compute capability sm100 is required for this test") @@ -369,6 +441,7 @@ def test_trtllm_gen_mxfp4_fused_moe( beta: float, limit: Optional[float], act_type: str, + transpose_optimized: bool, ): seed = 42 torch.manual_seed(seed) @@ -470,6 +543,321 @@ def test_trtllm_gen_mxfp4_fused_moe( act_type, alpha=alpha, beta=beta, - limit=limit) + limit=limit, + transpose_optimized=transpose_optimized) # relatively loose check since the mxfp4 quantization is less accurate check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) + + +def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor: + """Interleave scales on the last dimension by groups of 4, matching + the transformation in mxfp4.py's BF16 (Hopper) path.""" + s = scales.to(torch.uint8) + s_shape = s.shape + assert s_shape[-1] % 4 == 0 + s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4) + # Move the 4-group dimension before the row dimension + permuted = s.permute(0, 2, 1, 3) + # Merge the row dim with the 4-group dim + return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.skipif( + not HOPPER_MXFP4_BF16_AVAILABLE, + reason="nvidia gpu sm90 and flashinfer are required for this test", +) +def test_flashinfer_cutlass_mxfp4_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: float, + beta: float, + limit: Optional[float], +): + torch.manual_seed(42) + device = "cuda:0" + + # Inputs + hidden_states = torch.randn(num_tokens, + hidden_size, + device=device, + dtype=torch.bfloat16) + # Random MXFP4 weights and scales (uint8), contiguous [w1; w3] + w13_q = torch.randint( + 0, + 256, (num_experts, 2 * intermediate_size, hidden_size // 2), + device=device, + dtype=torch.uint8) + w13_scale = torch.randint( + 118, + 123, (num_experts, 2 * intermediate_size, hidden_size // 32), + device=device, + dtype=torch.uint8) + + w2_q = torch.randint(0, + 256, + (num_experts, hidden_size, intermediate_size // 2), + device=device, + dtype=torch.uint8) + w2_scale = torch.randint( + 118, + 123, (num_experts, hidden_size, intermediate_size // 32), + device=device, + dtype=torch.uint8) + # Bias contiguous [b1; b3] + bias13 = (torch.randn(num_experts, + 2 * intermediate_size, + device=device, + dtype=torch.bfloat16) * 10) + bias2 = (torch.randn( + num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) + router_logits = torch.rand(num_tokens, + num_experts, + dtype=torch.float32, + device=device) + + w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape( + num_experts, 2 * intermediate_size, hidden_size) + w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape( + num_experts, hidden_size, intermediate_size) + ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, + hidden_states.to(torch.float32), w13_ref, + bias13.to(torch.float32), w2_ref, + bias2.to(torch.float32), alpha, beta, limit, 'bf16') + + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) + + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1) + w13_s = torch.cat([w3_s, w1_s], dim=1) + w13_s_inter = _interleave_scales_lastdim_by4(w13_s) + w2_s_inter = _interleave_scales_lastdim_by4(w2_scale) + + routing_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + token_final_scales, token_selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + token_final_scales = (token_final_scales / + token_final_scales.sum(dim=-1, keepdim=True)) + token_selected_experts = token_selected_experts.to(torch.int).contiguous() + + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) + if alpha is not None: + alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) + if beta is not None: + beta = torch.full((num_experts, ), beta, device=hidden_states.device) + if limit is not None: + limit = torch.full((num_experts, ), limit, device=hidden_states.device) + + _ = flashinfer_cutlass_fused_moe( + input=hidden_states, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + fc1_expert_weights=w13_q_swapped, + fc2_expert_weights=w2_q, + output_dtype=torch.bfloat16, + output=out, + quant_scales=[w13_s_inter.to(torch.uint8), + w2_s_inter.to(torch.uint8)], + fc1_expert_biases=w13_b, + fc2_expert_biases=bias2.to(torch.bfloat16), + swiglu_alpha=alpha, + swiglu_beta=beta, + swiglu_limit=limit, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + use_w4_group_scaling=True, + ) + + # Allow some mismatch due to MXFP4 quantization + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), + (1.702, 1.0, 7.0)]) +@pytest.mark.skipif( + not (current_platform.is_cuda() + and current_platform.is_device_capability(100) and has_flashinfer()), + reason="NVIDIA GPU sm100 and flashinfer are required for this test", +) +def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + alpha: Optional[float], + beta: Optional[float], + limit: Optional[float], +): + torch.manual_seed(42) + device = "cuda:0" + + # Inputs + hidden_states = torch.randn(num_tokens, + hidden_size, + device=device, + dtype=torch.bfloat16) + # Float weights in w13 format [w1; w3] + w13 = (torch.randn(num_experts, + 2 * intermediate_size, + hidden_size, + device=device, + dtype=torch.bfloat16) / 10) + w2 = (torch.randn(num_experts, + hidden_size, + intermediate_size, + device=device, + dtype=torch.bfloat16) / 10) + # Bias contiguous [b1; b3] + bias13 = (torch.randn(num_experts, + 2 * intermediate_size, + device=device, + dtype=torch.bfloat16) * 10) + bias2 = (torch.randn( + num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) + router_logits = torch.rand(num_tokens, + num_experts, + dtype=torch.float32, + device=device) + + # Quantize weights to MXFP4 per expert (SM100 path) + from flashinfer import mxfp4_quantize + + def quant_mxfp4_batches(a: torch.Tensor, e: int): + qs, sfs = [], [] + for i in range(e): + q, sf = mxfp4_quantize(a[i].cuda()) + qs.append(q) + sfs.append(sf) + return torch.stack(qs), torch.stack(sfs) + + def dequant_mxfp4_batches(mat_fp4: torch.Tensor, + scale_tensor: torch.Tensor): + num_batches = mat_fp4.size(0) + scale_tensor = scale_tensor.view(num_batches, -1) + from flashinfer import mxfp4_dequantize + return torch.stack([ + mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) + for b in range(num_batches) + ]) + + w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts) + w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts) + + # Reference result using dequantized tensors and reference_moe + w13_ref = dequant_mxfp4_batches( + w13_q.view(torch.uint8), + w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( + num_experts, 2 * intermediate_size, hidden_size).to(device) + w2_ref = dequant_mxfp4_batches( + w2_q.view(torch.uint8), + w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( + num_experts, hidden_size, intermediate_size).to(device) + + # Quantize activations for SM100 path and dequantize for reference + hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32) + # Reference uses BF16 input but quantizes intermediate activation to MXFP8 + ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, + hidden_states.to(torch.float32), w13_ref, + bias13.to(torch.float32), w2_ref, + bias2.to(torch.float32), alpha, beta, limit, 'mxfp8') + + # Prepare inputs for FlashInfer CUTLASS fused MoE + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) + + # Swap scales halves to match swapped weights + s1, s3 = torch.chunk(w13_scale, 2, dim=1) + w13_scale_swapped = torch.cat([s3, s1], dim=1) + + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + # Build routing for kernel + routing_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + token_final_scales, token_selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + token_final_scales = (token_final_scales / + token_final_scales.sum(dim=-1, keepdim=True)) + token_selected_experts = token_selected_experts.to(torch.int).contiguous() + + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) + if alpha is not None: + alpha_t = torch.full((num_experts, ), + alpha, + device=hidden_states.device) + else: + alpha_t = None + if beta is not None: + beta_t = torch.full((num_experts, ), beta, device=hidden_states.device) + else: + beta_t = None + if limit is not None: + limit_t = torch.full((num_experts, ), + limit, + device=hidden_states.device) + else: + limit_t = None + + # Quant scales for SM100 MXFP8+MXFP4 path + fake_input_scale = torch.ones(num_experts, device=device) + quant_scales = [ + w13_scale_swapped.view(torch.int32), + fake_input_scale, + w2_scale.view(torch.int32), + fake_input_scale, + ] + + _ = flashinfer_cutlass_fused_moe( + input=hidden_states_q, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long), + fc2_expert_weights=w2_q.contiguous().view(torch.long), + output_dtype=torch.bfloat16, + output=out, + quant_scales=quant_scales, + fc1_expert_biases=w13_b, + fc2_expert_biases=bias2.to(torch.bfloat16), + swiglu_alpha=alpha_t, + swiglu_beta=beta_t, + swiglu_limit=limit_t, + tp_size=1, + tp_rank=0, + ep_size=1, + ep_rank=0, + use_mxfp8_act_scaling=True, + input_sf=hidden_states_sf, + ) + + # Allow some mismatch due to MXFP4 quantization + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 30388ef9375d4..a48bfeb10b2e6 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform @@ -56,7 +57,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, in_dtype=dtype, quant_dtype="nvfp4", block_shape=None, # use quant_blocksize? - per_act_token_quant=False, + per_out_ch_quant=False, ) score = torch.randn((m, e), device="cuda", dtype=dtype) @@ -73,18 +74,22 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, assert w1_blockscale is not None assert w2_blockscale is not None + quant_config = nvfp4_moe_quant_config( + g1_alphas=(1 / w1_gs), + g2_alphas=(1 / w2_gs), + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + ) + cutlass_output = cutlass_moe_fp4( a=a, - a1_gscale=a1_gs, w1_fp4=w1_q, - w1_blockscale=w1_blockscale, - g1_alphas=(1 / w1_gs), - a2_gscale=a2_gs, w2_fp4=w2_q, - w2_blockscale=w2_blockscale, - g2_alphas=(1 / w2_gs), topk_weights=topk_weights, topk_ids=topk_ids, + quant_config=quant_config, m=m, n=n, k=k, diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 9e78f4d6e4da0..59126cef6adbb 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -9,6 +9,8 @@ import torch from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassBatchedExpertsFp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -143,10 +145,16 @@ def pplx_cutlass_moe( device="cuda", dtype=torch.int64) - experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, - out_dtype, per_act_token, per_out_ch, - ab_strides1, ab_strides2, c_strides1, - c_strides2) + experts = CutlassBatchedExpertsFp8( + num_local_experts, num_dispatchers, out_dtype, ab_strides1, + ab_strides2, c_strides1, c_strides2, + fp8_w8a8_moe_quant_config( + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + w1_scale=chunk_by_rank(w1_scale, rank, world_size), + w2_scale=chunk_by_rank(w2_scale, rank, world_size), + a1_scale=chunk_by_rank(a1_scale, rank, world_size) + if per_act_token else a1_scale[rank])) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, @@ -167,10 +175,7 @@ def pplx_cutlass_moe( chunk_topk_ids, global_num_experts=num_experts, expert_map=None, #TODO - w1_scale=chunk_by_rank(w1_scale, rank, world_size), - w2_scale=chunk_by_rank(w2_scale, rank, world_size), - a1_scale=chunk_by_rank(a1_scale, rank, world_size) - if per_act_token else a1_scale[rank]) + ) torch.cuda.synchronize() diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 3f36d7ada2e94..4ca4a1e79c57c 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -4,10 +4,11 @@ Run `pytest tests/kernels/test_pplx_moe.py`. """ +import copy import itertools import textwrap import traceback -from typing import Callable, Optional +from typing import Callable, Optional, Union import pytest import torch @@ -21,7 +22,10 @@ try: except ImportError: has_pplx = False -from tests.kernels.moe.utils import make_test_weights, naive_batched_moe +from tests.kernels.moe.modular_kernel_tools.parallel_utils import ( + _set_vllm_config) +from tests.kernels.moe.utils import (make_shared_experts, make_test_weights, + naive_batched_moe) from tests.kernels.quant_utils import dequant from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config @@ -54,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [ ] PPLX_COMBOS = [ - # TODO: figure out why this fails, seems to be test problem + # TODO(bnell): figure out why this fails, seems to be test problem #(1, 128, 128), (2, 128, 512), (3, 1024, 2048), @@ -356,18 +360,18 @@ def pplx_prepare_finalize( b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare( a_chunk, - a1_scale, - a2_scale, chunk_topk_weight, chunk_topk_ids, num_experts, None, False, - FusedMoEQuantConfig( + FusedMoEQuantConfig.make( quant_dtype, - per_act_token_quant, - False, - block_shape, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=False, + block_shape=block_shape, + a1_scale=a1_scale, + a2_scale=a2_scale, ), ) @@ -511,7 +515,8 @@ def pplx_moe( block_shape: Optional[list[int]] = None, use_compile: bool = False, use_cudagraphs: bool = True, -) -> torch.Tensor: + shared_experts: Optional[torch.nn.Module] = None, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] @@ -535,19 +540,6 @@ def pplx_moe( topk_ids = topk_ids.to(dtype=torch.uint32) - experts = BatchedTritonExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape, - per_act_token_quant=per_act_token_quant, - ) - - fused_experts = FusedMoEModularKernel( - prepare_finalize, - experts, - ) - # Note: workers with the same dp_rank must use the exact same inputs. a_chunk = chunk_by_rank(a, rank, world_size) chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size) @@ -561,6 +553,28 @@ def pplx_moe( a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size) a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + w1_scale=w1_scale_chunk, + w2_scale=w2_scale_chunk, + a1_scale=a1_scale_chunk, + a2_scale=a2_scale_chunk, + ) + + experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=quant_config, + ) + + fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + shared_experts, + ) + # Note: for now use_compile will error out if the problem size is # large enough to trigger chunking. I'm leaving the flag and # setup code in case we are able to revisit this later. @@ -579,14 +593,14 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, - w1_scale=w1_scale_chunk, - w2_scale=w2_scale_chunk, - a1_scale=a1_scale_chunk, - a2_scale=a2_scale_chunk, global_num_experts=num_experts) if use_cudagraphs: - out.fill_(0) + if isinstance(out, tuple): + out[0].fill_(0) + out[1].fill_(0) + else: + out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): @@ -595,10 +609,6 @@ def pplx_moe( w2_chunk, chunk_topk_weight, chunk_topk_ids, - w1_scale=w1_scale_chunk, - w2_scale=w2_scale_chunk, - a1_scale=a1_scale_chunk, - a2_scale=a2_scale_chunk, global_num_experts=num_experts) torch.cuda.synchronize() @@ -626,6 +636,7 @@ def _pplx_moe( per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, use_internode: bool = False, + shared_experts: Optional[torch.nn.Module] = None, ): try: if use_internode: @@ -666,6 +677,11 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + if shared_experts is not None: + shared_output = shared_experts(a) + else: + shared_output = None + torch_output = torch_experts( a, w1, @@ -696,7 +712,7 @@ def _pplx_moe( block_shape=block_shape, ) - pplx_output = pplx_moe( + pplx_outputs = pplx_moe( group_name, rank, world_size, @@ -713,8 +729,24 @@ def _pplx_moe( quant_dtype=quant_dtype, per_act_token_quant=per_act_token_quant, block_shape=block_shape, + shared_experts=shared_experts, ) + if shared_experts is None: + pplx_shared_output = None + pplx_output = pplx_outputs + assert isinstance(pplx_output, torch.Tensor) + else: + pplx_shared_output, pplx_output = pplx_outputs + + if shared_output is not None: + assert pplx_shared_output is not None + chunked_shared_output = chunk_by_rank( + shared_output, pgi.rank, + pgi.world_size).to(pplx_shared_output.device) + else: + chunked_shared_output = None + chunked_batch_output = chunk_by_rank( batched_output, pgi.rank, pgi.world_size).to(pplx_output.device) @@ -727,6 +759,15 @@ def _pplx_moe( chunked_batch_output, atol=3e-2, rtol=3e-2) + + if shared_experts is not None: + assert chunked_shared_output is not None + assert pplx_shared_output is not None + torch.testing.assert_close(pplx_shared_output, + chunked_shared_output, + atol=3e-2, + rtol=3e-2) + finally: if use_internode: nvshmem_finalize() @@ -779,7 +820,7 @@ def test_pplx_moe_slow( k, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e, @@ -788,7 +829,8 @@ def test_pplx_moe_slow( def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, - make_weights: bool, test_fn: Callable): + use_shared_experts: bool, make_weights: bool, + test_fn: Callable): def format_result(msg, ex=None): if ex is not None: @@ -803,6 +845,14 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, else: print(f"PASSED {msg}") + if use_shared_experts: + # Note: this config is only needed for the non-naive shared experts. + new_vllm_config = copy.deepcopy(vllm_config) + new_vllm_config.parallel_config.data_parallel_size = pgi.world_size + new_vllm_config.parallel_config.enable_expert_parallel = True + _set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, + pgi.local_rank) + current_platform.seed_everything(7) combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]) @@ -819,9 +869,11 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, use_fp8_w8a8 = False quant_dtype = None - test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " - f"dtype={dtype}, per_act_token={per_act_token_quant}, " - f"block_shape={block_shape}") + test_desc = ( + f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " + f"dtype={dtype}, per_act_token={per_act_token_quant}, " + f"block_shape={block_shape}, use_internode={use_internode}, " + f"use_shared_experts={use_shared_experts}") if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): @@ -845,13 +897,21 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, k, quant_dtype=quant_dtype, block_shape=block_shape, - per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_act_token_quant, ) args["w1"] = w1 args["w2"] = w2 args["w1_s"] = w1_s args["w2_s"] = w2_s + if use_shared_experts: + args["shared_experts"] = make_shared_experts( + n, + k, + in_dtype=a.dtype, + quant_dtype=quant_dtype, + ) + try: test_fn( pgi=pgi, @@ -891,18 +951,20 @@ def test_pplx_prepare_finalize( current_platform.seed_everything(7) world_size, dp_size = world_dp_size parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size, - use_internode, False, _pplx_prepare_finalize) + use_internode, False, False, _pplx_prepare_finalize) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @pytest.mark.parametrize("use_internode", [False]) +@pytest.mark.parametrize("use_shared_experts", [False, True]) @requires_pplx @multi_gpu_test(num_gpus=2) def test_pplx_moe( world_dp_size: tuple[int, int], use_internode: bool, + use_shared_experts: bool, ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True, - _pplx_moe) + parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, + use_shared_experts, True, _pplx_moe) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 5a0379dfb4475..383b5ebfba9b7 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -5,28 +5,52 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm) + silu_mul_fp8_quant_deep_gemm_cuda) from vllm.platforms import current_platform +from vllm.utils import cdiv + +fp8_dtype = torch.float8_e4m3fn -# (E, T, H, group_size, seed) CASES = [ - (1, 1, 128, 64, 0), - (1, 4, 128, 128, 0), - (2, 4, 256, 128, 0), - (32, 64, 256, 128, 0), - (17, 31, 768, 128, 0), + (1, 1, 128, fp8_dtype), + (1, 4, 128, fp8_dtype), + (2, 4, 256, fp8_dtype), + (32, 64, 256, fp8_dtype), + (17, 31, 768, fp8_dtype), + (1, 1, 128 * 1, fp8_dtype), + (1, 1, 128 * 2, fp8_dtype), + (1, 1, 128 * 3, fp8_dtype), + (1, 1, 128 * 4, fp8_dtype), + (8, 16, 128 * 1, fp8_dtype), + (8, 16, 128 * 2, fp8_dtype), + (8, 16, 128 * 3, fp8_dtype), + (8, 16, 128 * 4, fp8_dtype), + (8, 64, 7168, fp8_dtype), + (8, 128, 7168, fp8_dtype), + (8, 256, 7168, fp8_dtype), + (8, 512, 7168, fp8_dtype), + (8, 1024, 7168, fp8_dtype), + (256, 8, 7168, fp8_dtype), + (256, 16, 7168, fp8_dtype), + (256, 32, 7168, fp8_dtype), + (256, 64, 7168, fp8_dtype), + + # Only add a few fnuz tests to help with long CI times. + (8, 512, 7168, torch.float8_e4m3fnuz), + (8, 1024, 7168, torch.float8_e4m3fnuz), ] -@pytest.mark.parametrize("E,T,H,group_size,seed", CASES) +@pytest.mark.parametrize("E,T,H,fp8_type", CASES) @torch.inference_mode() -def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): - current_platform.seed_everything(seed) +def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): + group_size = 128 + current_platform.seed_everything(42) # Input tensor of shape (E, T, 2*H) y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( - low=0, + low=T // 2, high=T, size=(E, ), dtype=torch.int32, @@ -34,45 +58,59 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): ) # Run the Triton kernel - y_q, y_s = silu_mul_fp8_quant_deep_gemm(y, - tokens_per_expert, - group_size=group_size, - eps=1e-10) + y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y, + tokens_per_expert, + group_size=group_size) - # Reference implementation - fp8_info = torch.finfo(torch.float8_e4m3fn) + torch.cuda.synchronize() + fp8_info = torch.finfo(fp8_dtype) fp8_max = fp8_info.max fp8_min = fp8_info.min eps = 1e-10 - # Compute silu activation and elementwise multiplication - y1 = y[..., :H] + y1 = y[..., :H].float() y2 = y[..., H:] silu_x = y1 * torch.sigmoid(y1) merged = silu_x * y2 - # Compute reference scales and quantized output, skipping padded tokens for e in range(E): nt = tokens_per_expert[e].item() - ref_s = torch.empty((T, H // group_size), + ref_s = torch.empty((T, cdiv(H, group_size)), dtype=torch.float32, device="cuda") - ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda") + ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda") + for t in range(nt): - data = merged[e, t] - data_grp = data.view(H // group_size, group_size) - amax = data_grp.abs().amax(dim=1).clamp(min=eps) - scale = amax / fp8_max + data = merged[e, t].float() + ref_q_row = torch.empty_like(data) - scaled = data / scale.repeat_interleave(group_size) - clamped = scaled.clamp(fp8_min, fp8_max) - q = clamped.to(torch.float8_e4m3fn) + # process full groups + n_full_groups = H // group_size + if n_full_groups > 0: + data_grp = data[:n_full_groups * group_size].view( + n_full_groups, group_size) + amax = data_grp.abs().amax(dim=1).clamp(min=eps) + scale = amax / fp8_max + scaled = data[:n_full_groups * + group_size] / scale.repeat_interleave(group_size) + ref_q_row[:n_full_groups * group_size] = scaled.clamp( + fp8_min, fp8_max).to(fp8_dtype) + ref_s[t, :n_full_groups] = scale - ref_s[t] = scale - ref_q[t] = q + # process remainder group + rem = H % group_size + if rem > 0: + data_rem = data[-rem:] + amax = data_rem.abs().amax().clamp(min=eps) + scale = amax / fp8_max + scaled = data_rem / scale + ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype) + ref_s[t, -1] = scale - y_se = y_s[e] - y_qe = y_q[e] + ref_q[t] = ref_q_row + + y_se = y_s[e].float() + y_qe = y_q[e].float() torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index dfd0f35c8da3d..1c31464b30e7f 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -7,10 +7,12 @@ import itertools import pytest import torch +from tests.kernels.moe.utils import fused_moe from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.config import ( + fp8_w8a8_moe_quant_config) from vllm.platforms import current_platform if current_platform.get_device_capability() < (9, 0): @@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): score, topk, renormalize=False, - use_fp8_w8a8=True, # using fp8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization + quant_config=fp8_w8a8_moe_quant_config( + per_act_token_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ), ) # Check results diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 82960bd57345d..7a0feb6a20795 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -8,7 +8,9 @@ import vllm._custom_ops as ops from tests.kernels.quant_utils import per_block_cast_to_int8 from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX) -from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -33,18 +35,22 @@ def triton_moe( per_act_token_quant=False, block_shape: Optional[list[int]] = None, ) -> torch.Tensor: + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + return fused_experts(a, w1, w2, topk_weight, topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - per_channel_quant=per_act_token_quant, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - block_shape=block_shape) + quant_config=quant_config) def batched_moe( @@ -63,6 +69,16 @@ def batched_moe( ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens, num_dispatchers=1, @@ -71,21 +87,11 @@ def batched_moe( BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, + quant_config=quant_config, ), ) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + return fused_experts(a, w1, w2, topk_weight, topk_ids) def naive_batched_moe( @@ -104,6 +110,16 @@ def naive_batched_moe( ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) + quant_config = FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens, num_dispatchers=1, @@ -112,21 +128,11 @@ def naive_batched_moe( NaiveBatchedExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, - use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, + quant_config=quant_config, ), ) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale) + return fused_experts(a, w1, w2, topk_weight, topk_ids) def chunk_scales(scales: Optional[torch.Tensor], start: int, @@ -215,7 +221,7 @@ def make_test_weight( in_dtype: torch.dtype = torch.bfloat16, quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 @@ -227,7 +233,7 @@ def make_test_weight( w_gs_l = [None] * e for idx in range(e): w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights( - w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) + w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape) w = torch.stack(w_l) w_s = torch.stack(w_s_l) @@ -257,16 +263,16 @@ def make_test_weights( in_dtype: torch.dtype = torch.bfloat16, quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, ) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]]: return ( make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, - per_act_token_quant), + per_out_ch_quant), make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, - per_act_token_quant), + per_out_ch_quant), ) @@ -282,3 +288,221 @@ def per_token_cast_to_fp8( x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +def make_test_quant_config( + e: int, + n: int, + k: int, + in_dtype: torch.dtype, + quant_dtype: Union[torch.dtype, str, None] = None, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]: + (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights( + e, + n, + k, + in_dtype, + quant_dtype, + per_out_ch_quant=per_act_token_quant, + block_shape=block_shape, + ) + + # Hacky/trivial scales for nvfp4. + a1_gscale: Optional[torch.Tensor] = None + a2_gscale: Optional[torch.Tensor] = None + if quant_dtype == "nvfp4": + a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32) + a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32) + a1_scale = a1_gscale + a2_scale = a2_gscale + else: + a1_scale = None + a2_scale = None + + return w1, w2, FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_s, + w2_scale=w2_s, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + a1_scale=a1_scale, + a2_scale=a2_scale, + # TODO: make sure this is handled properly + g1_alphas=(1 / w1_gs) if w1_gs is not None else None, + g2_alphas=(1 / w2_gs) if w2_gs is not None else None, + ) + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + renormalize: bool = False, + quant_config: Optional[FusedMoEQuantConfig] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: + topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk, + renormalize) + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=quant_config) + + +# CustomOp? +class BaselineMM(torch.nn.Module): + + def __init__( + self, + b: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.b = b.to(dtype=torch.float32) + self.out_dtype = out_dtype + + def forward( + self, + a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return torch.mm(a.to(dtype=torch.float32), + self.b).to(self.out_dtype), None + + +class TestMLP(torch.nn.Module): + + def __init__( + self, + w1: torch.Tensor, + w2: torch.Tensor, + out_dtype: torch.dtype, + ): + super().__init__() + self.gate_up_proj = BaselineMM(w1, out_dtype) + self.down_proj = BaselineMM(w2, out_dtype) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +def make_naive_shared_experts( + N: int, + K: int, + in_dtype: torch.dtype = torch.bfloat16, +) -> torch.nn.Module: + w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15 + w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15 + return TestMLP(w1, w2, out_dtype=in_dtype) + + +class RealMLP(torch.nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + w1: torch.Tensor, + w2: torch.Tensor, + hidden_act: str = "silu", + quant_config=None, + reduce_results: bool = True, + prefix: str = "", + w1_s: Optional[torch.Tensor] = None, + w2_s: Optional[torch.Tensor] = None, + ) -> None: + from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, RowParallelLinear) + + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.gate_up_proj.register_parameter( + "weight", torch.nn.Parameter(w1, requires_grad=False)) + self.gate_up_proj.register_parameter( + "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)) + self.gate_up_proj.register_parameter( + "input_scale", + None) #torch.nn.Parameter(None, requires_grad=False)) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + self.down_proj.register_parameter( + "weight", torch.nn.Parameter(w2, requires_grad=False)) + self.down_proj.register_parameter( + "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)) + self.down_proj.register_parameter( + "input_scale", + None) #torch.nn.Parameter(None, requires_grad=False)) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +def make_shared_experts( + N: int, + K: int, + in_dtype: torch.dtype = torch.bfloat16, + quant_dtype: Union[torch.dtype, str, None] = None, +) -> torch.nn.Module: + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights( + 1, + N, + K, + in_dtype=in_dtype, + quant_dtype=quant_dtype, + ) + old_dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(in_dtype) + if quant_dtype == torch.float8_e4m3fn: + w1 = w1[0].transpose(0, 1) + w2 = w2[0].transpose(0, 1) + w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None + w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None + quant_config = Fp8Config(True) + else: + w1 = w1[0] + w2 = w2[0] + w1_s = None + w2_s = None + quant_config = None + + return RealMLP(K, + N, + w1, + w2, + "silu", + quant_config, + w1_s=w1_s, + w2_s=w2_s) + finally: + torch.set_default_dtype(old_dtype) diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py index 1095975ab2b41..fc4e125550180 100644 --- a/tests/kernels/quantization/nvfp4_utils.py +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm._custom_ops import scaled_fp4_quant from vllm.scalar_type import scalar_types FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() @@ -65,3 +66,10 @@ def break_fp4_bytes(a, dtype): # Reshape to final form return values.reshape(m, n * 2).to(dtype=dtype) + + +def quant_nvfp4_tensor(a: torch.Tensor): + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.abs(a).max().to(torch.float32)) + a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale) + return a_quant, a_block_scale, a_global_scale diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index d9154d3fd7f33..c440747316b80 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,8 +11,8 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, native_w8a8_block_matmul) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, - w8a8_block_fp8_matmul) + cutlass_scaled_mm, get_col_major_tma_aligned_tensor, + per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 @@ -98,6 +98,54 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 +@torch.inference_mode() +def test_w8a8_block_fp8_cutlass_matmul(): + # Test simple case where weight.shape % 128 != 0, + # like in DSV3 kv_a_proj_with_mqa + M = 32 + N = 576 + K = 7168 + block_size = [128, 128] + out_dtype = torch.bfloat16 + seed = 0 + + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + # Hopper requires row-major format for scales + Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability( + 90) else Bs + + A_fp8, As = per_token_group_quant_fp8(A_fp32, + block_size[1], + column_major_scales=False) + # CUTLASS uses column-major format for scales + A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8( + A_fp32, block_size[1], column_major_scales=True) + + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + out = cutlass_scaled_mm(A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, + block_size, out_dtype) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 + + @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py new file mode 100644 index 0000000000000..720eee62760db --- /dev/null +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for QuantFP8 Group Quantization implementation.""" + +import pytest +import torch + +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.platforms import current_platform + + +@pytest.mark.parametrize( + "batch_size,hidden_dim,group_size", + [ + (16, 256, 32), # Small + (64, 1024, 64), # Medium + (128, 2048, 128), # Large + (8, 513, 64), # Non-divisible (native only) + ]) +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, + group_size: int, seed: int) -> None: + """Test QuantFP8 group quantization with various configurations. + + Tests both CUDA and native implementations, column-major scales, + and verifies consistency between implementations. + """ + current_platform.seed_everything(seed) + + x = torch.randn( + (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + expected_num_groups = (hidden_dim + group_size - 1) // group_size + is_divisible = hidden_dim % group_size == 0 + + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + # 1. Test native implementation (always available) + x_quant_native, scales_native = quant_op.forward_native(x.clone()) + assert x_quant_native.shape == x.shape + assert scales_native.shape == (batch_size, expected_num_groups) + + # 2. Test column-major scales configuration + quant_op_col = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=True) + _, scales_col = quant_op_col.forward_native(x.clone()) + assert scales_col.shape == (expected_num_groups, batch_size) + + # 3. Test CUDA implementation (only for divisible dimensions) + if is_divisible: + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) + assert x_quant_cuda.shape == x.shape + assert scales_cuda.shape == (batch_size, expected_num_groups) + + # Verify CUDA/native consistency + assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) + + # Quantized values should mostly match + diff_count = (x_quant_cuda != x_quant_native).sum().item() + diff_ratio = diff_count / x_quant_cuda.numel() + assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}" + + +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_multidimensional(seed: int) -> None: + current_platform.seed_everything(seed) + + group_size = 64 + + # Test with 3D input + batch1, batch2, hidden_dim = 4, 8, 512 + x_3d = torch.randn( + (batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + x_quant, scales = quant_op.forward_native(x_3d.clone()) + assert x_quant.shape == x_3d.shape + assert scales.shape == (batch1, batch2, hidden_dim // group_size) + + # Test column_major_scales with multi-dim + quant_op_col = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=True) + _, scales_col = quant_op_col.forward_native(x_3d.clone()) + assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) + + # Test with 4D input + batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256 + x_4d = torch.randn((batch1, batch2, batch3, hidden_dim), + dtype=torch.bfloat16, + device="cuda") * 8 + + x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone()) + assert x_quant_4d.shape == x_4d.shape + assert scales_4d.shape == (batch1, batch2, batch3, + hidden_dim // group_size) + + _, scales_4d_col = quant_op_col.forward_native(x_4d.clone()) + assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, + batch3) + + +@pytest.mark.parametrize("seed", [42]) +@torch.inference_mode() +def test_quantfp8_group_edge_cases(seed: int) -> None: + current_platform.seed_everything(seed) + + batch_size = 16 + group_size = 64 + + # Test with single group (group_size >= hidden_dim) + x_small = torch.randn( + (batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8 + group_shape = GroupShape(1, group_size) + quant_op = QuantFP8(static=False, + group_shape=group_shape, + column_major_scales=False) + + x_quant_small, scales_small = quant_op.forward_native(x_small.clone()) + assert x_quant_small.shape == x_small.shape + assert scales_small.shape == (batch_size, 1) + + # Test with zero inputs + x_zero = torch.zeros((batch_size, 256), + dtype=torch.bfloat16, + device="cuda") + x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone()) + assert x_quant_zero.shape == x_zero.shape + assert (scales_zero > 0).all(), "Scales should be clamped to minimum" + + # Test very large values + x_large = torch.full((batch_size, 256), + 1000.0, + dtype=torch.bfloat16, + device="cuda") + x_quant_large, scales_large = quant_op.forward_native(x_large.clone()) + assert x_quant_large.shape == x_large.shape + # FP8 max is typically 448 or 224, so scales should be > 1 + assert (scales_large > 1.0).all(), "Large values should have scales > 1" diff --git a/tests/kernels/quantization/test_hadacore.py b/tests/kernels/quantization/test_hadacore.py new file mode 100644 index 0000000000000..127d68072e3f3 --- /dev/null +++ b/tests/kernels/quantization/test_hadacore.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import pytest +import torch +from compressed_tensors.transform import deterministic_hadamard_matrix + +from vllm import _custom_ops as ops + + +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("hidden_dim", [2**n for n in range(10)]) +def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"): + x = torch.eye(hidden_dim, dtype=dtype, device=device) + hadamard = deterministic_hadamard_matrix( + hidden_dim, dtype=torch.float64, device="cuda") / math.sqrt(hidden_dim) + + y = ops.hadacore_transform(x.clone()) + y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype) + assert torch.allclose(y, y_true) + + y = ops.hadacore_transform(y) + assert torch.allclose(y, x) diff --git a/tests/kernels/quantization/test_int8_kernel.py b/tests/kernels/quantization/test_int8_kernel.py index dc5fecbf4ccc8..f2271e6be5420 100644 --- a/tests/kernels/quantization/test_int8_kernel.py +++ b/tests/kernels/quantization/test_int8_kernel.py @@ -8,7 +8,8 @@ import pytest import torch from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_quant_int8) from vllm.platforms import current_platform @@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): return C.reshape(origin_C_shape).to(output_dtype) -def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): +def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, + topk_ids): """This function performs fused moe with per-column int8 quantization using native torch.""" @@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) # Calculate routing - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) # Process each expert @@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale score = torch.randn((M, E), dtype=dtype) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(score, topk) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) - out = fused_moe( + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, + topk_weights, topk_ids) + + quant_config = FusedMoEQuantConfig.make( + torch.int8, + per_act_token_quant=True, + block_shape=None, + w1_scale=w1_s, + w2_scale=w2_s, + ) + + out = fused_experts( a, w1, w2, - score, - topk, - renormalize=False, - use_int8_w8a8=True, # Using int8-w8a8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization + topk_weights, + topk_ids, + quant_config=quant_config, ) # Check results diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 03d5d98739c50..a9b1c71ef0718 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -5,6 +5,8 @@ import torch import vllm._custom_ops as ops from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + rocm_per_tensor_w8a8_scaled_mm_impl) from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] @@ -116,3 +118,32 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): current_platform.get_cu_count()) assert torch.allclose(out, ref_out, rtol=0.01) + + +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="only test for rocm fp8") +def test_rocm_per_tensor_w8a8_scaled_mm_impl(n, k, m, dtype, seed, use_bias): + torch.manual_seed(seed) + + A = torch.rand(n, k, device="cuda") + B = torch.rand(m, k, device="cuda") + + A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) + B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) + + bias = torch.rand(1, m, dtype=dtype, device="cuda") if use_bias else None + + output = rocm_per_tensor_w8a8_scaled_mm_impl(A, B.t(), dtype, scale_a, + scale_b, bias) + ref_out = torch._scaled_mm(A, + B.t(), + out_dtype=dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + assert torch.allclose(output, ref_out, rtol=0.01) diff --git a/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py b/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py new file mode 100644 index 0000000000000..a40d0c4ef1224 --- /dev/null +++ b/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) +from vllm._custom_ops import scaled_fp4_quant +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms import current_platform + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +FP4_DTYPE = torch.uint8 +FP8_DTYPE = current_platform.fp8_dtype() + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 256), (128, 128), (256, 256), (256, 128)] +BLOCK_SIZE = 16 + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@torch.inference_mode() +def test_silu_mul_nvfp4_quant( + dtype: torch.dtype, + shape: tuple[int, int], +) -> None: + current_platform.seed_everything(42) + device = 'cuda:0' + torch.set_default_device(device) + + x = torch.randn(shape, dtype=dtype) + + # ref op + ref_output = SiluAndMul().forward_native(x) + ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.abs(ref_output).max().to(torch.float32)) + ref_output_quant, ref_block_scale = scaled_fp4_quant( + ref_output, ref_global_scale) + + # fused op + fused_output_quant = torch.empty_like(ref_output_quant) + fused_block_scale = torch.empty_like(ref_block_scale) + torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant, + fused_block_scale, x, + ref_global_scale) + + # check dtype + assert ref_output_quant.dtype == FP4_DTYPE + assert fused_output_quant.dtype == FP4_DTYPE + assert ref_output_quant.shape == fused_output_quant.shape + + assert ref_block_scale.dtype == FP8_DTYPE + assert fused_block_scale.dtype == FP8_DTYPE + assert ref_block_scale.shape == fused_block_scale.shape + + # check dequantized output + ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant, + ref_block_scale, + ref_global_scale, dtype, + device) + fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant, + fused_block_scale, + ref_global_scale, dtype, + device) + + atol, rtol = 3e-1, 3e-1 + torch.testing.assert_close(ref_output_dequant, + fused_output_dequant, + atol=atol, + rtol=rtol) diff --git a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py deleted file mode 100644 index 969f14cc3fe62..0000000000000 --- a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest -import torch - -from tests.kernels.utils import opcheck -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types - -if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) - -DTYPES = [torch.float16, torch.bfloat16] -SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] -SEEDS = [42] -CUDA_DEVICES = ['cuda:0'] - -FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - -BLOCK_SIZE = 16 - - -def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, - global_scale: torch.Tensor, - ref_output_scale: torch.Tensor) -> torch.Tensor: - silu_and_mul_out = silu_and_mul.forward_native(x) - assert not current_platform.is_rocm() - assert silu_and_mul_out.ndim >= 1, ( - f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.') - other_dims = 1 if silu_and_mul_out.ndim == 1 else -1 - silu_and_mul_out = silu_and_mul_out.reshape(other_dims, - silu_and_mul_out.shape[-1]) - m, n = silu_and_mul_out.shape - device = silu_and_mul_out.device - - # Two fp4 values will be packed into an uint8. - out = torch.empty((m, n // 2), device=device, dtype=torch.uint8) - - output_scale = ref_output_scale - - torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale, - global_scale) - - return out, output_scale - - -def ops_impl(x: torch.Tensor, global_scale: torch.Tensor, - ref_output_scale: torch.Tensor) -> torch.Tensor: - out_shape = (x.shape[0], x.shape[1] // 4) - output_scale = ref_output_scale - out = torch.empty(out_shape, dtype=torch.uint8, device=x.device) - torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale) - return out, output_scale - - -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("shape", SHAPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_quantize_to_fp4( - dtype: torch.dtype, - shape: tuple[int, int], - seed: int, - device: str, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - - m, n = shape - - x = torch.randn((m, n), dtype=dtype) - tensor_amax = torch.abs(x).max().to(torch.float32) - global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - - block_size = 16 - - assert n % block_size == 0, ( - f'last dim has to be multiple of 16, but got {n}.') - assert x.dtype in (torch.float16, torch.bfloat16), ( - f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.') - - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(x.shape[0], 128) - scale_n = x.shape[1] // (2 * block_size) - rounded_n = round_up(scale_n, 4) - output_scale = torch.empty((rounded_m, rounded_n // 4), - device=x.device, - dtype=torch.int32) - - layer = SiluAndMul() - - ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale) - - fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale) - - assert ref_out.dtype == torch.uint8 - assert fusion_out.dtype == torch.uint8 - assert ref_out.shape == fusion_out.shape - - assert ref_out_scale.dtype == torch.int32 - assert fusion_out_scale.dtype == torch.int32 - assert ref_out_scale.shape == fusion_out_scale.shape - - # Allow up to 2% of mismatched values since BF16 has accuracy issues. - mis_threshold = 0.02 - atol = 0.4 - rtol = 0.4 - ref_logits = ref_out[-1] - fusion_logits = fusion_out[-1] - - mis_count = torch.sum( - torch.abs(fusion_logits - ref_logits) > (atol + - rtol * torch.abs(ref_logits))) - mis_ratio = mis_count / fusion_logits.numel() - - assert mis_ratio < mis_threshold, \ - f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}" - - torch.testing.assert_close(ref_out_scale, fusion_out_scale) - - opcheck(torch.ops._C.silu_and_mul_nvfp4_quant, - (fusion_out, fusion_out_scale, x, global_scale)) diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py deleted file mode 100644 index 2b745b84dae6c..0000000000000 --- a/tests/kernels/test_cutlass_mla_decode.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest -import torch -import torch.nn.functional as F -from torch import Tensor - -import vllm._custom_ops as ops -from vllm.platforms import current_platform - -if not current_platform.has_device_capability(100): - pytest.skip( - reason="Cutlass MLA Requires compute capability of 10 or above.", - allow_module_level=True) - - -def ref_mla( - out: Tensor, # (bs, num_heads, v_head_dim) - query: Tensor, # (bs, num_heads, head_dim) - kv_cache: Tensor, # (num_blocks, block_size, head_dim) - scale: float, - block_tables: Tensor, # (bs, max_num_blocks) - seq_lens: Tensor, # (bs,) -): - bs, num_heads, v_head_dim = out.shape - head_dim = query.shape[2] - - for i in range(bs): - # gather and flatten KV-cache - kv = kv_cache[ - block_tables[i]] # (max_num_blocks, block_size, head_dim) - kv = kv.view(1, -1, - head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) - v = kv[:, :, :v_head_dim] - - q = query[i].view(num_heads, 1, head_dim) - o = F.scaled_dot_product_attention(q, - kv, - v, - scale=scale, - enable_gqa=True) - out[i] = o.view(num_heads, v_head_dim) - - return out - - -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) -@pytest.mark.parametrize("bs", [1, 2, 4]) -@pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("block_size", [16, 64, 128]) -def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, - varlen: bool, block_size: int): - torch.set_default_dtype(dtype) - torch.set_default_device('cuda') - torch.manual_seed(42) - - d = 576 - h_q = 128 - dv = 512 - - q_nope_dim = 128 - q_pe_dim = 64 - scale = (q_nope_dim + q_pe_dim)**(-0.5) - if varlen: - seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) - seq_lens = seq_lens.clip(2).to(torch.int32) - else: - seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) - max_seq_len = seq_lens.max().item() - block_num = (max_seq_len + block_size - 1) // block_size - - # Pad block_num so that small blocks can be packed into full 128-sized - # CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small - # blocks. - pack_factor = 128 // block_size - block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor - - # Amplify input values to ensure test coverage of edge cases where CUTLASS - # kernel errors occur with split_k settings. - q = torch.randn(bs, h_q, d) * 100 - block_table = torch.randint(0, - bs * block_num, (bs, block_num), - dtype=torch.int32) - - kv_cache = torch.randn(block_table.numel(), block_size, d) - - out_ref = q.new_zeros(bs, h_q, dv) - ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) - out_ans = torch.zeros_like(out_ref) - q_nope = q[:, :, :dv].clone() - q_pe = q[:, :, dv:].clone() - ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens, - block_table, scale) - - torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/test_onednn.py b/tests/kernels/test_onednn.py index 17692384ac9a9..37772464a209b 100644 --- a/tests/kernels/test_onednn.py +++ b/tests/kernels/test_onednn.py @@ -111,6 +111,49 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int, torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) +def onednn_gemm_test_helper(primitive_cache_size: int, + m: int, + n: int, + k: int, + use_bias: bool, + use_stride: bool, + dtype: torch.dtype = torch.bfloat16, + device: str = "cpu"): + if use_stride: + a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5 + a = a[:, :k] + else: + a = torch.rand((m, k), dtype=dtype, device=device) * 1.5 + + b = torch.rand((n, k), dtype=dtype, device=device) * 1.5 + + if use_bias: + bias = torch.rand((n, ), device=device, dtype=dtype) * 5 + bias_f32 = bias.float() + else: + bias = None + bias_f32 = None + + handler = ops.create_onednn_mm( + b.t(), + primitive_cache_size, + ) + + out = ops.onednn_mm(handler, a, bias) + baseline = torch.nn.functional.linear(a.float(), b.float(), + bias_f32).to(dtype=a.dtype) + + torch.testing.assert_close(out, baseline) + + if use_bias: + # To test runtime bias setting + out = ops.onednn_mm(handler, a, None) + baseline = torch.nn.functional.linear(a.float(), b.float(), + None).to(dtype=a.dtype) + + torch.testing.assert_close(out, baseline) + + @pytest.mark.parametrize("n,k", NK_FACTORS) @pytest.mark.parametrize("m_list", M_FACTORS) @pytest.mark.parametrize("per_tensor_a_scale", [True, False]) @@ -142,3 +185,30 @@ def test_onednn_int8_scaled_gemm( use_azp=use_azp, out_dtype=output_type, ) + + +@pytest.mark.parametrize("n,k", NK_FACTORS) +@pytest.mark.parametrize("m_list", M_FACTORS) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_stride", [True, False]) +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES) +def test_onednn_gemm( + n: int, + k: int, + m_list: tuple[int], + use_bias: bool, + use_stride: bool, + dtype: torch.dtype, + primitive_cache_size: int, +): + for m in m_list: + onednn_gemm_test_helper( + primitive_cache_size=primitive_cache_size, + m=m, + n=n, + k=k, + use_bias=use_bias, + use_stride=use_stride, + dtype=dtype, + ) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index fa4125840a010..c9bf85f6e2a5c 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1236,7 +1236,7 @@ def baseline_scaled_mm(a: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # We treat N-dimensional group scaling as extended numpy-style broadcasting - # in numpy simply stretches dimensions with an extent of 1 to match the + # in numpy simply stretches dimensions with an extent of 1 to match # the target shape by repeating the data along that dimension (broadcasting) # , we extend these semantics to say if the extent of a dimension in the # source shape is not 1 and does not match the target shape we repeat each @@ -1247,7 +1247,7 @@ def baseline_scaled_mm(a: torch.Tensor, # then we would expand a to: # a = [[1, 1, 2, 2], # [3, 3, 4, 4]] - # NOTE this function this function does not explicitly broadcast dimensions + # NOTE this function does not explicitly broadcast dimensions # with an extent of 1, since this can be done implicitly by pytorch def group_broadcast(t, shape): for i, s in enumerate(shape): diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index 352ab63552de7..ca2f04dabfc98 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -128,7 +128,7 @@ if __name__ == "__main__": print(f"initialized! My rank is {my_rank}") config = KVTransferConfig( - kv_connector='PyNcclConnector', + kv_connector='P2pNcclConnector', kv_buffer_device='cuda', kv_buffer_size=1e9, kv_rank=my_rank, diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 32116608a2177..99ad2b43aeac8 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -137,7 +137,7 @@ if __name__ == "__main__": ) config = KVTransferConfig( - kv_connector='PyNcclConnector', + kv_connector='P2pNcclConnector', kv_buffer_device='cuda', kv_buffer_size=1e9, kv_rank=my_rank, diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 6e2dda464d8eb..6735b7cd9e436 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -11,21 +11,21 @@ import pytest import torch import torch.nn.functional as F -from vllm.config import LoRAConfig -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA) +from vllm.config.lora import LoRAConfig # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, LogitsProcessorWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights @@ -60,9 +60,9 @@ DEVICES = ([ # prefill stage(True) or decode stage(False) STAGES = [True, False] -NUM_RANDOM_SEEDS = 6 +NUM_RANDOM_SEEDS = 2 -VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 +VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 2 @pytest.fixture(autouse=True) diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 06196cc697cec..a6770e6d32af8 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -13,14 +13,6 @@ from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test MODEL_PATH = "meta-llama/Llama-2-7b-hf" -EXPECTED_NO_LORA_OUTPUT = [ - "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501 - "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501 - "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501 -] EXPECTED_LORA_OUTPUT = [ " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 @@ -79,23 +71,12 @@ def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: Union[dict, None] = None): print("lora adapter created") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=0) == EXPECTED_NO_LORA_OUTPUT - print("lora 1") assert do_sample(llm, sql_lora_files, tensorizer_config_dict=tensorizer_config_dict, lora_id=1) == EXPECTED_LORA_OUTPUT - print("no lora") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=0) == EXPECTED_NO_LORA_OUTPUT - print("lora 2") assert do_sample(llm, sql_lora_files, @@ -110,6 +91,7 @@ def test_llama_lora(sql_lora_files): llm = vllm.LLM( MODEL_PATH, + tokenizer=sql_lora_files, enable_lora=True, # also test odd max_num_seqs max_num_seqs=13, @@ -123,6 +105,7 @@ def test_llama_lora_tp4(sql_lora_files): llm = vllm.LLM( MODEL_PATH, + tokenizer=sql_lora_files, enable_lora=True, max_num_seqs=16, max_loras=4, @@ -137,6 +120,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): llm = vllm.LLM( MODEL_PATH, + tokenizer=sql_lora_files, enable_lora=True, max_num_seqs=16, max_loras=4, @@ -184,6 +168,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) loaded_llm = LLM(model=model_ref, + tokenizer=sql_lora_files, load_format="tensorizer", enable_lora=True, enforce_eager=True, @@ -195,11 +180,6 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, tc_as_dict = tensorizer_config.to_serializable() print("lora adapter created") - assert do_sample(loaded_llm, - sql_lora_files, - tensorizer_config_dict=tc_as_dict, - lora_id=0) == EXPECTED_NO_LORA_OUTPUT - print("lora 1") assert do_sample(loaded_llm, sql_lora_files, diff --git a/tests/lora/test_lora_allowed_token_ids.py b/tests/lora/test_lora_allowed_token_ids.py deleted file mode 100644 index e77eae70445db..0000000000000 --- a/tests/lora/test_lora_allowed_token_ids.py +++ /dev/null @@ -1,135 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - VllmConfig) -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.v1.engine.processor import Processor - - -def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, - sql_lora_files): - """ - Test that we properly resolve the range of allowed token ids for lora - adapters that define additional tokens. - """ - - # Set up a base model compatible with the sql_lora_files adapter and - # a known number of tokens in the base model. - model_config = ModelConfig( - model=llama_2_7b_base_huggingface_id, - tokenizer=llama_2_7b_base_huggingface_id, - tokenizer_mode="auto", - ) - - vllm_config = VllmConfig( - model_config=model_config, - cache_config=CacheConfig(), - device_config=DeviceConfig(), - lora_config=LoRAConfig(), - ) - - tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) - processor = Processor(vllm_config, tokenizer) - - lora_request = LoRARequest("1", 1, str(sql_lora_files)) - request_id = "1" - prompt = "a prompt" - - # tokens added in the lora adapter should not raise an error - lora_token_ids = [32000, 32001, 32002, 32003] - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=lora_token_ids), - lora_request=lora_request) - - # tokens in the base model should not raise an error - base_token_ids = [1000, 1001, 1002, 1003] - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=base_token_ids), - lora_request=lora_request) - - # tokens not in the lora adapter should raise an error - invalid_token_ids = [35000, 35001, 35002, 35003] - with pytest.raises(ValueError): - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=invalid_token_ids), - lora_request=lora_request) - - # tokens in the lora adapter with no lora request should raise an error - with pytest.raises(ValueError): - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=lora_token_ids), - ) - - -def test_allowed_token_ids_with_lora_adapter_no_vocab( - qwen25vl_base_huggingface_id, qwen25vl_lora_files): - """ - Test that we properly resolve the range of allowed token ids for lora - adapters that do not define additional tokens. - """ - - # Set up a base model compatible with the qwen25vl_lora_files adapter and - # a known number of tokens in the base model. - model_config = ModelConfig( - model=qwen25vl_base_huggingface_id, - tokenizer=qwen25vl_base_huggingface_id, - tokenizer_mode="auto", - ) - - vllm_config = VllmConfig( - model_config=model_config, - cache_config=CacheConfig(), - device_config=DeviceConfig(), - lora_config=LoRAConfig(), - ) - - tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) - processor = Processor(vllm_config, tokenizer) - - lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files)) - request_id = "1" - prompt = "a prompt" - - # tokens in the base model should not raise an error - base_token_ids = [1000, 1001, 1002, 1003] - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=base_token_ids), - lora_request=lora_request) - - # tokens in the base model with no lora request should not raise an error - base_token_ids = [1000, 1001, 1002, 1003] - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=base_token_ids), - ) - - # tokens not in the base model should raise an error - invalid_token_ids = [200000, 200001, 200002, 200003] - with pytest.raises(ValueError): - processor.process_inputs( - request_id, - prompt, - params=SamplingParams(allowed_token_ids=invalid_token_ids), - lora_request=lora_request) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index c9ab32edc7f32..a5802c108c6be 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -8,7 +8,7 @@ import torch from safetensors.torch import load_file from torch import nn -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.lora.layers import (ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, RowParallelLinearWithLoRA) diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py index df8696cf58e0f..ffffb5d8eab90 100644 --- a/tests/lora/test_peft_helper.py +++ b/tests/lora/test_peft_helper.py @@ -7,7 +7,7 @@ import shutil import pytest -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.lora.peft_helper import PEFTHelper ERROR_CASES = [ diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index caa31fdb0e73e..2b54b2edd6a9c 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -82,31 +82,20 @@ def test_quant_model_lora(tinyllama_lora_files, model): gpu_memory_utilization=0.2, #avoid OOM quantization=model.quantization, trust_remote_code=True, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + tokenizer=tinyllama_lora_files) if model.quantization is None: - expected_no_lora_output = [ - "Here are some examples of orange-brown colors", - "I'm sorry, I don't have" - ] expected_lora_output = [ "#ff8050", "#ff8080", ] elif model.quantization == "awq": - expected_no_lora_output = [ - "I'm sorry, I don't understand", - "I'm sorry, I don't understand", - ] expected_lora_output = [ "#f07700: A v", "#f00000: A v", ] elif model.quantization == "gptq": - expected_no_lora_output = [ - "I'm sorry, I don't have", - "I'm sorry, I don't have", - ] expected_lora_output = [ "#f08800: This is", "#f07788 \n#", @@ -117,7 +106,6 @@ def test_quant_model_lora(tinyllama_lora_files, model): # Assert that the outputs changed. if (model.quantization == "gptq" and expected_output is expected_lora_output): - assert output != expected_no_lora_output for i, o in enumerate(output): assert o.startswith( '#'), f"Expected example {i} to start with # but got {o}" @@ -127,12 +115,6 @@ def test_quant_model_lora(tinyllama_lora_files, model): max_tokens = 10 print("lora adapter created") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=0, - max_tokens=max_tokens) - expect_match(output, expected_no_lora_output) - print("lora 1") output = do_sample(llm, tinyllama_lora_files, @@ -140,13 +122,6 @@ def test_quant_model_lora(tinyllama_lora_files, model): max_tokens=max_tokens) expect_match(output, expected_lora_output) - print("no lora") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=0, - max_tokens=max_tokens) - expect_match(output, expected_no_lora_output) - print("lora 2") output = do_sample(llm, tinyllama_lora_files, diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py deleted file mode 100644 index 6cfdaf50d33c4..0000000000000 --- a/tests/lora/test_tokenizer_group.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import get_lora_tokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup - - -@pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) -async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): - reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) - tokenizer_group = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=True, - max_num_seqs=1, - max_loras=1, - max_input_length=None, - ) - lora_request = LoRARequest("1", 1, sql_lora_files) - assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( - prompt="prompt", lora_request=lora_request) - assert reference_tokenizer.encode( - "prompt") == await tokenizer_group.encode_async( - prompt="prompt", lora_request=lora_request) - assert isinstance(tokenizer_group.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer_group.get_lora_tokenizer( - None) == await tokenizer_group.get_lora_tokenizer_async(None) - - assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request), - PreTrainedTokenizerBase) - assert tokenizer_group.get_lora_tokenizer( - lora_request) != tokenizer_group.get_lora_tokenizer(None) - assert tokenizer_group.get_lora_tokenizer( - lora_request) == await tokenizer_group.get_lora_tokenizer_async( - lora_request) - - -def test_get_lora_tokenizer(sql_lora_files, tmp_path): - lora_request = None - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer - - lora_request = LoRARequest("1", 1, sql_lora_files) - tokenizer = get_lora_tokenizer(lora_request) - assert tokenizer.get_added_vocab() - - lora_request = LoRARequest("1", 1, str(tmp_path)) - tokenizer = get_lora_tokenizer(lora_request) - assert not tokenizer - - -@pytest.mark.parametrize("enable_lora", [True, False]) -@pytest.mark.parametrize("max_num_seqs", [1, 2]) -@pytest.mark.parametrize("max_loras", [1, 2]) -def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras): - tokenizer_group = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=enable_lora, - max_num_seqs=max_num_seqs, - max_loras=max_loras, - max_input_length=None, - ) - if enable_lora: - assert tokenizer_group.lora_tokenizers.capacity == max( - max_num_seqs, max_loras) - else: - assert tokenizer_group.lora_tokenizers.capacity == 0 diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index a836ff94ba3ed..9c47abf8f4dce 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -6,9 +6,10 @@ import random import tempfile from unittest.mock import patch -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VllmConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.config.load import LoadConfig +from vllm.config.lora import LoRAConfig from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu_worker import Worker diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py deleted file mode 100644 index dbd9c518e0200..0000000000000 --- a/tests/metrics/test_metrics.py +++ /dev/null @@ -1,268 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import ray -from prometheus_client import REGISTRY - -import vllm.envs as envs -from vllm import EngineArgs, LLMEngine -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.metrics import RayPrometheusStatLogger -from vllm.sampling_params import SamplingParams -from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module tests V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -MODELS = [ - "distilbert/distilgpt2", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [128]) -def test_metric_counter_prompt_tokens( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4) as vllm_model: - tokenizer = vllm_model.llm.get_tokenizer() - prompt_token_counts = [ - len(tokenizer.encode(p)) for p in example_prompts - ] - # This test needs at least 2 prompts in a batch of different lengths to - # verify their token count is correct despite padding. - assert len(example_prompts) > 1, "at least 2 prompts are required" - assert prompt_token_counts[0] != prompt_token_counts[1], ( - "prompts of different lengths are required") - vllm_prompt_token_count = sum(prompt_token_counts) - - _ = vllm_model.generate_greedy(example_prompts, max_tokens) - stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] - metric_count = stat_logger.metrics.counter_prompt_tokens.labels( - **stat_logger.labels)._value.get() - - assert vllm_prompt_token_count == metric_count, ( - f"prompt token count: {vllm_prompt_token_count!r}\n" - f"metric: {metric_count!r}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [128]) -def test_metric_counter_generation_tokens( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - tokenizer = vllm_model.llm.get_tokenizer() - stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] - metric_count = stat_logger.metrics.counter_generation_tokens.labels( - **stat_logger.labels)._value.get() - vllm_generation_count = 0 - for i in range(len(example_prompts)): - vllm_output_ids, vllm_output_str = vllm_outputs[i] - prompt_ids = tokenizer.encode(example_prompts[i]) - # vllm_output_ids contains both prompt tokens and generation tokens. - # We're interested only in the count of the generation tokens. - vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) - - assert vllm_generation_count == metric_count, ( - f"generation token count: {vllm_generation_count!r}\n" - f"metric: {metric_count!r}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize( - "served_model_name", - [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]]) -def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, - served_model_name: list[str]) -> None: - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.3, - served_model_name=served_model_name) as vllm_model: - stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] - metrics_tag_content = stat_logger.labels["model_name"] - - if envs.VLLM_CI_USE_S3: - model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" - if served_model_name is None or served_model_name == []: - assert metrics_tag_content == model, ( - f"Metrics tag model_name is wrong! expect: {model!r}\n" - f"actual: {metrics_tag_content!r}") - else: - assert metrics_tag_content == served_model_name[0], ( - f"Metrics tag model_name is wrong! expect: " - f"{served_model_name[0]!r}\n" - f"actual: {metrics_tag_content!r}") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [4]) -@pytest.mark.parametrize("disable_log_stats", [True, False]) -@pytest.mark.asyncio -async def test_async_engine_log_metrics_regression( - example_prompts, - model: str, - dtype: str, - max_tokens: int, - disable_log_stats: bool, -) -> None: - """ - Regression test ensuring async engine generates metrics - when disable_log_stats=False - (see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678) - """ - engine_args = AsyncEngineArgs( - model=model, - dtype=dtype, - disable_log_stats=disable_log_stats, - ) - async_engine = AsyncLLMEngine.from_engine_args(engine_args) - for i, prompt in enumerate(example_prompts): - results = async_engine.generate( - prompt, - SamplingParams(max_tokens=max_tokens), - f"request-id-{i}", - ) - # Exhaust the async iterator to make the async engine work - async for _ in results: - pass - - assert_metrics(model, async_engine.engine, disable_log_stats, - len(example_prompts)) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [4]) -@pytest.mark.parametrize("disable_log_stats", [True, False]) -def test_engine_log_metrics_regression( - example_prompts, - model: str, - dtype: str, - max_tokens: int, - disable_log_stats: bool, -) -> None: - engine_args = EngineArgs( - model=model, - dtype=dtype, - disable_log_stats=disable_log_stats, - ) - engine = LLMEngine.from_engine_args(engine_args) - for i, prompt in enumerate(example_prompts): - engine.add_request( - f"request-id-{i}", - prompt, - SamplingParams(max_tokens=max_tokens), - ) - while engine.has_unfinished_requests(): - engine.step() - - if envs.VLLM_CI_USE_S3: - model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" - assert_metrics(model, engine, disable_log_stats, len(example_prompts)) - - -def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool, - num_requests: int) -> None: - if disable_log_stats: - with pytest.raises(AttributeError): - _ = engine.stat_loggers - else: - assert (engine.stat_loggers - is not None), "engine.stat_loggers should be set" - # Ensure the count bucket of request-level histogram metrics matches - # the number of requests as a simple sanity check to ensure metrics are - # generated - labels = {'model_name': model} - request_histogram_metrics = [ - "vllm:e2e_request_latency_seconds", - "vllm:request_prompt_tokens", - "vllm:request_generation_tokens", - "vllm:request_params_n", - "vllm:request_params_max_tokens", - ] - for metric_name in request_histogram_metrics: - metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", - labels) - assert ( - metric_value == num_requests), "Metrics should be collected" - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [16]) -def test_engine_log_metrics_ray( - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - # This test is quite weak - it only checks that we can use - # RayPrometheusStatLogger without exceptions. - # Checking whether the metrics are actually emitted is unfortunately - # non-trivial. - - # We have to run in a Ray task for Ray metrics to be emitted correctly - @ray.remote(num_gpus=1) - def _inner(): - - class _RayPrometheusStatLogger(RayPrometheusStatLogger): - - def __init__(self, *args, **kwargs): - self._i = 0 - super().__init__(*args, **kwargs) - - def log(self, *args, **kwargs): - self._i += 1 - return super().log(*args, **kwargs) - - engine_args = EngineArgs( - model=model, - dtype=dtype, - disable_log_stats=False, - ) - engine = LLMEngine.from_engine_args(engine_args) - logger = _RayPrometheusStatLogger( - local_interval=0.5, - labels=dict(model_name=engine.model_config.served_model_name), - vllm_config=engine.vllm_config) - engine.add_logger("ray", logger) - for i, prompt in enumerate(example_prompts): - engine.add_request( - f"request-id-{i}", - prompt, - SamplingParams(max_tokens=max_tokens), - ) - while engine.has_unfinished_requests(): - engine.step() - assert logger._i > 0, ".log must be called at least once" - - ray.get(_inner.remote()) diff --git a/tests/model_executor/model_loader/test_registry.py b/tests/model_executor/model_loader/test_registry.py index 93a3e34835b5a..639ee6db9270f 100644 --- a/tests/model_executor/model_loader/test_registry.py +++ b/tests/model_executor/model_loader/test_registry.py @@ -4,7 +4,8 @@ import pytest from torch import nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader import (get_model_loader, register_model_loader) from vllm.model_executor.model_loader.base_loader import BaseModelLoader diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 140f00294765d..86139d598582d 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, vllm_topk_softmax) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) -from vllm.model_executor.layers.layernorm import ( - RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, - rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) +from vllm.model_executor.layers.layernorm import (RMSNorm, + dispatch_rocm_rmsnorm_func, + fused_add_rms_norm, rms_norm) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + # Registered subclass for test @CustomOp.register("relu3") @@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): @pytest.mark.parametrize("add_residual", [True, False]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) @pytest.mark.skipif(not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm") -def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str, - use_rocm_aiter_norm: str, monkeypatch): +def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype, + use_rocm_aiter: str, use_rocm_aiter_norm: str, + monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) - rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual) + rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) - if not add_residual: - if current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): - assert rms_norm_func == rocm_aiter_rms_norm - else: - assert rms_norm_func == rms_norm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): - assert rms_norm_func == rocm_aiter_fused_add_rms_norm - else: + should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \ + and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES + + if add_residual and should_use_rocm_aiter: + assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + elif should_use_rocm_aiter: + assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm + elif add_residual: assert rms_norm_func == fused_add_rms_norm + else: + assert rms_norm_func == rms_norm diff --git a/tests/model_executor/test_logits_processor.py b/tests/model_executor/test_logits_processor.py deleted file mode 100644 index 532ebba038d38..0000000000000 --- a/tests/model_executor/test_logits_processor.py +++ /dev/null @@ -1,98 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from unittest.mock import patch - -import pytest -import torch - -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available - - -class MockLogitsProcessor(LogitsProcessor): - - def __init__(self, vocab_size: int, scale: float, - fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size, scale=scale) - self.fake_logits = fake_logits.clone() - - def forward(self, *args, **kwargs): - with patch( - "vllm.model_executor.layers.logits_processor._prune_hidden_states", - lambda x, y: x - ), patch( - "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", - lambda *args, **kwargs: self.fake_logits): - return super().forward(*args, **kwargs) - - -def _prepare_test( - batch_size: int -) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: - vocab_size = 32000 - input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) - fake_logits = torch.full((batch_size, vocab_size), - 1e-2, - dtype=input_tensor.dtype) - logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - return input_tensor, fake_logits, logits_processor - - -RANDOM_SEEDS = list(range(128)) -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_logits_processors(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) - - # This sample logits processor gives infinite score to the i-th token, - # where i is the length of the input sequence. - # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") - return logits - - seq_group_metadata_list = [] - seq_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - logits_processor_output = logits_processor( - lm_head=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - - assert torch.isinf(logits_processor_output[:, 0]).all() - - fake_logits *= logits_processor.scale - torch.testing.assert_close(logits_processor_output[:, 1], - fake_logits[:, 1], - rtol=1e-4, - atol=0.0) diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 0ade75b7e6228..c7b15c6ae1186 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -47,8 +47,8 @@ def test_model_loading_with_params(vllm_runner, monkeypatch): assert model_config.pooler_config.normalize # asserts on the tokenizer loaded - assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5" - assert model_tokenizer.tokenizer.model_max_length == 512 + assert model_config.tokenizer == "BAAI/bge-base-en-v1.5" + assert model_tokenizer.model_max_length == 512 def check_model(model): assert isinstance(model, BertEmbeddingModel) @@ -87,8 +87,8 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): assert model_config.pooler_config.normalize # asserts on the tokenizer loaded - assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-base" - assert model_tokenizer.tokenizer.model_max_length == 512 + assert model_config.tokenizer == "intfloat/multilingual-e5-base" + assert model_tokenizer.model_max_length == 512 def check_model(model): assert isinstance(model, RobertaEmbeddingModel) @@ -116,8 +116,7 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch): output = vllm_model.embed("Write a short story about a robot that" " dreams for the first time.\n") - model_tokenizer = vllm_model.llm.llm_engine.tokenizer - assert model_tokenizer.tokenizer_id == model_name + assert vllm_model.llm.llm_engine.model_config.tokenizer == model_name def check_model(model): assert isinstance(model, RobertaEmbeddingModel) diff --git a/tests/models/language/generation/test_bart.py b/tests/models/language/generation/test_bart.py deleted file mode 100644 index b4c771840196c..0000000000000 --- a/tests/models/language/generation/test_bart.py +++ /dev/null @@ -1,220 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.sequence import SampleLogprobs - -from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, - HfRunner, VllmRunner) -from ....utils import multi_gpu_test -from ...utils import check_logprobs_close - - -def vllm_to_hf_output( - vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, -): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - hf_output_str = output_str + "</s>" - if decoder_prompt_type == DecoderPromptType.NONE: - hf_output_str = "<s>" + hf_output_str - - return output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - prompts: list[ExplicitEncoderDecoderPrompt[str, str]], - decoder_prompt_type: DecoderPromptType, - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -) -> None: - ''' - Test the vLLM BART model for a variety of encoder/decoder input prompts, - by validating it against HuggingFace (HF) BART. - - Arguments: - - * hf_runner: HuggingFace (HF) test model runner - * vllm_runner: vLLM test model runner - * example_encoder_decoder_prompts: test fixture which provides a - dictionary of dummy prompts - * model: the HF ID of the specific BART variant under test - * dtype: the tensor datatype to employ - * max_tokens - * num_logprobs - * decoder_prompt_type: key into the example_encoder_decoder_prompts - dictionary; selects specific encoder/decoder - prompt scenarios to test - - A note on using HF BART as a baseline for validating vLLM BART, - specifically when the decoder prompt is None. - - The HF GenerationMixin's default behavior is to force the first - decoded token to be <BOS> if the prompt does not already contain - <BOS> (this is accomplished using a logit - processor setting.) - - So when we use HF BART as our baseline for comparison, note that - when the user provides a request with a None decoder prompt - (i.e. a singleton encoder prompt, or else an explicit encoder/ - decoder prompt with the decoder sub-prompt set to None), HF and - vLLM handle this in different ways: - - * HF will (1) tokenize the None prompt as an empty token-list, - (2) append <decoder-start-token> to the beginning, yielding - [<decoder-start-token>], (3) pass this token list to the model, and - then (4) after computing logits during prefill, override the model - logits & force <BOS> to be the first generated token. - - * vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder- - start-token to the beginning, yielding [<decoder-start-token><BOS>], - (3) pass these tokens to the model & proceed with generation. - - The net effect is that compared to vLLM, the list of HF *decoded* tokens - will contain one more initial <BOS> than the vLLM generated tokens, - because vLLM's <BOS> token is injected into the prompt rather than into - the generated output. This is in spite of the fact that overall, the - complete sequences (prompt + decoded tokens) produced by vLLM will match - HF. - - So when we use HF decoded token output to validate vLLM's decoded token - output, the testing process must account for the difference in decoded - token sequences between vLLM and HF specifically in the - decoder-prompt-is-None case. - - One option is to disable the logit processor feature that forces the - <BOS> token to be decoded (forced_bos_token_id = None), eliminating - the problem entirely. However this is not "normal" BART usage. - - The other option is - only in the decoder-prompt-is-None case - to - discard the first decoded token from the HF output before comparing it - to vLLM. - - To that end, when testing the scenario where the decoder prompt is None - (and only in that one scenario), this test skips the first HF decoded - token during the process of validating the vLLM decoded output. - ''' - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default). - - # Note: currently encoder/decoder models are only compatible with - # enforce_eager=True. Normally this is not a problem because - # for encoder/decoder models vLLM will - # default to enforce_eager=True if enforce_eager - # is left unspecified. However, the - # VllmRunner test fixture (which wraps around the LLM class) defaults to - # enforce_eager=False (a behavior which a number of already-existing - # decoder-only unit tests expect), so when testing an encoder/decoder - # model we must explicitly specify enforce_eager=True in the VllmRunner - # constructor. - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs) - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) - - -@pytest.mark.parametrize( - "model", - [ - pytest.param("facebook/bart-base", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), - pytest.param("facebook/bart-large-cnn"), - ], -) -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) -def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, - dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - - run_test( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) -def test_models_distributed(hf_runner, vllm_runner, - example_encoder_decoder_prompts, - distributed_executor_backend, model, dtype, - max_tokens, num_logprobs, - decoder_prompt_type) -> None: - run_test( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - ) diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 8a04946b2ffb3..a5aa1e3f49743 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os from typing import Optional import pytest @@ -39,7 +38,7 @@ AITER_MODEL_LIST = [ [ pytest.param( "bigscience/bloom-560m", # bloom - testing alibi slopes - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, pytest.mark.slow_test], ), pytest.param( "openai-community/gpt2", # gpt2 @@ -50,7 +49,10 @@ AITER_MODEL_LIST = [ pytest.param("EleutherAI/pythia-70m"), # gpt_neox pytest.param( "google/gemma-1.1-2b-it", # gemma - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[ + pytest.mark.core_model, pytest.mark.cpu_model, + pytest.mark.slow_test + ], ), pytest.param( "zai-org/chatglm3-6b", # chatglm (text-only) @@ -71,14 +73,17 @@ AITER_MODEL_LIST = [ ), pytest.param( "microsoft/phi-2", # phi - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, pytest.mark.slow_test], ), pytest.param( "Qwen/Qwen-7B-Chat", # qwen (text-only) ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[ + pytest.mark.core_model, pytest.mark.cpu_model, + pytest.mark.slow_test + ], ), pytest.param( "Qwen/Qwen3-8B", # qwen (text-only) @@ -93,15 +98,16 @@ AITER_MODEL_LIST = [ "allenai/OLMoE-1B-7B-0924-Instruct", marks=[pytest.mark.cpu_model], ), - pytest.param("swiss-ai/Apertus-8B"), # apertus + pytest.param("swiss-ai/Apertus-8B-2509"), # apertus ]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +@pytest.mark.parametrize("use_prompt_embeds", [True, False]) def test_models(hf_runner, vllm_runner, example_prompts, model: str, max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, - monkeypatch) -> None: + use_prompt_embeds: bool, monkeypatch) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -119,7 +125,11 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") - use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0" + # Note: can be removed when + # https://github.com/vllm-project/vllm/pull/24278 finished + if current_platform.is_cpu() and use_prompt_embeds: + pytest.skip("Skipping use_prompt_embeds=True with " + "V1-only CPU backend.") with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 3cacbdcfbe86e..206ad1352e06e 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -25,8 +25,7 @@ SSM_MODELS = [ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # skipping until vLLM implementation issues are resolved - # "pfnet/plamo-2-1b", + "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", "ibm-granite/granite-4.0-tiny-preview", @@ -34,20 +33,10 @@ HYBRID_MODELS = [ "LiquidAI/LFM2-1.2B", ] -HF_UNSUPPORTED_MODELS = [ - # The HF transformers implementation of - # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test - # doesn't compare vLLM output with HF output. - # See https://github.com/huggingface/transformers/pull/35943 - "yujiepan/mamba2-codestral-v0.1-tiny-random", - # transformers 4.55 is still producing garbage for this model - # TODO(tdoublep): follow-up on transformers side - "ibm-granite/granite-4.0-tiny-preview" -] - V1_SUPPORTED_MODELS = [ "state-spaces/mamba-130m-hf", "ai21labs/Jamba-tiny-dev", + "pfnet/plamo-2-1b", "yujiepan/mamba2-codestral-v0.1-tiny-random", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", @@ -58,6 +47,7 @@ V1_SUPPORTED_MODELS = [ FULL_CUDA_GRAPH_MODELS = [ "ai21labs/Jamba-tiny-dev", + "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", ] @@ -90,20 +80,13 @@ def test_models( try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") - hf_version_check = model_info.check_transformers_version( - on_fail="return") + model_info.check_transformers_version(on_fail="skip") except ValueError: - hf_version_check = None - - if hf_version_check is not None: - print(f"Skipping transformers comparison because: {hf_version_check}") + pass with hf_runner(model) as hf_model: - if model not in HF_UNSUPPORTED_MODELS and hf_version_check is None: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - else: - hf_outputs = None + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "0") @@ -121,7 +104,7 @@ def test_models( else: vllm_v1_outputs = None - if hf_outputs is not None and vllm_v0_outputs is not None: + if vllm_v0_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v0_outputs, @@ -130,12 +113,10 @@ def test_models( ) if model in V1_SUPPORTED_MODELS: - ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs - assert ref_outputs is not None check_logprobs_close( - outputs_0_lst=ref_outputs, + outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v1_outputs, - name_0="hf" if hf_outputs is not None else "vllm-v0", + name_0="hf", name_1="vllm-v1", ) @@ -320,7 +301,7 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( finished_requests_ids is larger than the maximum mamba block capacity. This could generally happen due to the fact that hybrid does support - statelessness mechanism where it can cleanup new incoming requests in + statelessness mechanism where it can clean up new incoming requests in a single step. """ try: @@ -341,7 +322,7 @@ def test_state_cleanup( This test is for verifying that the Hybrid state is cleaned up between steps. - If its not cleaned, an error would be expected. + If it's not cleaned, an error would be expected. """ try: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: @@ -402,11 +383,8 @@ def test_full_cuda_graph( pass with hf_runner(model) as hf_model: - if model not in HF_UNSUPPORTED_MODELS: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - else: - hf_outputs = None + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "0") @@ -421,7 +399,7 @@ def test_full_cuda_graph( vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - if hf_outputs is not None and vllm_v0_outputs is not None: + if vllm_v0_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v0_outputs, @@ -429,12 +407,10 @@ def test_full_cuda_graph( name_1="vllm-v0", ) - ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs - assert ref_outputs is not None check_logprobs_close( - outputs_0_lst=ref_outputs, + outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v1_outputs, - name_0="hf" if hf_outputs is not None else "vllm-v0", + name_0="hf", name_1="vllm-v1", ) @@ -442,7 +418,9 @@ def test_full_cuda_graph( @pytest.mark.parametrize("model", FP32_STATE_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_fp32_state( +@pytest.mark.parametrize("cache_dtype_param", + ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) +def test_fp32_cache_state( hf_runner, vllm_runner, example_prompts, @@ -450,6 +428,7 @@ def test_fp32_state( model: str, max_tokens: int, num_logprobs: int, + cache_dtype_param: str, ) -> None: try: @@ -460,38 +439,33 @@ def test_fp32_state( pass with hf_runner(model) as hf_model: - if model not in HF_UNSUPPORTED_MODELS: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - else: - hf_outputs = None + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "0") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32") as vllm_model: + **{cache_dtype_param: "float32"}) as vllm_model: vllm_v0_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - mamba_ssm_cache_dtype="float32") as vllm_model: + **{cache_dtype_param: "float32"}) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - if hf_outputs is not None: - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_v0_outputs, - name_0="hf", - name_1="vllm-v0", - ) - - ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs check_logprobs_close( - outputs_0_lst=ref_outputs, + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v0_outputs, + name_0="hf", + name_1="vllm-v0", + ) + + check_logprobs_close( + outputs_0_lst=hf_outputs, outputs_1_lst=vllm_v1_outputs, - name_0="hf" if hf_outputs is not None else "vllm-v0", + name_0="hf", name_1="vllm-v1", ) diff --git a/tests/models/language/generation/test_mbart.py b/tests/models/language/generation/test_mbart.py deleted file mode 100644 index 854a72713943b..0000000000000 --- a/tests/models/language/generation/test_mbart.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.sequence import SampleLogprobs - -from ....conftest import DecoderPromptType, HfRunner, VllmRunner -from ...utils import check_logprobs_close - - -def vllm_to_hf_output( - vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, -): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - hf_output_str = output_str + "</s>" - return output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - prompts: list[dict[str, str]], - decoder_prompt_type: DecoderPromptType, - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -) -> None: - ''' - Test the vLLM mBART model by validating it against HuggingFace (HF). - (Docstring content is omitted for brevity) - ''' - - vllm_prompts = prompts - if decoder_prompt_type == DecoderPromptType.NONE: - vllm_prompts = [{ - "encoder_prompt": p['encoder_prompt'], - "decoder_prompt": "" - } for p in prompts] - - vllm_kwargs = { - "hf_overrides": { - "architectures": ["MBartForConditionalGeneration"] - } - } - - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, - **vllm_kwargs) as vllm_model: # type: ignore - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - vllm_prompts, max_tokens, num_logprobs) - - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_kwargs["decoder_start_token_id"] = ( - hf_model.tokenizer.lang_code_to_id["ro_RO"]) - - hf_outputs = ( - hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, # HF runner still uses the original prompts - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - hf_skip_tokens = 0 - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) - - -@pytest.mark.parametrize( - "model", - [pytest.param("facebook/mbart-large-en-ro")], -) -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) -def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, - dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - - run_test( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) diff --git a/tests/encoder_decoder/__init__.py b/tests/models/language/generation_ppl_test/__init__.py similarity index 100% rename from tests/encoder_decoder/__init__.py rename to tests/models/language/generation_ppl_test/__init__.py diff --git a/tests/models/language/generation_ppl_test/ppl_utils.py b/tests/models/language/generation_ppl_test/ppl_utils.py new file mode 100644 index 0000000000000..6225bbe3377bd --- /dev/null +++ b/tests/models/language/generation_ppl_test/ppl_utils.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/docs/transformers/perplexity +from typing import Optional, cast + +import pytest +import torch +from datasets import load_dataset + +import tests.ci_envs as ci_envs +from tests.models.utils import (GenerateModelInfo, + TokensTextLogprobsPromptLogprobs) +from vllm.logprobs import Logprob + +# See #24485 +PPL_TOL = 0.01 +MAX_LENGTH = 1024 + + +@torch.inference_mode +def wikitext_ppl_test(hf_runner, + vllm_runner, + model_info: GenerateModelInfo, + max_length=MAX_LENGTH, + vllm_extra_kwargs=None, + atol=PPL_TOL): + + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: + pytest.skip("Skipping test.") + + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + # Allow vllm to test using the given dtype, such as float32 + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + + # Allow vllm to test using hf_overrides + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"][ + "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + with vllm_runner(model_info.name, + gpu_memory_utilization=0.7, + max_model_len=max_length, + max_num_seqs=1, + enforce_eager=True, + **vllm_extra_kwargs) as vllm_model: + # Use max_num_seqs=1 to avoid OOM, + # and avoid batch different requests together. + + model_config = vllm_model.llm.llm_engine.model_config + + # Confirm whether vllm is using the correct architecture + if model_info.architecture: + assert (model_info.architecture in model_config.architectures) + + max_length = min(model_config.max_model_len - 1, max_length) + stride = max_length + + tokenizer = vllm_model.llm.get_tokenizer() + tokens = tokenizer.encode("\n\n".join(dataset["text"])) + n_tokens = len(tokens) + + chunks = [] + for begin_loc in range(0, n_tokens, stride): + end_loc = min(begin_loc + max_length, n_tokens) + chunks.append(tokens[begin_loc:end_loc]) + + outputs = vllm_model.generate_greedy_logprobs(prompts=chunks, + max_tokens=1, + num_logprobs=None, + num_prompt_logprobs=0, + use_tqdm=False) + nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") + n_tokens = 0 + for output in outputs: + output = cast(TokensTextLogprobsPromptLogprobs, output) + token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + + assert token_datas[0] is None + token_log_probs = [] + for token_data in token_datas[1:]: + assert token_data is not None + assert len(token_data) == 1 + token_log_prob = list(token_data.values())[0].logprob + token_log_probs.append(token_log_prob) + + neg_log_likelihood = -torch.tensor( + token_log_probs, dtype=torch.float32, device="cpu").sum() + nll_sum += neg_log_likelihood + n_tokens += len(token_log_probs) + vllm_ppl = float(torch.exp(nll_sum / n_tokens)) + vllm_dtype = model_config.dtype + head_dtype = model_config.head_dtype + + # Accelerate ppl test by setting Transformers ppl score to a constant + if model_info.hf_ppl is None: + with hf_runner( + model_info.name, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, + ) as hf_model: + nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") + n_tokens = 0 + for chunk in chunks: + inputs = hf_model.wrap_device( + {"input_ids": torch.tensor([chunk])}) + input_ids = inputs["input_ids"] + outputs = hf_model.model(input_ids, labels=input_ids) + neg_log_likelihood = outputs.loss + + neg_log_likelihood = neg_log_likelihood.to(torch.float32).cpu() + + num_loss_tokens = len(chunk) - 1 + nll_sum += neg_log_likelihood * num_loss_tokens + n_tokens += num_loss_tokens + + hf_ppl = float(torch.exp(nll_sum / n_tokens)) + hf_dtype = next(hf_model.model.parameters()).dtype + else: + hf_ppl = model_info.hf_ppl + hf_dtype = "Constant" + + differ = (vllm_ppl - hf_ppl) / hf_ppl + print("Model:", model_info.name) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_ppl) + print("Transformers:", hf_dtype, hf_ppl) + print("Difference (%):", differ * 100) + + # PPL the smaller, the better + # We are not concerned that the vllm PPL is less than Transformers, + # so we only perform one-sided testing. + assert differ < atol diff --git a/tests/models/language/generation_ppl_test/test_gemma.py b/tests/models/language/generation_ppl_test/test_gemma.py new file mode 100644 index 0000000000000..5324de143d674 --- /dev/null +++ b/tests/models/language/generation_ppl_test/test_gemma.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import GenerateModelInfo + +from .ppl_utils import wikitext_ppl_test + +MODELS = [ + GenerateModelInfo("google/gemma-2b"), + GenerateModelInfo("google/gemma-2-2b"), + GenerateModelInfo("google/gemma-3-4b-it"), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): + wikitext_ppl_test(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/generation_ppl_test/test_gpt.py b/tests/models/language/generation_ppl_test/test_gpt.py new file mode 100644 index 0000000000000..f3f9e55a24234 --- /dev/null +++ b/tests/models/language/generation_ppl_test/test_gpt.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from tests.models.utils import GenerateModelInfo + +from .ppl_utils import wikitext_ppl_test + +MODELS = [GenerateModelInfo("openai-community/gpt2-large")] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): + wikitext_ppl_test(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/generation_ppl_test/test_qwen.py b/tests/models/language/generation_ppl_test/test_qwen.py new file mode 100644 index 0000000000000..0d3127cbaac47 --- /dev/null +++ b/tests/models/language/generation_ppl_test/test_qwen.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from tests.models.utils import GenerateModelInfo + +from .ppl_utils import wikitext_ppl_test + +MODELS = [ + GenerateModelInfo("Qwen/Qwen3-0.6B"), + GenerateModelInfo("Qwen/Qwen3-0.6B-FP8"), + # transformers: + # Loading a GPTQ quantized model requires optimum, gptqmodel + # GenerateModelInfo("Qwen/Qwen3-0.6B-GPTQ-Int8"), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo): + wikitext_ppl_test(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index a74ad2aa25972..86751e0a4d5f4 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -35,10 +35,7 @@ def correctness_test_embed_models(hf_runner, example_prompts, vllm_extra_kwargs=None, hf_model_callback=None): - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. - pytest.skip("Skipping test.") + pytest.skip("Debug only, ci prefers to use mteb test.") # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" @@ -62,7 +59,7 @@ def correctness_test_embed_models(hf_runner, with hf_runner( model_info.name, - dtype="float32", + dtype=model_info.hf_dtype, is_sentence_transformer=True, ) as hf_model: diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index c71fa96275335..8e398830d39df 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -11,7 +11,10 @@ from vllm.platforms import current_platform "model", [ pytest.param("jason9693/Qwen2.5-1.5B-apeach", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + marks=[ + pytest.mark.core_model, pytest.mark.cpu_model, + pytest.mark.slow_test + ]), ], ) @pytest.mark.parametrize("dtype", diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index f918b2b91bcc3..d61ac08475e3c 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -7,7 +7,7 @@ import pytest from vllm.config import PoolerConfig from vllm.platforms import current_platform -from ...utils import check_embeddings_close, check_transformers_version +from ...utils import check_embeddings_close @pytest.mark.parametrize( @@ -19,7 +19,7 @@ from ...utils import check_embeddings_close, check_transformers_version # model code with bidirectional attention. # [Decoder-only] pytest.param("BAAI/bge-multilingual-gemma2", - marks=[pytest.mark.core_model]), + marks=[pytest.mark.core_model, pytest.mark.slow_test]), pytest.param( "intfloat/e5-mistral-7b-instruct", # CPU v1 doesn't support sliding window @@ -27,12 +27,20 @@ from ...utils import check_embeddings_close, check_transformers_version pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model]), # [Encoder-only] - pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), + pytest.param( + "BAAI/bge-base-en-v1.5", + marks=[ + pytest.mark.core_model, pytest.mark.cpu_model, + pytest.mark.slow_test + ], + ), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("intfloat/multilingual-e5-small"), - pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2"), + pytest.param( + "sentence-transformers/stsb-roberta-base-v2", + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), ], ) def test_models( @@ -42,8 +50,6 @@ def test_models( model, monkeypatch, ) -> None: - if model == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": - check_transformers_version(model, max_transformers_version="4.53.2") if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm(): # ROCm Triton FA does not currently support sliding window attention diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py new file mode 100644 index 0000000000000..166b953de43e7 --- /dev/null +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.platforms import current_platform + + +def test_idefics_multimodal( + vllm_runner, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + with vllm_runner(model_name="HuggingFaceM4/Idefics3-8B-Llama3", + runner="pooling", + task="classify", + convert="classify", + load_format="dummy", + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16") as vllm_model: + llm = vllm_model.get_llm() + outputs = llm.classify(prompts) + for output in outputs: + assert len(output.outputs.probs) == 2 + + +def update_config(config): + config.text_config.update({ + "architectures": ["Gemma3ForSequenceClassification"], + "classifier_from_token": ["A", "B", "C", "D", "E"], + "method": + "no_post_processing", + "id2label": { + "A": "Chair", + "B": "Couch", + "C": "Table", + "D": "Bed", + "E": "Cupboard" + }, + }) + return config + + +def test_gemma_multimodal( + vllm_runner, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + messages = [{ + "role": + "system", + "content": + """ + You are a helpful assistant. You will be given a product description + which may also include an image. Classify the following product into + one of the categories: + + A = chair + B = couch + C = table + D = bed + E = cupboard + + You'll answer with exactly one letter (A, B, C, D, or E).""" + }, { + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": + "https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg" + } + }, { + "type": "text", + "text": "A fine 19th century piece of furniture." + }] + }] + + with vllm_runner(model_name="google/gemma-3-4b-it", + runner="pooling", + task="classify", + convert="classify", + load_format="auto", + hf_overrides=update_config, + override_pooler_config={"pooling_type": "LAST"}, + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16") as vllm_model: + + llm = vllm_model.get_llm() + prompts = llm.preprocess_chat(messages) + + result = llm.classify(prompts) + assert result[0].outputs.probs[0] > 0.95 + assert all(c < 0.05 for c in result[0].outputs.probs[1:]) \ No newline at end of file diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py new file mode 100644 index 0000000000000..fd5e48a8b1449 --- /dev/null +++ b/tests/models/language/pooling/test_token_classification.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForTokenClassification + +from tests.models.utils import softmax + + +@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"]) +# The float32 is required for this tiny model to pass the test. +@pytest.mark.parametrize("dtype", ["float"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForTokenClassification) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + hf_outputs.append(softmax(output.logits[0])) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).cpu().float() + vllm_output = torch.tensor(vllm_output).cpu().float() + assert torch.allclose(hf_output, vllm_output, 1e-2) diff --git a/tests/metrics/__init__.py b/tests/models/language/pooling_mteb_test/__init__.py similarity index 100% rename from tests/metrics/__init__.py rename to tests/models/language/pooling_mteb_test/__init__.py diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling_mteb_test/mteb_utils.py similarity index 67% rename from tests/models/language/pooling/mteb_utils.py rename to tests/models/language/pooling_mteb_test/mteb_utils.py index 640858125bfca..7b3c02fbbd9f8 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_utils.py @@ -9,8 +9,11 @@ import mteb import numpy as np import pytest import requests +import torch -from tests.models.utils import EmbedModelInfo, RerankModelInfo +import tests.ci_envs as ci_envs +from tests.models.utils import (EmbedModelInfo, RerankModelInfo, + check_embeddings_close) # Most embedding models on the STS12 task (See #17175): # - Model implementation and minor changes in tensor dtype @@ -18,7 +21,7 @@ from tests.models.utils import EmbedModelInfo, RerankModelInfo # - Different model results in differences more than 1e-3 # 1e-4 is a good tolerance threshold MTEB_EMBED_TASKS = ["STS12"] -MTEB_EMBED_TOL = 0.02 +MTEB_EMBED_TOL = 1e-4 # See #19344 MTEB_RERANK_TASKS = ["NFCorpus"] @@ -163,18 +166,30 @@ def mteb_test_embed_models(hf_runner, model_info: EmbedModelInfo, vllm_extra_kwargs=None, hf_model_callback=None, - atol=MTEB_RERANK_TOL): - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. + atol=MTEB_EMBED_TOL): + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: pytest.skip("Skipping test.") - vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = model_info.dtype + # Test embed_dims, isnan and whether to use normalize + example_prompts = ["The chef prepared a delicious meal." * 1000] + # Allow vllm to test using the given dtype, such as float32 + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + + # Allow vllm to test using hf_overrides if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"][ + "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -183,31 +198,63 @@ def mteb_test_embed_models(hf_runner, model_config = vllm_model.llm.llm_engine.model_config + # Confirm whether vllm is using the correct architecture if model_info.architecture: assert model_info.architecture in model_config.architectures + + # Confirm whether vllm uses the correct default_pooling_type, which + # relates to whether chunked prefill and prefix caching are enabled assert (model_config._model_info.default_pooling_type == model_info.default_pooling_type) vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype + head_dtype = model_config.head_dtype - with hf_runner(model_info.name, - is_sentence_transformer=True, - dtype="float32") as hf_model: + # Test embed_dims, isnan and whether to use normalize + vllm_outputs = vllm_model.embed(example_prompts, + truncate_prompt_tokens=-1) + assert not torch.any(torch.isnan(torch.tensor(vllm_outputs))) - if hf_model_callback is not None: - hf_model_callback(hf_model) + # Accelerate mteb test by setting + # SentenceTransformers mteb score to a constant + if model_info.mteb_score is None: + with hf_runner( + model_info.name, + is_sentence_transformer=True, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, + ) as hf_model: - st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) - st_dtype = next(hf_model.model.parameters()).dtype + # e.g. setting default parameters for the encode method of hf_runner + if hf_model_callback is not None: + hf_model_callback(hf_model) + + st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) + st_dtype = next(hf_model.model.parameters()).dtype + + # Test embed_dims and whether to use normalize + hf_outputs = hf_model.encode(example_prompts) + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) + else: + st_main_score = model_info.mteb_score + st_dtype = "Constant" print("Model:", model_info.name) - print("VLLM:", vllm_dtype, vllm_main_score) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", + vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=atol) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < atol def run_mteb_rerank(cross_encoder, tasks, languages): @@ -243,9 +290,12 @@ def run_mteb_rerank(cross_encoder, tasks, languages): return main_score -def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None): +def mteb_test_rerank_models_hf(hf_runner, + model_name, + hf_dtype="float32", + hf_model_callback=None): with hf_runner(model_name, is_cross_encoder=True, - dtype="float32") as hf_model: + dtype=hf_dtype) as hf_model: original_predict = hf_model.predict @@ -279,17 +329,26 @@ def mteb_test_rerank_models(hf_runner, hf_model_callback=None, vllm_mteb_encoder=VllmMtebEncoder, atol=MTEB_RERANK_TOL): - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: pytest.skip("Skipping test.") + # Allow vllm to test using the given dtype, such as float32 vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = model_info.dtype + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + # Allow vllm to test using hf_overrides if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"][ + "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -299,9 +358,15 @@ def mteb_test_rerank_models(hf_runner, model_config = vllm_model.llm.llm_engine.model_config + # Confirm whether vllm is using the correct architecture if model_info.architecture: assert (model_info.architecture in model_config.architectures) + + # Score API is only enabled for num_labels == 1 assert model_config.hf_config.num_labels == 1 + + # Confirm whether vllm uses the correct default_pooling_type, which + # relates to whether chunked prefill and prefix caching are enabled assert (model_config._model_info.default_pooling_type == model_info.default_pooling_type) @@ -309,13 +374,23 @@ def mteb_test_rerank_models(hf_runner, tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS) vllm_dtype = model_config.dtype + head_dtype = model_config.head_dtype - st_main_score, st_dtype = mteb_test_rerank_models_hf( - hf_runner, model_info.name, hf_model_callback) + # Accelerate mteb test by setting + # SentenceTransformers mteb score to a constant + if model_info.mteb_score is None: + st_main_score, st_dtype = mteb_test_rerank_models_hf( + hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback) + else: + st_main_score = model_info.mteb_score + st_dtype = "Constant" print("Model:", model_info.name) - print("VLLM:", vllm_dtype, vllm_main_score) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", + vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=atol) + # We are not concerned that the vllm mteb results are better + # than SentenceTransformers, so we only perform one-sided testing. + assert st_main_score - vllm_main_score < atol diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling_mteb_test/test_baai.py similarity index 89% rename from tests/models/language/pooling/test_baai.py rename to tests/models/language/pooling_mteb_test/test_baai.py index 6fbe0e82d7f8a..e131c9b1038de 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling_mteb_test/test_baai.py @@ -2,16 +2,19 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, - EmbedModelInfo, LASTPoolingEmbedModelInfo, - RerankModelInfo) -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import (CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, EmbedModelInfo, + LASTPoolingEmbedModelInfo, RerankModelInfo) + from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel CLSPoolingEmbedModelInfo("BAAI/bge-base-en", architecture="BertModel", + mteb_score=0.779336792, enable_test=True), CLSPoolingEmbedModelInfo("BAAI/bge-base-zh", architecture="BertModel", @@ -52,10 +55,12 @@ MODELS = [ ########## XLMRobertaModel CLSPoolingEmbedModelInfo("BAAI/bge-m3", architecture="XLMRobertaModel", + mteb_score=0.787343078, enable_test=True), ########## Qwen2Model LASTPoolingEmbedModelInfo("BAAI/bge-code-v1", architecture="Qwen2Model", + mteb_score=0.75724465, dtype="float32", enable_test=True), ] @@ -65,6 +70,7 @@ RERANK_MODELS = [ CLSPoolingRerankModelInfo( "BAAI/bge-reranker-base", architecture="XLMRobertaForSequenceClassification", + mteb_score=0.32398, enable_test=True), CLSPoolingRerankModelInfo( "BAAI/bge-reranker-large", diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py similarity index 95% rename from tests/models/language/pooling/test_bge_reranker_v2_gemma.py rename to tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py index f473e0ba01ffa..1eca2a2c0abd9 100644 --- a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py @@ -7,13 +7,14 @@ import pytest import torch from tests.conftest import HfRunner - -from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo -from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models +from tests.models.language.pooling_mteb_test.mteb_utils import ( + VllmMtebEncoder, mteb_test_rerank_models) +from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo RERANK_MODELS = [ LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", architecture="GemmaForSequenceClassification", + mteb_score=0.33757, hf_overrides={ "architectures": ["GemmaForSequenceClassification"], @@ -104,7 +105,6 @@ class GemmaMtebEncoder(VllmMtebEncoder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.prompt = PROMPT self.query_template = "A: {query}\n" self.document_template = "B: {doc}\n{prompt}" @@ -119,7 +119,7 @@ class GemmaMtebEncoder(VllmMtebEncoder): _sentences = [] for query, corpus, prompt in sentences: query = self.query_template.format(query=query) - corpus = self.document_template.format(doc=corpus, prompt=prompt) + corpus = self.document_template.format(doc=corpus, prompt=PROMPT) _sentences.append((query, corpus, prompt)) return super().predict(_sentences, *args, **kwargs) diff --git a/tests/models/language/pooling/test_cross_encoder.py b/tests/models/language/pooling_mteb_test/test_cross_encoder.py similarity index 75% rename from tests/models/language/pooling/test_cross_encoder.py rename to tests/models/language/pooling_mteb_test/test_cross_encoder.py index 8c1bc5779b8a1..ad320fae0c85a 100644 --- a/tests/models/language/pooling/test_cross_encoder.py +++ b/tests/models/language/pooling_mteb_test/test_cross_encoder.py @@ -2,14 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo, - RerankModelInfo) +from tests.models.utils import (CLSPoolingRerankModelInfo, + LASTPoolingRerankModelInfo, RerankModelInfo) + from .mteb_utils import mteb_test_rerank_models RERANK_MODELS = [ CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", + mteb_score=0.32898, architecture="BertForSequenceClassification"), LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", + mteb_score=0.25736, architecture="Qwen3ForSequenceClassification") ] diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling_mteb_test/test_gte.py similarity index 80% rename from tests/models/language/pooling/test_gte.py rename to tests/models/language/pooling_mteb_test/test_gte.py index 9911620c018ef..9ae43fd05bf78 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling_mteb_test/test_gte.py @@ -3,15 +3,18 @@ import pytest -from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, - EmbedModelInfo, LASTPoolingEmbedModelInfo, - RerankModelInfo, check_transformers_version) -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import (CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, EmbedModelInfo, + LASTPoolingEmbedModelInfo, RerankModelInfo) + from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel CLSPoolingEmbedModelInfo("thenlper/gte-large", + mteb_score=0.76807651, architecture="BertModel", enable_test=True), CLSPoolingEmbedModelInfo("thenlper/gte-base", @@ -30,28 +33,37 @@ MODELS = [ architecture="BertModel", enable_test=False), ########### NewModel + # These three architectures are almost the same, but not exactly the same. + # For example, + # - whether to use token_type_embeddings + # - whether to use context expansion + # So only test one (the most widely used) model CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", architecture="GteNewModel", + mteb_score=0.775074696, hf_overrides={"architectures": ["GteNewModel"]}, enable_test=True), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", architecture="GteNewModel", hf_overrides={"architectures": ["GteNewModel"]}, - enable_test=True), + enable_test=False), CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", architecture="GteNewModel", hf_overrides={"architectures": ["GteNewModel"]}, - enable_test=True), + enable_test=False), ########### Qwen2ForCausalLM LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", + mteb_score=0.758473459018872, architecture="Qwen2ForCausalLM", enable_test=True), ########## ModernBertModel CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-modernbert-base", + mteb_score=0.748193353, architecture="ModernBertModel", enable_test=True), ########## Qwen3ForCausalLM LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", + mteb_score=0.771163695, architecture="Qwen3ForCausalLM", dtype="float32", enable_test=True), @@ -65,10 +77,12 @@ RERANK_MODELS = [ CLSPoolingRerankModelInfo( # classifier_pooling: mean "Alibaba-NLP/gte-reranker-modernbert-base", + mteb_score=0.33386, architecture="ModernBertForSequenceClassification", enable_test=True), CLSPoolingRerankModelInfo( "Alibaba-NLP/gte-multilingual-reranker-base", + mteb_score=0.33062, architecture="GteNewForSequenceClassification", hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, enable_test=True), @@ -78,10 +92,6 @@ RERANK_MODELS = [ @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: - if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": - check_transformers_version(model_info.name, - max_transformers_version="4.53.2") - mteb_test_embed_models(hf_runner, vllm_runner, model_info) @@ -89,10 +99,6 @@ def test_embed_models_mteb(hf_runner, vllm_runner, def test_embed_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts) -> None: - if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": - check_transformers_version(model_info.name, - max_transformers_version="4.53.2") - correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling/test_intfloat.py b/tests/models/language/pooling_mteb_test/test_intfloat.py similarity index 85% rename from tests/models/language/pooling/test_intfloat.py rename to tests/models/language/pooling_mteb_test/test_intfloat.py index 6cae53a660ad8..0d6026898ad4a 100644 --- a/tests/models/language/pooling/test_intfloat.py +++ b/tests/models/language/pooling_mteb_test/test_intfloat.py @@ -2,14 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo + from .mteb_utils import mteb_test_embed_models MODELS = [ ########## BertModel CLSPoolingEmbedModelInfo("intfloat/e5-small", architecture="BertModel", + mteb_score=0.742285423, enable_test=True), CLSPoolingEmbedModelInfo("intfloat/e5-base", architecture="BertModel", @@ -23,6 +26,7 @@ MODELS = [ ########## XLMRobertaModel CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-base", architecture="XLMRobertaModel", + mteb_score=0.779325955, enable_test=True), CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large", architecture="XLMRobertaModel", @@ -36,7 +40,7 @@ MODELS = [ @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: - mteb_test_embed_models(hf_runner, vllm_runner, model_info, atol=0.02) + mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling_mteb_test/test_jina.py similarity index 90% rename from tests/models/language/pooling/test_jina.py rename to tests/models/language/pooling_mteb_test/test_jina.py index 37c5bdc97dd98..0a77a78bb31b6 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling_mteb_test/test_jina.py @@ -4,16 +4,18 @@ from functools import partial import pytest +from tests.models.language.pooling.embed_utils import ( + check_embeddings_close, correctness_test_embed_models, matryoshka_fy) +from tests.models.utils import (CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, EmbedModelInfo, + RerankModelInfo) from vllm import PoolingParams -from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo, - EmbedModelInfo, RerankModelInfo) -from .embed_utils import (check_embeddings_close, - correctness_test_embed_models, matryoshka_fy) from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models EMBEDDING_MODELS = [ CLSPoolingEmbedModelInfo("jinaai/jina-embeddings-v3", + mteb_score=0.824413164, architecture="XLMRobertaModel", is_matryoshka=True) ] @@ -21,6 +23,7 @@ EMBEDDING_MODELS = [ RERANK_MODELS = [ CLSPoolingRerankModelInfo( "jinaai/jina-reranker-v2-base-multilingual", + mteb_score=0.33643, architecture="XLMRobertaForSequenceClassification") ] diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py similarity index 96% rename from tests/models/language/pooling/test_mxbai_rerank.py rename to tests/models/language/pooling_mteb_test/test_mxbai_rerank.py index 73823deeff4e0..05ebb4ec4d3f5 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py @@ -6,8 +6,8 @@ import pytest import torch from tests.conftest import HfRunner +from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo -from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models mxbai_rerank_hf_overrides = { @@ -20,6 +20,7 @@ RERANK_MODELS = [ LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", architecture="Qwen2ForSequenceClassification", hf_overrides=mxbai_rerank_hf_overrides, + mteb_score=0.273, enable_test=True), LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", architecture="Qwen2ForSequenceClassification", diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling_mteb_test/test_nomic.py similarity index 84% rename from tests/models/language/pooling/test_nomic.py rename to tests/models/language/pooling_mteb_test/test_nomic.py index 2d05958e9bcda..61512fd0dff18 100644 --- a/tests/models/language/pooling/test_nomic.py +++ b/tests/models/language/pooling_mteb_test/test_nomic.py @@ -3,13 +3,16 @@ import pytest -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo + from .mteb_utils import mteb_test_embed_models MODELS = [ CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1", architecture="NomicBertModel", + mteb_score=0.737568559, enable_test=True), CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", architecture="NomicBertModel", @@ -19,6 +22,7 @@ MODELS = [ enable_test=False), CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", architecture="NomicBertModel", + mteb_score=0.715488912, enable_test=True) ] diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py similarity index 96% rename from tests/models/language/pooling/test_qwen3_reranker.py rename to tests/models/language/pooling_mteb_test/test_qwen3_reranker.py index 5dd2d9eae9115..65403081dc0f8 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py @@ -6,9 +6,9 @@ import pytest import torch from tests.conftest import HfRunner +from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo from tests.utils import multi_gpu_test -from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo from .mteb_utils import mteb_test_rerank_models qwen3_reranker_hf_overrides = { @@ -20,6 +20,7 @@ qwen3_reranker_hf_overrides = { RERANK_MODELS = [ LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", architecture="Qwen3ForSequenceClassification", + mteb_score=0.25736, hf_overrides=qwen3_reranker_hf_overrides, enable_test=True), LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py similarity index 83% rename from tests/models/language/pooling/test_snowflake_arctic_embed.py rename to tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py index c22c78592e535..91bad2c4e42fc 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py @@ -3,14 +3,17 @@ import pytest -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo -from .embed_utils import correctness_test_embed_models +from tests.models.language.pooling.embed_utils import ( + correctness_test_embed_models) +from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo + from .mteb_utils import mteb_test_embed_models MODELS = [ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", is_matryoshka=False, architecture="BertModel", + mteb_score=0.714927797, enable_test=True), CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-s", is_matryoshka=False, @@ -23,6 +26,7 @@ MODELS = [ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", is_matryoshka=False, architecture="NomicBertModel", + mteb_score=0.681146831, enable_test=True), CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l", is_matryoshka=False, @@ -31,14 +35,17 @@ MODELS = [ CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", is_matryoshka=True, architecture="BertModel", + mteb_score=0.649088363, enable_test=True), CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", is_matryoshka=True, architecture="XLMRobertaModel", + mteb_score=0.712258299, enable_test=True), CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", is_matryoshka=True, architecture="GteModel", + mteb_score=0.706622444, enable_test=True), ] @@ -46,7 +53,7 @@ MODELS = [ @pytest.mark.parametrize("model_info", MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: - mteb_test_embed_models(hf_runner, vllm_runner, model_info, atol=0.02) + mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling_mteb_test/test_st_projector.py similarity index 60% rename from tests/models/language/pooling/test_st_projector.py rename to tests/models/language/pooling_mteb_test/test_st_projector.py index 51ddbcc5ab249..bd493e7e2ba09 100644 --- a/tests/models/language/pooling/test_st_projector.py +++ b/tests/models/language/pooling_mteb_test/test_st_projector.py @@ -2,7 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo +from tests.models.utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo, + LASTPoolingEmbedModelInfo) + from .mteb_utils import mteb_test_embed_models # ST models with projector (Dense) layers @@ -10,8 +12,13 @@ ST_PROJECTOR_MODELS = [ CLSPoolingEmbedModelInfo( "TencentBAC/Conan-embedding-v1", architecture="BertModel", + mteb_score=0.688611955, enable_test=True, ), + LASTPoolingEmbedModelInfo("google/embeddinggemma-300m", + architecture="Gemma3TextModel", + mteb_score=0.7473819294684156, + enable_test=True) ] diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index d61b182761e44..79f9d607f3386 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -10,7 +10,7 @@ from pathlib import PosixPath import pytest from transformers import (AutoModel, AutoModelForImageTextToText, - AutoModelForTextToWaveform, AutoModelForVision2Seq) + AutoModelForTextToWaveform) from vllm.platforms import current_platform from vllm.utils import identity @@ -137,7 +137,7 @@ VLM_TEST_SETTINGS = { video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501 max_model_len=4096, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], @@ -502,7 +502,7 @@ VLM_TEST_SETTINGS = { num_video_frames=16, max_model_len=16384, hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, custom_test_opts=[CustomTestOptions( inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs( @@ -518,7 +518,7 @@ VLM_TEST_SETTINGS = { num_video_frames=16, max_model_len=4096, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, ), "mantis": VLMTestInfo( @@ -680,7 +680,7 @@ VLM_TEST_SETTINGS = { multi_image_prompt="Picture 1: <vlm_image>\nPicture 2: <vlm_image>\nDescribe these two images with one paragraph respectively.", # noqa: E501 max_model_len=4096, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.cpu_model], @@ -784,7 +784,7 @@ VLM_TEST_SETTINGS = { test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=16384, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, custom_test_opts=[CustomTestOptions( @@ -800,7 +800,7 @@ VLM_TEST_SETTINGS = { test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=4096, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, custom_test_opts=[CustomTestOptions( inputs=custom_inputs.windows_attention_image_qwen2_5_vl(), diff --git a/tests/models/multimodal/generation/test_florence2.py b/tests/models/multimodal/generation/test_florence2.py deleted file mode 100644 index a622957f96f69..0000000000000 --- a/tests/models/multimodal/generation/test_florence2.py +++ /dev/null @@ -1,147 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import pytest -from PIL import Image - -from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt -from vllm.multimodal.image import rescale_image_size -from vllm.sequence import SampleLogprobs - -from ....conftest import IMAGE_ASSETS, HfRunner, ImageTestAssets, VllmRunner -from ...utils import check_logprobs_close - -MODELS = ["microsoft/Florence-2-base"] -# Florence-2 model repo's tokenizer config is missing some special tokens. -# Therefore, we use a converted tokenizer from a forked repo -TOKENIZER = "Isotr0py/Florence-2-tokenizer" -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<OD>", # special task token which will output special tokens - "cherry_blossom": - "Describe in detail what is shown in the image.", -}) - - -def get_hf_images_prompts( - prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]], -) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]: - prompts, images = [], [] - for prompt in prompts_: - encoder_prompt = prompt["encoder_prompt"] - prompts.append( - ExplicitEncoderDecoderPrompt( - encoder_prompt=encoder_prompt["prompt"], - decoder_prompt=None, - )) - images.append(encoder_prompt["multi_modal_data"]["image"]) - return prompts, images - - -def hf_to_vllm_output(hf_output: tuple[list[int], str, - Optional[SampleLogprobs]]): - """Sanitize hf output to be comparable with vllm output.""" - output_ids, output_str, out_logprobs = hf_output - - output_str = output_str.replace("</s>", "").replace("<s>", "") - - return output_ids, output_str, out_logprobs - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - inputs: list[list[ExplicitEncoderDecoderPrompt]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -) -> None: - with vllm_runner(model, - max_num_seqs=8, - tokenizer_name=TOKENIZER, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs_per_case = [ - vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, - max_tokens, - num_logprobs=num_logprobs, - skip_special_tokens=False, - ) for prompts in inputs - ] - - hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs] - - with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model: - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.lm_head - hf_outputs_per_case = [ - hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in hf_inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): - check_logprobs_close( - outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs], - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=1, - ) - - -# FIXME: https://github.com/huggingface/transformers/issues/38358 -@pytest.mark.skip("Model initialization fails") -@pytest.mark.core_model -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, model: str, - size_factors: list[int], dtype: str, max_tokens: int, - num_logprobs: int) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [[ - ExplicitEncoderDecoderPrompt( - encoder_prompt=TextPrompt( - prompt=prompt, - multi_modal_data={"image": rescale_image_size(image, factor)}), - decoder_prompt=None, - ) for factor in size_factors - ] for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - - run_test( - hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) diff --git a/tests/models/multimodal/generation/test_mllama.py b/tests/models/multimodal/generation/test_mllama.py deleted file mode 100644 index 1c32cc6d71c04..0000000000000 --- a/tests/models/multimodal/generation/test_mllama.py +++ /dev/null @@ -1,768 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional, overload - -import pytest -import torch -from packaging.version import Version -from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer -from transformers import __version__ as TRANSFORMERS_VERSION - -from vllm import LLM, SamplingParams -from vllm.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.attention.selector import (_Backend, _cached_get_attn_backend, - global_force_attn_backend_context_manager) -from vllm.model_executor.models.mllama import MllamaForConditionalGeneration -from vllm.multimodal.image import rescale_image_size -from vllm.sequence import SampleLogprobs - -from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, - PromptImageInput, VllmRunner) -from ....quantization.utils import is_quant_method_supported -from ....utils import (create_new_process_for_each_test, large_gpu_test, - multi_gpu_test) -from ...utils import check_logprobs_close - -_LIMIT_IMAGE_PER_PROMPT = 3 -MLLAMA_IMAGE_TOKEN_ID = 128256 - -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|image|><|begin_of_text|>The meaning of the image is", - "cherry_blossom": - "<|image|><|begin_of_text|>The city is", -}) - -text_only_prompts = [ - "The color of the sky is blue but sometimes it can also be", -] - -models = [ - "meta-llama/Llama-3.2-11B-Vision-Instruct", -] - -# Indices for inputs -TEXT_ONLY = '0' -IMAGE_AT_BEG = '1' -IMAGE_AT_MIDDLE = '2' -TWO_IMAGES = '3' - -# Input tokenized -prompt_data = { - # Tell me a story - TEXT_ONLY: [41551, 757, 264, 3446], - # <|image|> What's the content of this image - IMAGE_AT_BEG: - [MLLAMA_IMAGE_TOKEN_ID, 3639, 596, 279, 2262, 315, 420, 2217, 220], - # Hello <|image|>What' the content of this image - IMAGE_AT_MIDDLE: - [9906, 220, MLLAMA_IMAGE_TOKEN_ID, 3923, 6, 279, 2262, 315, 420, 2217], - #<|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501 - TWO_IMAGES: [ - MLLAMA_IMAGE_TOKEN_ID, 3957, 1070, 264, 37085, 304, 420, 2217, 30, - MLLAMA_IMAGE_TOKEN_ID, 3923, 596, 279, 10065, 304, 420, 2217, 30 - ] -} - - -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - config = AutoConfig.from_pretrained(model) - image_token_id = config.image_token_index - - tokenizer = AutoTokenizer.from_pretrained(model) - eos_token_id = tokenizer.eos_token_id - - hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) - if token_id != image_token_id or output_ids[idx - 1] != image_token_id - ] - - hf_output_str = output_str - if hf_output_ids[-1] == eos_token_id: - hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) - - return hf_output_ids, hf_output_str, out_logprobs - - -def _get_inputs( - image_assets: ImageTestAssets, - *, - size_factors: Optional[list[float]] = None, - sizes: Optional[list[tuple[int, int]]] = None, -) -> list[tuple[list[str], PromptImageInput]]: - images = [asset.pil_image for asset in image_assets] - - if size_factors is not None: - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - elif sizes is not None: - inputs_per_image = [( - [ - prompt if size is not None else text_only_prompts[0] - for size in sizes - ], - [ - image.resize(size) if size is not None else None - for size in sizes - ], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - if len(sizes) == 0: - inputs_per_image.append( - (text_only_prompts, [None] * len(text_only_prompts))) - else: - raise ValueError("You must provide either `size_factors` or `sizes`") - - return inputs_per_image - - -@overload -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, - model: str, - *, - size_factors: list[float], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - ... - - -@overload -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, - model: str, - *, - sizes: list[tuple[int, int]], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - ... - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets, - model: str, - *, - size_factors: Optional[list[float]] = None, - sizes: Optional[list[tuple[int, int]]] = None, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - _run_test( - hf_runner, - vllm_runner, - _get_inputs(image_assets, size_factors=size_factors, sizes=sizes), - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - ) - - -def _run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - inputs: list[tuple[list[str], PromptImageInput]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - # max_model_len should be greater than image_feature_size - with vllm_runner( - model, - dtype=dtype, - max_model_len=19212, # 3 max size images - max_num_seqs=3, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - limit_mm_per_prompt={"image": - _LIMIT_IMAGE_PER_PROMPT}) as vllm_model: - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - with hf_runner(model, - dtype=dtype, - model_kwargs={"device_map": "auto"}, - auto_cls=AutoModelForImageTextToText) as hf_model: - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Fixture to clear backend cache before each test.""" - _cached_get_attn_backend.cache_clear() # Clear the cache - yield # This allows the test to run - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "sizes", - [ - # Text only - [], - # Single-size - [(512, 512)], - # Single-size, batched - [(512, 512), (512, 512), (512, 512)], - # Multi-size, batched - [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), - (1024, 1024), (512, 1536), (512, 2028)], - # Multi-size, batched, including text only - [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), - (1024, 1024), (512, 1536), (512, 2028), None], - # mllama has 8 possible aspect ratios, carefully set the sizes - # to cover all of them - ]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.skipif( - Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), - reason="Transformers v4.55 has a regression issue on mllama, " - "see: https://github.com/huggingface/transformers/pull/40083") -def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, - model, sizes, dtype, max_tokens, - num_logprobs, - attn_backend: _Backend) -> None: - with global_force_attn_backend_context_manager(attn_backend): - if attn_backend == _Backend.FLASH_ATTN: - # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' - run_test( - hf_runner, - vllm_runner, - image_assets, - model, - sizes=sizes, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.skipif( - Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), - reason="Transformers v4.55 has a regression issue on mllama, " - "see: https://github.com/huggingface/transformers/pull/40083") -def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, - model, dtype, max_tokens, num_logprobs, - attn_backend: _Backend) -> None: - - stop_sign = image_assets[0].pil_image - cherry_blossom = image_assets[1].pil_image - - inputs = [( - [ - "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 - "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 - "<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501 - ], - [ - [stop_sign, cherry_blossom], - # Images with different sizes. - [ - stop_sign.resize((512, 512)), - stop_sign, - ], - [ - stop_sign, - stop_sign.resize((512, 1536)), - cherry_blossom.resize((512, 1024)), - ], - ])] - with global_force_attn_backend_context_manager(attn_backend): - if attn_backend == _Backend.FLASH_ATTN: - # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' - _run_test( - hf_runner, - vllm_runner, - inputs, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -@pytest.mark.skipif( - Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), - reason="Transformers v4.55 has a regression issue on mllama, " - "see: https://github.com/huggingface/transformers/pull/40083") -def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, - dtype, max_tokens, num_logprobs, - attn_backend: _Backend) -> None: - - stop_sign = image_assets[0].pil_image - cherry_blossom = image_assets[1].pil_image - - inputs = [( - [ - "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 - "<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501 - "which is a stop sign and which is a cherry blossom?", # noqa: E501 - ], - [ - [stop_sign], - [stop_sign, cherry_blossom], - ])] - with global_force_attn_backend_context_manager(attn_backend): - if attn_backend == _Backend.FLASH_ATTN: - # Flash Attention works only with bfloat16 data-type - dtype = 'bfloat16' - _run_test( - hf_runner, - vllm_runner, - inputs, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@create_new_process_for_each_test() -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.skipif( - Version(TRANSFORMERS_VERSION) <= Version("4.55.2"), - reason="Transformers v4.55 has a regression issue on mllama, " - "see: https://github.com/huggingface/transformers/pull/40083") -def test_models_distributed( - hf_runner, - vllm_runner, - image_assets, - distributed_executor_backend, - model, - dtype, - max_tokens, - num_logprobs, -) -> None: - run_test( - hf_runner, - vllm_runner, - image_assets, - model=model, - size_factors=[0.25, 0.5, 1.0], - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - ) - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["float16"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') -def test_bnb_regression( - image_assets: ImageTestAssets, - model: str, - dtype: str, - max_tokens: int, -): - stop_sign = image_assets[0].pil_image - prompts = [ - { - "prompt": "<|begin_of_text|>The content of the image <|image|> is", - "multi_modal_data": { - "image": stop_sign - }, - }, - { - "prompt": - "The color of the sky is blue but sometimes it can also be", - }, - ] - # Test regression about QKVCrossParallelLinear - llm = LLM( - model=model, - dtype=dtype, - max_model_len=8192, - max_num_seqs=2, - quantization="bitsandbytes", - ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=max_tokens, - ) - outputs = llm.generate(prompts, sampling_params) - assert outputs - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [32]) -def test_explicit_implicit_prompt( - image_assets: ImageTestAssets, - model: str, - dtype: str, - max_tokens: int, -): - stop_sign = image_assets[0].pil_image - # yapf: disable - prompts = [ - # explicit prompt - { - "encoder_prompt": { - "prompt": "<|image|>", - "multi_modal_data": {"image": stop_sign}, - }, - "decoder_prompt": { - "prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501 - } - }, - { - "encoder_prompt": "Not <|image|>", - "decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501 - }, - # implicit prompt - { - "prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 - "multi_modal_data": {"image": stop_sign}, - }, - { - "prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501 - }, - ] - # yapf: enable - llm = LLM( - model=model, - dtype=dtype, - max_model_len=8192, - max_num_seqs=2, - tensor_parallel_size=1, - ) - sampling_params = SamplingParams( - temperature=0, - max_tokens=max_tokens, - ) - outputs = llm.generate(prompts, sampling_params) - n_prompts = len(prompts) - explicit_outputs = outputs[:n_prompts // 2] - implicit_outputs = outputs[n_prompts // 2:] - for exp_output, imp_output in zip(explicit_outputs, implicit_outputs): - assert exp_output.outputs[0].text == imp_output.outputs[0].text - - -@large_gpu_test(min_gb=48) -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) -def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, - num_logprobs, attn_backend: _Backend) -> None: - - stop_sign = image_assets[0].pil_image - - with global_force_attn_backend_context_manager(attn_backend), vllm_runner( - model, - dtype=dtype, - max_model_len=8192, - max_num_seqs=4, - tensor_parallel_size=1, - limit_mm_per_prompt={"image": - _LIMIT_IMAGE_PER_PROMPT}) as vllm_model: - - # Regression tests for https://github.com/vllm-project/vllm/issues/10648 - - # Number of groups of image tokens is greater than the number of images - # provided (the whitespace between the tags is necessary) - prompt = "<|begin_of_text|><|image|> <|image|> Compare the two images" # noqa: E501 - image = stop_sign - with pytest.raises(ValueError): - vllm_model.generate_greedy_logprobs([prompt], - max_tokens, - num_logprobs, - images=[image]) - - # Batch of a text-only and image request that requires cross-attention - prompts = [ - "What is the capital of spain?", - "Text before the image...<|image|>What is in the image?", # noqa: E501 - ] - images = [ - None, - [stop_sign], - ] - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs, - images=images) - - # Test the reverse order too for good measure - prompts = [ - "<|begin_of_text|>Text before the image...<|image|>What is in the image?", # noqa: E501 - "<|begin_of_text|>Hello!", - ] - images = [ - [stop_sign], - None, - ] - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs, - images=images) - - # Mixed batch with text and images with different numbers of tiles - prompts = [ - "<|begin_of_text|>Hello!", - "<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501 - "<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501 - ] - images = [ - None, - [stop_sign], - # smaller image must be 2nd for the repro - [stop_sign.resize((448, 448))], - ] - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs, - images=images) - - -class DummyModel: - image_token_id = MLLAMA_IMAGE_TOKEN_ID - - -@pytest.mark.core_model -@pytest.mark.parametrize( - "input_indices_and_output", - # inputs, (cross_attention_mask, kv_range_for_decode) - [([TEXT_ONLY], (None, None)), ([IMAGE_AT_BEG], (None, None)), - ([TEXT_ONLY, IMAGE_AT_BEG], (None, None)), - ([IMAGE_AT_MIDDLE], ((10, 12), [[0, 6]])), - ([TEXT_ONLY, IMAGE_AT_MIDDLE], ((14, 12), [[0, 6]])), - ([TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], - ((23, 24), [[0, 6], [6, 12]])), - ([IMAGE_AT_MIDDLE, TEXT_ONLY], ((14, 12), [[0, 6]])), - ([TWO_IMAGES], ((18, 12), [[6, 12]])), - ([TEXT_ONLY, TWO_IMAGES], ((22, 12), [[6, 12]]))]) -def test_get_cross_attention_mask(input_indices_and_output) -> None: - - input_indices, expected_output = input_indices_and_output - - sequences = [torch.tensor(prompt_data[i]) for i in input_indices] - num_tiles = [[2, 2] if i != TEXT_ONLY else [] for i in input_indices - if i != TEXT_ONLY] - input = torch.cat(sequences) - - seq_lens = [len(s) for s in sequences] - - attn_data = FlashAttentionMetadata( - seq_lens=seq_lens, - # Dummy values - enable_kv_scales_calculation=False, - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=0, - slot_mapping=0, - multi_modal_placeholder_index_maps=None, - seq_lens_tensor=0, - max_prefill_seq_len=0, - max_decode_seq_len=0, - context_lens_tensor=None, - block_tables=None, - use_cuda_graph=False, - ) - - dummy = DummyModel() - - cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\ - .get_cross_attention_mask(dummy, - input, - attn_data, - num_tiles=num_tiles, - num_tokens_per_tile=3, - dtype=torch.bfloat16) - - expected_cross_attention_mask, expected_kv_range_for_decode = \ - expected_output - - assert kv_range_for_decode == expected_kv_range_for_decode - if expected_cross_attention_mask is not None: - assert cross_attention_mask is not None - assert cross_attention_mask.shape == expected_cross_attention_mask - else: - assert cross_attention_mask is None - - -@pytest.mark.core_model -@pytest.mark.parametrize( - "input_indices", - [[TEXT_ONLY], [IMAGE_AT_BEG], [TEXT_ONLY, IMAGE_AT_BEG], [IMAGE_AT_MIDDLE], - [TEXT_ONLY, IMAGE_AT_MIDDLE], [TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE], - [IMAGE_AT_MIDDLE, TEXT_ONLY], [TWO_IMAGES], [TEXT_ONLY, TWO_IMAGES]]) -def test_get_full_text_row_masked_out_mask(input_indices) -> None: - - sequences = [torch.tensor(prompt_data[i]) for i in input_indices] - - seq_lens = [len(s) for s in sequences] - - num_prefill_tokens = sum(seq_lens) - - # TEXT_ONLY is zero, so it will be masked out, - # other instances should not be. - encoder_seq_lens = [int(i) for i in input_indices] - - attn_data = FlashAttentionMetadata( - seq_lens=seq_lens, - encoder_seq_lens=encoder_seq_lens, - num_prefill_tokens=num_prefill_tokens, - # Dummy values - enable_kv_scales_calculation=False, - num_prefills=0, - num_decode_tokens=0, - slot_mapping=0, - multi_modal_placeholder_index_maps=None, - seq_lens_tensor=0, - max_prefill_seq_len=0, - max_decode_seq_len=0, - context_lens_tensor=None, - block_tables=None, - use_cuda_graph=False, - ) - - dummy = DummyModel() - - full_text_row_masked_out_mask = MllamaForConditionalGeneration\ - .get_full_text_row_masked_out_mask(dummy, - attn_data, - torch.get_default_device()) - - full_text_row_masked_out_mask = full_text_row_masked_out_mask.squeeze() - full_text_row_masked_out_mask = full_text_row_masked_out_mask.tolist() - - idx = 0 - assert len(full_text_row_masked_out_mask) == num_prefill_tokens - for i, seq_len in enumerate(seq_lens): - must_be_masked = input_indices[i] != TEXT_ONLY - for _ in range(seq_len): - assert full_text_row_masked_out_mask[idx] == must_be_masked, \ - f"full_text_row_masked_out_mask[{idx}] must be " \ - f"'{must_be_masked}' " - idx += 1 - - -@pytest.mark.core_model -@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [ - ([6404], [[4]], [6404]), - ([0, 6404], [[4]], [6404]), - ([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]), - ([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]), -]) -def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles, - expected) -> None: - - dummy = DummyModel() - num_tokens_per_tile = 1601 - actual_encoder_seq_lens = MllamaForConditionalGeneration \ - ._get_and_validate_encoder_lens( - dummy, - encoder_seq_lens, - num_tiles, - num_tokens_per_tile, - ) - assert actual_encoder_seq_lens == expected, \ - f"Expected {expected} but got {actual_encoder_seq_lens}" diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index d39cf706786e2..a4e21aface41f 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -29,10 +29,10 @@ MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID] IMG_URLS = [ - "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", - "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/231-200x300.jpg", - "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/27-500x500.jpg", - "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/17-150x600.jpg", + "237-400x300.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", + "231-200x300.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", + "27-500x500.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", + "17-150x600.jpg", # "https://huggingface.co/datasets/Isotr0py/mistral-test-images/resolve/main/237-400x300.jpg", ] PROMPT = "Describe each image in one short sentence." @@ -105,12 +105,6 @@ def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt: return engine_inputs -MSGS = [ - _create_msg_format(IMG_URLS[:1]), - _create_msg_format(IMG_URLS[:2]), - _create_msg_format(IMG_URLS), -] - SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) LIMIT_MM_PER_PROMPT = dict(image=4) @@ -156,12 +150,8 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_chat( - vllm_runner, - max_model_len: int, - model: str, - dtype: str, -) -> None: +def test_chat(vllm_runner, max_model_len: int, model: str, dtype: str, + local_asset_server) -> None: EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs( FIXTURE_LOGPROBS_CHAT[model]) with vllm_runner( @@ -174,7 +164,14 @@ def test_chat( limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: outputs = [] - for msg in MSGS: + + urls_all = [local_asset_server.url_for(u) for u in IMG_URLS] + msgs = [ + _create_msg_format(urls_all[:1]), + _create_msg_format(urls_all[:2]), + _create_msg_format(urls_all), + ] + for msg in msgs: output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS) outputs.extend(output) @@ -190,17 +187,19 @@ def test_chat( name_1="output") -@pytest.mark.parametrize("prompt,expected_ranges", - [(_create_engine_inputs_hf(IMG_URLS[:1]), - [PlaceholderRange(offset=11, length=494)]), - (_create_engine_inputs_hf(IMG_URLS[1:4]), [ - PlaceholderRange(offset=11, length=266), - PlaceholderRange(offset=277, length=1056), - PlaceholderRange(offset=1333, length=418) - ])]) -def test_multi_modal_placeholders(vllm_runner, prompt: TextPrompt, +@pytest.mark.parametrize( + "image_urls,expected_ranges", + [(IMG_URLS[:1], [PlaceholderRange(offset=11, length=494)]), + (IMG_URLS[1:4], [ + PlaceholderRange(offset=11, length=266), + PlaceholderRange(offset=277, length=1056), + PlaceholderRange(offset=1333, length=418) + ])]) +def test_multi_modal_placeholders(vllm_runner, image_urls: list[str], expected_ranges: list[PlaceholderRange], - monkeypatch) -> None: + local_asset_server, monkeypatch) -> None: + local_image_urls = [local_asset_server.url_for(u) for u in image_urls] + prompt = _create_engine_inputs_hf(local_image_urls) # This placeholder checking test only works with V0 engine # where `multi_modal_placeholders` is returned with `RequestOutput` diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 4a65e8c95204e..e0e9980b88339 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -122,8 +122,7 @@ def run_test( @pytest.mark.core_model -@pytest.mark.parametrize( - "model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"]) +@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) @create_new_process_for_each_test() def test_models(vllm_runner, model) -> None: run_test( diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index 03c08240d6a81..133d5d6ee2ef8 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -250,7 +250,7 @@ def build_video_inputs_from_test_info( def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], size_type: SizeType): - """Applies a size scaler to one image; this can be a an image size factor, + """Applies a size scaler to one image; this can be an image size factor, which scales the image while maintaining the aspect ratio""" # Special case for embeddings; if it's a tensor, it's only valid if we # are considering size factors at constant scale, i.e., we just clone diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index 336e2dd2b1201..1edb512135343 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -42,7 +42,7 @@ def get_filtered_test_settings( else: assert test_info.prompt_formatter is not None - # Everything looks okay; keep if this is has correct proc handling + # Everything looks okay; keep if this is correct proc handling if (test_info.distributed_executor_backend is not None) == new_proc_per_test: matching_tests[test_name] = test_info diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index a5d6948f06efd..11d44120b875f 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -42,7 +42,7 @@ def run_test( tensor_parallel_size: int = 1, vllm_embeddings: Optional[torch.Tensor] = None, ): - """Modality agnostic test test executor for comparing HF/vLLM outputs.""" + """Modality agnostic test executor for comparing HF/vLLM outputs.""" # In the case of embeddings, vLLM takes separate input tensors vllm_inputs = vllm_embeddings if vllm_embeddings is not None else inputs @@ -69,6 +69,9 @@ def run_test( vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode if model_info.hf_overrides: vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides + if model_info.skip_tokenizer_init: + vllm_runner_kwargs_[ + "skip_tokenizer_init"] = model_info.skip_tokenizer_init if vllm_runner_kwargs: vllm_runner_kwargs_.update(vllm_runner_kwargs) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index e9be79fba911f..b503d42567022 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -46,7 +46,7 @@ def _run_test( vllm_model.encode(prompt) -MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] +MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] @pytest.mark.core_model diff --git a/tests/models/multimodal/pooling/test_radio.py b/tests/models/multimodal/pooling/test_radio.py new file mode 100644 index 0000000000000..27b9fe369e800 --- /dev/null +++ b/tests/models/multimodal/pooling/test_radio.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download +from transformers import AutoConfig, AutoModel, CLIPImageProcessor + +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.model_executor.models.radio import RadioModel +from vllm.transformers_utils.configs.radio import RadioConfig +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + +from ....conftest import ImageTestAssets + +# we use snapshot_download to prevent conflicts between +# dynamic_module and trust_remote_code for hf_runner +DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] + + +@torch.inference_mode() +def run_radio_test( + image_assets: ImageTestAssets, + model_id: str, + *, + dtype: str, +): + model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN) + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + + img_processor = CLIPImageProcessor.from_pretrained(model) + images = [asset.pil_image for asset in image_assets] + # Input resolution must be a multiple of `self.min_resolution_step`. + # Using `self.get_nearest_supported_resolution`, for assets 432x642 the + # nearest supported resolution is 432x640. + pixel_values = [ + img_processor( + image, + return_tensors='pt').pixel_values.to(torch_dtype)[:, :, :, :640] + for image in images + ] + + config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + + hf_model = AutoModel.from_pretrained( + model_id, + config=config, + torch_dtype=torch_dtype, + trust_remote_code=True, + ).to("cuda") + hf_model.eval() + + hf_outputs_per_image = [ + hf_model(pixel_value.to("cuda")).features + for pixel_value in pixel_values + ] + + radio_config = RadioConfig(model_name=config.args["model"], + reg_tokens=config.args["register_multiple"]) + vllm_model = RadioModel(radio_config) + vllm_model.load_weights(hf_model.state_dict()) + vllm_model = vllm_model.to("cuda", torch_dtype) + + vllm_outputs_per_image = [ + vllm_model(pixel_values=pixel_value.to("cuda")) + for pixel_value in pixel_values + ] + del vllm_model, hf_model + cleanup_dist_env_and_memory() + + cos_similar = nn.CosineSimilarity(dim=-1) + for vllm_output, hf_output in zip(vllm_outputs_per_image, + hf_outputs_per_image): + assert cos_similar(vllm_output, hf_output).mean() > 0.99 + + +@pytest.mark.parametrize("model_id", [ + "nvidia/C-RADIOv2-H", +]) +@pytest.mark.parametrize("dtype", ["half"]) +def test_radio(dist_init, image_assets, model_id, dtype: str) -> None: + run_radio_test( + image_assets, + model_id, + dtype=dtype, + ) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 16c0428c6d8f1..0941cc3f608e7 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -31,16 +31,48 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: """ # Ensure video metadata is included if "video" in mm_data: + # GLM4.1V doesn't support multiple videos video = mm_data["video"] + num_frames = len(video) mm_data["video"] = (video, { - "total_num_frames": len(video), - "fps": len(video), + "total_num_frames": num_frames, + "fps": num_frames, "duration": 1, - "video_backend": "opencv" + "frames_indices": [i for i in range(num_frames)], + "video_backend": "opencv", + "do_sample_frames": True, }) return mm_data +def qwen3_vl_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: + """ + Patch the multimodal data for Qwen3-VL model. + """ + + def create_metadata(frames: np.ndarray): + num_frames = len(frames) + return { + "total_num_frames": num_frames, + "fps": 2.0, + "duration": num_frames / 2.0, + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + "do_sample_frames": True, + } + + # Ensure video metadata is included + if "video" in mm_data: + video = mm_data["video"] + if isinstance(video, list): + # multiple videos + mm_data["video"] = [(vid, create_metadata(vid)) for vid in video] + else: + # single video + mm_data["video"] = (video, create_metadata(video)) + return mm_data + + def _test_processing_correctness( model_id_or_arch: str, hit_rate: float, @@ -66,7 +98,9 @@ def _test_processing_correctness( hf_overrides=model_info.hf_overrides, # Ensure that the cache can fit all of the data mm_processor_cache_gb=2048, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] @@ -162,8 +196,6 @@ def _test_processing_correctness( # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. _ADD_SPECIAL_TOKENS_OVERRIDES = { - "donut": False, - "mllama": False, "ovis": False, "ovis2_5": False, "paligemma": False, @@ -179,8 +211,10 @@ _IGNORE_MM_KEYS = { } MM_DATA_PATCHES = { - # GLM4.1V requires video metadata to be included in the input + # GLM4.1V and Qwen3-VL requires video metadata to be included in the input "glm4v": glm4_1v_patch_mm_data, + "qwen3_vl": qwen3_vl_patch_mm_data, + "qwen3_vl_moe": qwen3_vl_patch_mm_data, } @@ -273,9 +307,7 @@ def _test_processing_correctness_one( "facebook/chameleon-7b", "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", - "naver-clova-ix/donut-base-finetuned-docvqa", "baidu/ERNIE-4.5-VL-28B-A3B-PT", - "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", "google/gemma-3n-E2B-it", @@ -300,8 +332,8 @@ def _test_processing_correctness_one( "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf", "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", - "meta-llama/Llama-3.2-11B-Vision-Instruct", "TIGER-Lab/Mantis-8B-siglip-llama3", + "mispeech/midashenglm-7b", "openbmb/MiniCPM-Llama3-V-2_5", "openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-V-2_6", @@ -325,6 +357,8 @@ def _test_processing_correctness_one( "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2.5-Omni-3B", + "Qwen/Qwen3-VL-4B-Instruct", + "Qwen/Qwen3-VL-30B-A3B-Instruct", "YannQi/R-4B", "Skywork/Skywork-R1V-38B", "HuggingFaceTB/SmolVLM2-2.2B-Instruct", diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index a49842e1099c2..070ddcd89ee96 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -5,14 +5,26 @@ import pytest from vllm.assets.video import VideoAsset from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.video import OpenCVDynamicVideoBackend, OpenCVVideoBackend from ...utils import build_model_context @pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"]) @pytest.mark.parametrize("expected_toks_per_frame", [299]) -@pytest.mark.parametrize("num_frames", [32, 128]) -@pytest.mark.parametrize("fps, expected_grid_t", [(1, 5), (2, 10)]) +@pytest.mark.parametrize( + "num_frames, fps, expected_grid_t", + [ + # pre-sampled fixed frames (unexpected behavior, + # but we still expect it to work without errors) + (32, 1, 16), + (32, 2, 16), + (128, 1, 64), + (128, 2, 64), + # post-sampled frames (expected behavior) + (-1, 1, 5), + (-1, 2, 10), + ]) def test_processor_override( model_id: str, expected_toks_per_frame: int, @@ -50,3 +62,49 @@ def test_processor_override( assert grid_t == expected_grid_t assert video_tok_count == expected_toks_per_frame * grid_t + + +@pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"]) +@pytest.mark.parametrize("fps", [2]) +def test_video_loader_consistency( + model_id: str, + fps: int, +): + """ + Ensure dynamic video loader (pre-sampled by loader) and normal video + loader (post-sampled by processor) produce same video processing outputs. + """ + ctx = build_model_context( + model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"video": 1}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + hf_processor_mm_kwargs = {"fps": fps} + + # Build the image str / prompt based on the number of images we pass + prompt = "<|begin_of_video|><|video|><|end_of_video|>" + + video_path = VideoAsset(name="baby_reading", num_frames=-1).video_path + with open(video_path, "rb") as f: + video_bytes = f.read() + + static_video, static_metadata = OpenCVVideoBackend.load_bytes(video_bytes) + dynamic_video, dynamic_metadata = OpenCVDynamicVideoBackend.load_bytes( + video_bytes, fps=fps) + + # pre-sampled loader shouldn't read all frames + assert len(dynamic_video) < len(static_video) + + static_mm_data = {"video": [(static_video, static_metadata)]} + dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]} + + static_outputs = processor.apply(prompt, static_mm_data, + hf_processor_mm_kwargs) + dynamic_outputs = processor.apply(prompt, dynamic_mm_data, + hf_processor_mm_kwargs) + + assert static_outputs["prompt_token_ids"] == dynamic_outputs[ + "prompt_token_ids"] + assert static_outputs["mm_kwargs"].get_data( + ) == dynamic_outputs["mm_kwargs"].get_data() diff --git a/tests/models/multimodal/processing/test_mllama.py b/tests/models/multimodal/processing/test_mllama.py deleted file mode 100644 index b42d3f89f3cbf..0000000000000 --- a/tests/models/multimodal/processing/test_mllama.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for mllama's multimodal preprocessing and profiling.""" -import pytest -from transformers import MllamaConfig - -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.profiling import MultiModalProfiler - -from ...utils import build_model_context - - -@pytest.mark.parametrize("model_id", - ["meta-llama/Llama-3.2-11B-Vision-Instruct"]) -@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072]) -@pytest.mark.parametrize("max_num_seqs", [1, 2, 8]) -def test_profiling( - model_id: str, - max_model_len: int, - max_num_seqs: int, -): - # regression test for https://github.com/vllm-project/vllm/issues/13929 - from vllm.model_executor.models.mllama import calc_token_per_chunk - - model_config_kwargs = { - "max_model_len": max_model_len, - } - ctx = build_model_context( - model_id, - model_config_kwargs=model_config_kwargs, - limit_mm_per_prompt={"image": 1}, - ) - - mm_config = ctx.get_mm_config() - processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - profiler = MultiModalProfiler(processor) - - dummy_encoder_data = profiler.get_encoder_dummy_data( - max_model_len, - mm_counts=mm_config.limit_per_prompt, - ) - dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs( - max_model_len, - mm_counts=mm_config.limit_per_prompt, - ) - - hf_config = ctx.get_hf_config(MllamaConfig) - image_size = hf_config.vision_config.image_size - encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids) - ] * max_num_seqs - - mm_data = processor.apply( - prompt=dummy_mm_data.prompt, - mm_data=dummy_mm_data.mm_data, - hf_processor_mm_kwargs=dict(), - )["mm_kwargs"].get_data() - - # Get the actual number of encoder tokens for each sample. - # Because attn_metadata.encoder_seq_lens only counts the last - # group of images for each sample, which is used to cheat the - # block manager to allocate blocks for those images only. - # See MllamaMultiModalProcessor for more details. - num_tiles = [[t] for t in mm_data.pop("num_tiles")] - num_tokens_per_tile = calc_token_per_chunk(image_size) - actual_encoder_seq_lens = [ - sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles - ] - - # simulate mllama image-present prefill. - for actual_len, last_group_len in zip(actual_encoder_seq_lens, - encoder_seq_lens): - assert actual_len >= last_group_len diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index 3be77b5da63f2..e7b28ff8ec7f0 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -52,7 +52,7 @@ def test_profiling(model_id: str, max_model_len: int): chunks_per_image = prod(mm_data["patches_per_image"]) total_num_patches = chunks_per_image * tokens_per_patch num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][ - 1] # x-y seperator tokens + 1] # x-y separator tokens total_tokens = total_num_patches.item() + num_tiles.item( ) + 3 # image start, image, image end diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 1a11fa3d2b824..b678313752d65 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -41,9 +41,6 @@ ARCH_NEEDS_EXTRAS = [ ] REPO_ID_TO_SKIP = { "nm-testing/pixtral-12b-FP8-dynamic": "duplicated test", - # FIXME(Isotr0py): enable GPT-OSS based InternVL3.5 model - # after support PP for GPT-OSS - "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview": "Broken model", } ImageInput = list[Image.Image] @@ -199,7 +196,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=hf_overrides_fn, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] diff --git a/tests/models/multimodal/test_mapping.py b/tests/models/multimodal/test_mapping.py index 7096810d8e15c..caf1966ab513f 100644 --- a/tests/models/multimodal/test_mapping.py +++ b/tests/models/multimodal/test_mapping.py @@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) original_weights = create_repo_dummy_weights(model_id) diff --git a/tests/models/registry.py b/tests/models/registry.py index 3b5cec2dc7022..e9cc5170ade74 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -6,10 +6,11 @@ from dataclasses import dataclass, field from typing import Any, Literal, Optional import pytest +import torch from packaging.version import Version from transformers import __version__ as TRANSFORMERS_VERSION -from vllm.config import TokenizerMode +from vllm.config import ModelDType, TokenizerMode @dataclass(frozen=True) @@ -47,6 +48,23 @@ class _HfExamplesInfo: The reason for the minimum/maximum version requirement. """ + skip_tokenizer_init: bool = False + """ + If true, skip initialization of tokenizer and detokenizer. + """ + + dtype: ModelDType = "auto" + """ + The data type for the model weights and activations. + """ + + enforce_eager: bool = False + """ + Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + """ + is_available_online: bool = True """ Set this to ``False`` if the name of this architecture no longer exists on @@ -76,6 +94,15 @@ class _HfExamplesInfo: If not specified, the default revision will be used. """ + max_num_seqs: Optional[int] = None + """Maximum number of sequences to be processed in a single iteration.""" + + use_original_num_layers: bool = False + """ + If True, use the original number of layers from the model config + instead of minimal layers for testing. + """ + def check_transformers_version( self, *, @@ -137,7 +164,7 @@ class _HfExamplesInfo: # yapf: disable _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] - "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B", + "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-2509", min_transformers_version="4.56.0", trust_remote_code=True), "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", @@ -153,8 +180,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", trust_remote_code=True), + "BailingMoeV2ForCausalLM": _HfExamplesInfo("inclusionAI/Ling-mini-2.0", + trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1", - min_transformers_version="4.56.0", + min_transformers_version="4.55.3", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", {"1b": "bigscience/bloomz-1b1"}), @@ -208,7 +237,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "GptOssForCausalLM": _HfExamplesInfo("lmsys/gpt-oss-20b-bf16"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 + "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", # noqa: E501 + min_transformers_version="4.55.3"), "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), @@ -228,7 +258,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", - min_transformers_version="4.56.0", + min_transformers_version="4.55.3", extras={ "tiny": "ai21labs/Jamba-tiny-dev", "random": "ai21labs/Jamba-tiny-random", # noqa: E501 @@ -244,7 +274,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 is_available_online=False), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), - "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"), + "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1", + min_transformers_version="4.55.3", + extras={ + "random": "yujiepan/mamba2-codestral-v0.1-tiny-random", # noqa: E501 + }), "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 "MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16", trust_remote_code=True), @@ -259,7 +293,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501 {"tiny": "TitanML/tiny-mixtral"}), # noqa: E501 - "QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501 + "MotifForCausalLM": _HfExamplesInfo("Motif-Technologies/Motif-2.6B", + trust_remote_code=True, + v0_only=True), "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), @@ -267,6 +303,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), + "Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"}), @@ -282,8 +319,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", - max_transformers_version="4.53", - transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501 trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", max_transformers_version="4.53", @@ -294,6 +329,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), + "Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", + min_transformers_version="4.56.2"), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 trust_remote_code=True, @@ -317,17 +354,13 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True), "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"), - # [Encoder-decoder] - "BartModel": _HfExamplesInfo("facebook/bart-base"), - "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), - "MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501 - hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501 } _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501 + "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -359,7 +392,20 @@ _EMBEDDING_EXAMPLE_MODELS = { trust_remote_code=True), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 - is_available_online=False), # noqa: E501 + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model + # going OOM in CI + max_num_seqs=32, + ), + "Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model going OOM in CI + max_num_seqs=32, + ), } _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { @@ -368,6 +414,7 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { # [Cross-encoder] "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 + "BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"), "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 trust_remote_code=True, hf_overrides={ @@ -445,7 +492,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { trust_remote_code=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 max_model_len=10240, - extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501 + extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501 ), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 @@ -457,6 +504,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501 + "MiDashengLMModel": _HfExamplesInfo("mispeech/midashenglm-7b", + trust_remote_code=True), "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", @@ -476,6 +525,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { trust_remote_code=True), "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501 trust_remote_code=True), + "NemotronH_Nano_VL": _HfExamplesInfo("nano_vl_dummy", + is_available_online=False, + trust_remote_code=True), "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, max_transformers_version="4.53", transformers_version_reason="HF model is not compatible", # noqa: E501 @@ -506,6 +558,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_model_len=4096), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 + "Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501 + max_model_len=4096, + min_transformers_version="4.57"), # noqa: E501 + "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501 + max_model_len=4096, + min_transformers_version="4.57"), "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", @@ -527,15 +585,6 @@ _MULTIMODAL_EXAMPLE_MODELS = { is_available_online=False, ), # [Encoder-decoder] - "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 - hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 - extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 - # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer - # Therefore, we borrow the BartTokenizer from the original Bart model - "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 - tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501 - trust_remote_code=True), # noqa: E501 - "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 # [Cross-encoder] "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 @@ -555,19 +604,21 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random", speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501 trust_remote_code=True), - "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", + "EagleLlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B-Instruct", # noqa: E501 trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", - tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 - "Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 + "Eagle3LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.1-8B-Instruct", # noqa: E501 trust_remote_code=True, - speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", - tokenizer="meta-llama/Llama-3.1-8B-Instruct"), - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501 - # trust_remote_code=True, - # speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 - # tokenizer="Qwen/Qwen3-8B"), + speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 + tokenizer="meta-llama/Llama-3.1-8B-Instruct", + use_original_num_layers=True, + max_model_len=10240), + "LlamaForCausalLMEagle3": _HfExamplesInfo("Qwen/Qwen3-8B", # noqa: E501 + trust_remote_code=True, + speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 + tokenizer="Qwen/Qwen3-8B", + use_original_num_layers=True), "EagleLlama4ForCausalLM": _HfExamplesInfo( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", trust_remote_code=True, @@ -587,7 +638,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { is_available_online=False), "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, - speculative_model="XiaomiMiMo/MiMo-7B-RL") + speculative_model="XiaomiMiMo/MiMo-7B-RL"), + "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", + min_transformers_version="4.56.2"), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index b4d516233b4bf..56b5d32d16536 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -10,7 +10,7 @@ from vllm import LLM from vllm.config import ModelImpl from vllm.engine.llm_engine import LLMEngine as V0LLMEngine from vllm.utils import GiB_bytes -from vllm.v1.core.kv_cache_utils import get_kv_cache_config +from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.engine.core import EngineCore as V1EngineCore from ..utils import create_new_process_for_each_test @@ -18,6 +18,26 @@ from .registry import (_TRANSFORMERS_BACKEND_MODELS, AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels) from .utils import dummy_hf_overrides +# This minimal list of model architectures is smaller than the total list of +# supported models. The intention is that in the "typical" regression testing +# scenario, we only test initializing these models. This subset was chosen +# to include representative examples of model varieties/workloads (conditional +# generation, sequence classification, causal LM, ranking, chat, reward model, +# multimodal, geospatial, voice, embedding, MTP) +MINIMAL_MODEL_ARCH_LIST = [ + "LlavaForConditionalGeneration", "Llama4ForConditionalGeneration", + "BertForSequenceClassification", "Gemma3nForCausalLM", "JinaVLForRanking", + "InternVLChatModel", "InternLM2ForRewardModel", + "TransformersForMultimodalLM", "PrithviGeoSpatialMAE", "UltravoxModel", + "DeepSeekMTPModel", "XLMRobertaModel" +] + +# This list is the complement of the minimal list above. The intention is that +# this list of models is only tested in a "special case" i.e. most PRs should +# not test these models +OTHER_MODEL_ARCH_LIST = (set(HF_EXAMPLE_MODELS.get_supported_archs()) - + set(MINIMAL_MODEL_ARCH_LIST)) + @create_new_process_for_each_test() def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, @@ -36,7 +56,10 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, hf_overrides_fn = partial(dummy_hf_overrides, model_arch=model_arch, - exist_overrides=model_info.hf_overrides) + exist_overrides=model_info.hf_overrides, + use_original_num_layers=getattr( + model_info, 'use_original_num_layers', + False)) # Avoid calling model.forward() def _initialize_kv_caches_v0(self) -> None: @@ -45,11 +68,11 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, def _initialize_kv_caches_v1(self, vllm_config): kv_cache_specs = self.model_executor.get_kv_cache_specs() - scheduler_kv_cache_config = get_kv_cache_config( + scheduler_kv_cache_config = get_kv_cache_configs( vllm_config, - kv_cache_specs[0], - 10 * GiB_bytes, - ) + kv_cache_specs, + [10 * GiB_bytes], + )[0] # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config return 1, 0, scheduler_kv_cache_config @@ -60,19 +83,25 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, _initialize_kv_caches_v1), monkeypatch.context() as m): if model_info.v0_only: m.setenv("VLLM_USE_V1", "0") - if model_arch == "Phi4FlashForCausalLM": - # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend + if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"): + # Phi4FlashForCausalLM and MotifForCausalLM + # only supports DIFFERENTIAL_FLASH_ATTN backend m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # L4 supports FA3. m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") + if model_arch == "WhisperForConditionalGeneration": + m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") LLM( model_info.default, tokenizer=model_info.tokenizer, tokenizer_mode=model_info.tokenizer_mode, revision=model_info.revision, + enforce_eager=model_info.enforce_eager, + skip_tokenizer_init=model_info.skip_tokenizer_init, + dtype=model_info.dtype, speculative_config={ "model": model_info.speculative_model, "num_speculative_tokens": 1, @@ -85,11 +114,26 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, model_impl=ModelImpl.TRANSFORMERS if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM, hf_overrides=hf_overrides_fn, - ) + max_num_seqs=model_info.max_num_seqs) -@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) -def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): +@pytest.mark.parametrize("model_arch", MINIMAL_MODEL_ARCH_LIST) +def test_can_initialize_small_subset(model_arch: str, + monkeypatch: pytest.MonkeyPatch): + """Test initializing small subset of supported models""" + if model_arch == "Lfm2ForCausalLM": + pytest.skip("Skipping until test supports V1-only models") + can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) + + +@pytest.mark.parametrize("model_arch", OTHER_MODEL_ARCH_LIST) +def test_can_initialize_large_subset(model_arch: str, + monkeypatch: pytest.MonkeyPatch): + """Test initializing large subset of supported models + + This test covers the complement of the tests covered in the "small subset" + test. + """ if model_arch == "Lfm2ForCausalLM": pytest.skip("Skipping until test supports V1-only models") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 36882aba5e941..f67d4017eeeec 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -50,7 +50,6 @@ def test_registry_imports(model_arch): @create_new_process_for_each_test() @pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [ ("LlamaForCausalLM", False, False, False), - ("MllamaForConditionalGeneration", True, False, False), ("LlavaForConditionalGeneration", True, True, False), ("BertForSequenceClassification", False, False, True), ("RobertaForSequenceClassification", False, False, True), diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py new file mode 100644 index 0000000000000..d6d43ca2f7e15 --- /dev/null +++ b/tests/models/test_terratorch.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.conftest import VllmRunner +from vllm.utils import set_default_torch_num_threads + + +@pytest.mark.parametrize( + "model", + [ + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + "mgazz/Prithvi_v2_eo_300_tl_unet_agb" + ], +) +def test_inference( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + + pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) + location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) + prompt = dict(prompt_token_ids=[1], + multi_modal_data=dict(pixel_values=pixel_values, + location_coords=location_coords)) + with ( + set_default_torch_num_threads(1), + vllm_runner( + model, + runner="pooling", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + ) as vllm_model, + ): + + vllm_output = vllm_model.llm.encode(prompt) + assert torch.equal( + torch.isnan(vllm_output[0].outputs.data).any(), + torch.tensor(False)) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 66ff8f7a54d31..ba9c3bebc437e 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -8,8 +8,7 @@ import pytest from vllm.platforms import current_platform from ..conftest import HfRunner, VllmRunner -from ..core.block.e2e.test_correctness_sliding_window import prep_prompts -from ..utils import multi_gpu_test +from ..utils import multi_gpu_test, prep_prompts from .utils import check_logprobs_close diff --git a/tests/models/utils.py b/tests/models/utils.py index 0fb1f5b3753b5..76c6e4823a12c 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -294,6 +294,8 @@ def build_model_context( limit_mm_per_prompt=limit_mm_per_prompt, mm_processor_cache_gb=mm_processor_cache_gb, hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, **model_config_kwargs, ) return InputContext(model_config) @@ -345,6 +347,7 @@ class ModelInfo: name: str architecture: str = "" dtype: str = "auto" + hf_dtype: str = "float32" hf_overrides: Optional[dict[str, Any]] = None default_pooling_type: str = "" enable_test: bool = True @@ -352,6 +355,7 @@ class ModelInfo: @dataclass class EmbedModelInfo(ModelInfo): + mteb_score: Optional[float] = None is_matryoshka: bool = False matryoshka_dimensions: Optional[list[int]] = None @@ -368,7 +372,7 @@ class LASTPoolingEmbedModelInfo(EmbedModelInfo): @dataclass class RerankModelInfo(ModelInfo): - pass + mteb_score: Optional[float] = None @dataclass @@ -381,11 +385,18 @@ class LASTPoolingRerankModelInfo(RerankModelInfo): default_pooling_type: str = "LAST" +@dataclass +class GenerateModelInfo(ModelInfo): + hf_dtype: str = "auto" + hf_ppl: Optional[float] = None + + def dummy_hf_overrides( hf_config: PretrainedConfig, *, model_arch: str = "", exist_overrides: Optional[dict[str, Any]] = None, + use_original_num_layers: bool = False, ) -> PretrainedConfig: """ Dummy HF overrides function used to create dummy model @@ -402,10 +413,18 @@ def dummy_hf_overrides( # we use three layers for Gemma-3n to check # both normal layer and kv_shared_layer - num_hidden_layers = (3 if model_arch == "Gemma3nForConditionalGeneration" - else 1) + if use_original_num_layers: + # Use the original number of layers from the config + num_layers = getattr(text_config, 'num_layers', 1) + num_hidden_layers = getattr(text_config, 'num_hidden_layers', 1) + else: + # Use minimal layers for testing + num_layers = 1 + num_hidden_layers = (3 if model_arch + == "Gemma3nForConditionalGeneration" else 1) + text_config.update({ - "num_layers": 1, + "num_layers": num_layers, "num_hidden_layers": num_hidden_layers, "num_experts": num_experts, "num_experts_per_tok": 2, diff --git a/tests/mq_llm_engine/conftest.py b/tests/mq_llm_engine/conftest.py deleted file mode 100644 index 375b248ebedaa..0000000000000 --- a/tests/mq_llm_engine/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py deleted file mode 100644 index 5ff08cbb32487..0000000000000 --- a/tests/mq_llm_engine/test_abort.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test that aborting is handled properly.""" - -import asyncio -import tempfile -import uuid - -import pytest - -from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate -from vllm.engine.arg_utils import AsyncEngineArgs - -MODEL = "google/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL) -RAISED_ERROR = KeyError -RAISED_VALUE = "foo" -EXPECTED_TOKENS = 250 - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest.mark.asyncio -async def test_abort(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - - request_id_to_be_aborted = "request-aborted" - request_ids_a = [f"request-a-{idx}" for idx in range(10)] - request_ids_b = [f"request-b-{idx}" for idx in range(10)] - - # Requests started before one to be aborted. - tasks = [] - for request_id in request_ids_a: - tasks.append( - asyncio.create_task( - generate(client, request_id, EXPECTED_TOKENS))) - - # Aborted. - task_aborted = asyncio.create_task( - generate(client, request_id_to_be_aborted, EXPECTED_TOKENS)) - - # Requests started after one to be aborted. - for request_id in request_ids_b: - tasks.append( - asyncio.create_task( - generate(client, request_id, EXPECTED_TOKENS))) - - # Actually abort. - await asyncio.sleep(0.5) - await client.abort(request_id_to_be_aborted) - - # Confirm that we got all the EXPECTED tokens from the requests. - for task in tasks: - count, request_id = await task - assert count == EXPECTED_TOKENS, ( - f"{request_id} generated only {count} tokens") - - # Cancel task (this will hang indefinitely if not). - task_aborted.cancel() - - # Shutdown. - client.close() diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py deleted file mode 100644 index 77e3732cd06c6..0000000000000 --- a/tests/mq_llm_engine/test_error_handling.py +++ /dev/null @@ -1,376 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test that various errors are handled properly.""" - -import asyncio -import tempfile -import time -import uuid -from unittest.mock import Mock - -import pytest - -from tests.mq_llm_engine.utils import RemoteMQLLMEngine -from vllm import SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.engine.multiprocessing import MQEngineDeadError -from vllm.engine.multiprocessing.engine import MQLLMEngine -from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroupMetadata -from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser - -MODEL = "google/gemma-1.1-2b-it" -ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True) -RAISED_ERROR = KeyError -RAISED_VALUE = "foo" - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Raise error during first forward pass. - engine.engine.model_executor.execute_model = Mock( - side_effect=RAISED_ERROR(RAISED_VALUE)) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_evil_forward(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_forward) as engine: - - client = await engine.make_client() - - # Server should be healthy after initial probe. - await asyncio.sleep(2.0) - await client.check_health() - - # Throws an error that should get ENGINE_DEAD_ERROR. - with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=str(uuid.uuid4())): - pass - assert client.errored - - await asyncio.sleep(1.0) - with pytest.raises(RAISED_ERROR): - await client.check_health() - assert client.errored - - # Shutdown. - client.close() - - -def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs, - ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Raise error during first forward pass. - engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_failed_health_check(tmp_socket): - with RemoteMQLLMEngine( - engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_model_executor_health) as engine: - - client = await engine.make_client() - assert client.is_running - - # Health probe should throw RAISED_ERROR. - await asyncio.sleep(15.) - - with pytest.raises(RAISED_ERROR): - await client.check_health() - assert client.errored - - # Generate call should throw ENGINE_DEAD_ERROR - with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=str(uuid.uuid4())): - pass - - client.close() - - -def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Raise error during abort call. - engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_failed_abort(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_abort) as engine: - - client = await engine.make_client() - assert client.is_running - - # First check health should work. - await client.check_health() - - # Trigger an abort on the client side. - # This request ID does not exist, and will cause the engine to error - await client.abort(request_id="foo") - - # Future generation requests will now fail - # with reference to the original KeyError("foo") - with pytest.raises(MQEngineDeadError) as execinfo: - async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), - request_id=str(uuid.uuid4())): - pass - assert "KeyError" in repr(execinfo.value) - assert client.errored - - # This should raise the original error. - with pytest.raises(RAISED_ERROR): - await client.check_health() - - client.close() - - -@pytest.mark.asyncio -async def test_batch_error(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_abort) as engine: - - client = await engine.make_client() - assert client.is_running - - # First check health should work. - await client.check_health() - - # Batch of requests - async def do_generate(client): - # min_tokens=2048 to keep busy the engine busy - # to get enough time to get process a request - # that will crash the engine - params = SamplingParams(min_tokens=2048, max_tokens=2048) - async for _ in client.generate(prompt="Hello my name is", - sampling_params=params, - request_id=str(uuid.uuid4())): - pass - - tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)] - - # This request will force a processing batch to raise - # an exception and next the engine get errored - await client.abort(request_id="foo") - - # The batch of those request failed, then they - # should get the same exception as a MQEngineDeadError. - errors = await asyncio.gather(*tasks, return_exceptions=True) - for e in errors: - assert isinstance(e, MQEngineDeadError) - assert "KeyError" in repr(e) - - client.close() - - -@pytest.mark.asyncio -async def test_bad_request(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - - # Invalid request should fail, but not crash the server. - with pytest.raises(ValueError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id="abcd-1", - lora_request=LoRARequest( - "invalid-lora", 1, - "invalid-path")): - pass - - # This request should be okay. - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id="abcd-2"): - pass - - # Shutdown. - client.close() - - -@pytest.mark.asyncio -async def test_mp_crash_detection(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - - parser = FlexibleArgumentParser( - description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - # When LLMEngine is loaded, it will crash. - def mock_init(): - raise ValueError - - m.setattr(LLMEngine, "__init__", mock_init) - - start = time.perf_counter() - async with build_async_engine_client(args): - pass - end = time.perf_counter() - - assert end - start < 100, ( - "Expected vLLM to gracefully shutdown in <100s " - "if there is an error in the startup.") - - -@pytest.mark.asyncio -async def test_mp_cuda_init(): - # it should not crash, when cuda is initialized - # in the API server process - import torch - torch.cuda.init() - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - async with build_async_engine_client(args): - pass - - -@pytest.mark.asyncio -async def test_engine_process_death(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - assert client.is_running - - # kill the engine process - engine.proc.kill() - - # Generate call should fail - with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=str(uuid.uuid4())): - pass - - # And the health check should show the engine is dead - with pytest.raises(RuntimeError, match="Engine process .* died"): - await client.check_health() - - client.close() - - -def run_with_evil_input_processing(engine_args: AsyncEngineArgs, - ipc_path: str): - """Simulate an exception while preparing inputs for the model. - In the wild, this could be something like a multimodal input processor - failing on invalid image data.""" - - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - runner = engine.engine.model_executor.driver_worker.worker.model_runner - - # Raise error in the model runner when adding a sequence group. - # See class ModelInputForGPUBuilder - def raiser(_, seq_group_metadata: SequenceGroupMetadata): - if seq_group_metadata.request_id.startswith("evil"): - raise RAISED_ERROR(RAISED_VALUE) - - runner.builder.per_seq_group_compute_fns.append(raiser) - - # Run engine. - engine.start() - - -@pytest.mark.asyncio -async def test_failed_inputs(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket, - run_fn=run_with_evil_input_processing) as engine: - - client = await engine.make_client() - assert client.is_running - - # Engine should be healthy - await client.check_health() - - async def run_failing_request(): - async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), - request_id="evil" + str(uuid.uuid4())): - pass - - async def run_passing_request(): - async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=10), - request_id=str(uuid.uuid4())): - pass - - passing_tasks = [ - asyncio.create_task(run_passing_request()) for _ in range(10) - ] - failing_tasks = [ - asyncio.create_task(run_failing_request()) for _ in range(10) - ] - await asyncio.gather(*failing_tasks, return_exceptions=True) - await asyncio.gather(*passing_tasks) - - # All the bad inputs should have raised - for task in failing_tasks: - with pytest.raises(RAISED_ERROR): - task.result() - - # But all good inputs should have still succeeded - for task in passing_tasks: - task.result() - - # And the engine should remain healthy - assert not client.errored - await client.check_health() - - client.close() diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py deleted file mode 100644 index c934706611ae3..0000000000000 --- a/tests/mq_llm_engine/test_load.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test that the MQLLMEngine is able to handle 10k concurrent requests.""" - -import asyncio -import tempfile -import uuid - -import pytest - -from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate -from vllm.engine.arg_utils import AsyncEngineArgs - -MODEL = "google/gemma-1.1-2b-it" -NUM_EXPECTED_TOKENS = 10 -NUM_REQUESTS = 10000 - -# Scenarios to test for num generated token. -ENGINE_ARGS = AsyncEngineArgs(model=MODEL) - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest.mark.asyncio -async def test_load(tmp_socket): - with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, - ipc_path=tmp_socket) as engine: - - client = await engine.make_client() - - request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] - - # Create concurrent requests. - tasks = [] - for request_id in request_ids: - tasks.append( - asyncio.create_task( - generate(client, request_id, NUM_EXPECTED_TOKENS))) - - # Confirm that we got all the EXPECTED tokens from the requests. - failed_request_id = None - tokens = None - for task in tasks: - num_generated_tokens, request_id = await task - if (num_generated_tokens != NUM_EXPECTED_TOKENS - and failed_request_id is None): - failed_request_id = request_id - tokens = num_generated_tokens - - assert failed_request_id is None, ( - f"{failed_request_id} generated {tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") - - # Shutdown. - client.close() diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py deleted file mode 100644 index 7976d5031aea1..0000000000000 --- a/tests/mq_llm_engine/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import multiprocessing -from typing import Callable, Union - -from vllm import SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.multiprocessing.client import MQLLMEngineClient -from vllm.engine.multiprocessing.engine import MQLLMEngine -from vllm.outputs import RequestOutput -from vllm.usage.usage_lib import UsageContext - - -async def generate( - client: MQLLMEngineClient, - request_id: str, - num_tokens: int, - return_output: bool = False) -> Union[RequestOutput, tuple[int, str]]: - - final_output = None - count = 0 - async for out in client.generate( - request_id=request_id, - prompt="Hello my name is Robert and", - sampling_params=SamplingParams(max_tokens=num_tokens, - temperature=0)): - - count += 1 - final_output = out - await asyncio.sleep(0.) - - if return_output: - return final_output - - # Confirm we generated all the tokens we expected. - return count, request_id - - -def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): - # Make engine. - engine = MQLLMEngine.from_engine_args( - engine_args=engine_args, - usage_context=UsageContext.UNKNOWN_CONTEXT, - ipc_path=ipc_path) - - # Run engine. - engine.start() - - -class RemoteMQLLMEngine: - - def __init__(self, - engine_args: AsyncEngineArgs, - ipc_path: str, - run_fn: Callable = run_normal) -> None: - - self.engine_args = engine_args - self.ipc_path = ipc_path - context = multiprocessing.get_context("spawn") - self.proc = context.Process(target=run_fn, - args=(engine_args, ipc_path)) - self.proc.start() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.proc.kill() - - async def make_client(self) -> MQLLMEngineClient: - engine_config = self.engine_args.create_engine_config() - client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid) - while True: - try: - await client.setup() - break - except TimeoutError: - assert self.proc.is_alive() - return client diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 44c05db2278f7..3c737acfbfe28 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -7,17 +7,17 @@ import pytest import torch from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import (MultiModalCache, MultiModalProcessorCacheItem, MultiModalProcessorCacheItemMetadata, - processor_cache_from_config, - receiver_cache_from_config) + engine_receiver_cache_from_config, + processor_cache_from_config) from vllm.multimodal.hasher import MultiModalHasher from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, MultiModalKwargsItems, MultiModalSharedField) from vllm.multimodal.processing import PromptInsertion -from vllm.multimodal.registry import MultiModalRegistry def _dummy_elem( @@ -96,7 +96,9 @@ def _create_vllm_config( enable_ipc: bool, ): return VllmConfig( - model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb), + model_config=ModelConfig( + model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + mm_processor_cache_gb=mm_processor_cache_gb), parallel_config=ParallelConfig( data_parallel_size=1 if enable_ipc else 2), ) @@ -113,15 +115,16 @@ def _compare_caches( n_iter: int = 100, seed: int = 0, ): - mm_registry = MultiModalRegistry() - cache_0_p0 = processor_cache_from_config(config_0, mm_registry) - cache_0_p1 = receiver_cache_from_config(config_0, mm_registry) - cache_1_p0 = processor_cache_from_config(config_1, mm_registry) - cache_1_p1 = receiver_cache_from_config(config_1, mm_registry) + cache_0_p0 = processor_cache_from_config(config_0, MULTIMODAL_REGISTRY) + cache_0_p1 = engine_receiver_cache_from_config(config_0, + MULTIMODAL_REGISTRY) + cache_1_p0 = processor_cache_from_config(config_1, MULTIMODAL_REGISTRY) + cache_1_p1 = engine_receiver_cache_from_config(config_1, + MULTIMODAL_REGISTRY) cache_size_gb = max( - config_0.model_config.mm_processor_cache_gb, - config_1.model_config.mm_processor_cache_gb, + config_0.model_config.multimodal_config.mm_processor_cache_gb, + config_1.model_config.multimodal_config.mm_processor_cache_gb, ) item_size_gb = int(cache_size_gb / item_capacity) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 05e68a961a548..e1e8282dd66d4 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -31,11 +31,11 @@ if TYPE_CHECKING: from vllm.multimodal.inputs import MultiModalPlaceholderDict # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] TEST_VIDEO_URLS = [ @@ -45,12 +45,11 @@ TEST_VIDEO_URLS = [ @pytest.fixture(scope="module") -def url_images() -> dict[str, Image.Image]: - connector = MediaConnector() +def url_images(local_asset_server) -> dict[str, Image.Image]: return { - image_url: connector.fetch_image(image_url) - for image_url in TEST_IMAGE_URLS + image_url: local_asset_server.get_image_asset(image_url) + for image_url in TEST_IMAGE_ASSETS } @@ -69,7 +68,7 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool: @pytest.mark.asyncio -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_fetch_image_http(image_url: str): connector = MediaConnector() @@ -79,12 +78,12 @@ async def test_fetch_image_http(image_url: str): @pytest.mark.asyncio -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) @pytest.mark.parametrize("suffix", get_supported_suffixes()) async def test_fetch_image_base64(url_images: dict[str, Image.Image], - image_url: str, suffix: str): + raw_image_url: str, suffix: str): connector = MediaConnector() - url_image = url_images[image_url] + url_image = url_images[raw_image_url] try: mime_type = Image.MIME[Image.registered_extensions()[suffix]] @@ -117,7 +116,7 @@ async def test_fetch_image_base64(url_images: dict[str, Image.Image], @pytest.mark.asyncio -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_fetch_image_local_files(image_url: str): connector = MediaConnector() @@ -152,8 +151,8 @@ async def test_fetch_image_local_files(image_url: str): @pytest.mark.asyncio -async def test_fetch_image_local_files_with_space_in_name(): - image_url = TEST_IMAGE_URLS[0] +@pytest.mark.parametrize("image_url", [TEST_IMAGE_ASSETS[0]], indirect=True) +async def test_fetch_image_local_files_with_space_in_name(image_url: str): connector = MediaConnector() with TemporaryDirectory() as temp_dir: @@ -205,6 +204,32 @@ async def test_fetch_video_http(video_url: str, num_frames: int): assert metadata_sync == metadata_async +@pytest.mark.asyncio +@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) +@pytest.mark.parametrize("max_duration", [1, 60, 1800]) +@pytest.mark.parametrize("requested_fps", [2, 24]) +async def test_fetch_video_http_with_dynamic_loader( + video_url: str, max_duration: int, requested_fps: int, + monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic") + connector = MediaConnector( + media_io_kwargs={ + "video": { + "max_duration": max_duration, + "requested_fps": requested_fps, + } + }) + + video_sync, metadata_sync = connector.fetch_video(video_url) + video_async, metadata_async = await connector.fetch_video_async( + video_url) + + assert np.array_equal(video_sync, video_async) + assert metadata_sync == metadata_async + assert metadata_sync["video_backend"] == "opencv_dynamic" + + # Used for `test_argsort_mm_positions`. class TestCase(NamedTuple): mm_positions: "MultiModalPlaceholderDict" @@ -458,7 +483,7 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, with torch.inference_mode(): sharded_output = run_dp_sharded_vision_model(image_input, vision_model) - # Check that the world size is setup correctly + # Check that the world size is set up correctly assert get_tensor_model_parallel_world_size() == world_size # Check that the outputs have the same shape @@ -642,7 +667,7 @@ def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, rope_type="rope_3d") sharded_output = torch.cat(sharded_output, dim=0) - # Check that the world size is setup correctly + # Check that the world size is set up correctly assert get_tensor_model_parallel_world_size() == world_size # Compare outputs (only on rank 0) diff --git a/tests/neuron/1_core/test_activation.py b/tests/neuron/1_core/test_activation.py deleted file mode 100644 index 2d6e5f523cb85..0000000000000 --- a/tests/neuron/1_core/test_activation.py +++ /dev/null @@ -1,43 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch -import torch.nn.functional as F - -from vllm.model_executor.layers.activation import FastGELU, SiluAndMul -from vllm.platforms import current_platform - - -@pytest.mark.parametrize("activation", ["silu_and_mul", "gelu_fast"]) -@pytest.mark.parametrize("num_tokens,d,dtype", [ - (7, 512, torch.half), - (7, 512, torch.float), - (83, 512, torch.half), -]) -@torch.inference_mode() -def test_act_and_mul( - activation: str, - num_tokens: int, - d: int, - dtype: torch.dtype, -) -> None: - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - current_platform.seed_everything(0) - torch.set_default_device("cpu") - x = torch.randn(num_tokens, 2 * d, dtype=dtype).to(device=device) - if activation == "silu_and_mul": - layer = SiluAndMul() - fn = layer.forward_native - elif activation == "gelu_fast": - layer = FastGELU() - fn = F.gelu - else: - raise NotImplementedError( - f"activation {activation} is not implemented.") - assert x.is_xla, "input tensor under testing is expected to be XLA tensor." - out = layer.to(device=device).forward_neuron(x) - ref_out = fn(x.cpu()) - torch.testing.assert_close(out.cpu(), ref_out, atol=0.01, rtol=0.0) diff --git a/tests/neuron/1_core/test_block_table.py b/tests/neuron/1_core/test_block_table.py deleted file mode 100644 index efec56360c142..0000000000000 --- a/tests/neuron/1_core/test_block_table.py +++ /dev/null @@ -1,154 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import neuronxcc.nki.language as nl -import pytest -import torch -import torch.nn.functional as F -from neuronxcc import nki - -from vllm.attention.ops.nki_flash_attn import ( - load_block_tables, transform_block_tables_for_indirect_load) - - -def is_power_of_2(n): - return n > 0 and (n & (n - 1) == 0) - - -def nki_load_and_transform_block_tables( - block_tables, - num_tiles, - num_blocks_per_tile, - num_head, - head_id, - block_size_tiling_factor, -): - assert is_power_of_2( - num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2" - block_tables_sbuf = load_block_tables(block_tables, num_tiles, - num_blocks_per_tile) - - # we need to pass an Index as head_id - head_id = nl.arange(1)[None, :] + head_id - - block_tables_transposed = transform_block_tables_for_indirect_load( - block_tables_sbuf, block_size_tiling_factor, num_head, head_id) - B_P_SIZE = 128 - assert block_tables_transposed.shape[1] == B_P_SIZE - - out = nl.ndarray( - block_tables_transposed.shape, - dtype=nl.int32, - buffer=nl.shared_hbm, - ) - for i in nl.affine_range(block_tables_transposed.shape[0]): - nl.store(dst=out[i], value=block_tables_transposed[i]) - return out - - -def ref_block_tables_transform( - block_tables, - num_tiles, - num_blocks_per_tile, - num_head, - head_id, - block_size_tiling_factor, -): - assert block_tables.numel() == num_tiles * num_blocks_per_tile - block_tables = block_tables.view(num_tiles, num_blocks_per_tile) - B_F_SIZE = 128 - num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE - block_tables = F.pad( - block_tables, - (0, 0, 0, num_tiles_padded - num_tiles), - "constant", - 0, - ) - - block_tables = block_tables * num_head + head_id - block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1) - offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1) - block_tables = block_tables * block_size_tiling_factor + offset - block_tables_transposed = block_tables.view(num_tiles_padded, -1).t() - - num_blocks_per_tile = block_tables_transposed.shape[0] - assert num_blocks_per_tile % B_F_SIZE == 0 - return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE, - B_F_SIZE, num_tiles_padded) - - -@pytest.mark.parametrize( - "q_head_per_kv_head,head_id", - [ - (1, 0), - (3, 1), - ], -) -@pytest.mark.parametrize( - "num_tiles,num_blocks_per_tile", - [ - (1, 1), - (13, 16), - (17, 128), - (35, 512), - (128, 128), - (130, 64), - (280, 256), - (315, 1), - ], -) -@torch.inference_mode() -def test_load_and_transform_block_tables( - monkeypatch: pytest.MonkeyPatch, - num_tiles, - num_blocks_per_tile, - q_head_per_kv_head, - head_id, -) -> None: - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - - compiler_flags_str = " ".join([ - "-O1", - "--retry_failed_compilation", - ]) - with monkeypatch.context() as m: - m.setenv("NEURON_CC_FLAGS", compiler_flags_str) - - torch.manual_seed(10000) - torch.set_printoptions(sci_mode=False) - - # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient - B_P_SIZE = 128 - if num_blocks_per_tile < B_P_SIZE: - assert B_P_SIZE % num_blocks_per_tile == 0 - block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile - else: - block_size_tiling_factor = 1 - max_num_blocks = 100000 - block_tables = torch.randint( - 0, - max_num_blocks, - (num_tiles * num_blocks_per_tile, ), - dtype=torch.int32, - ) - nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1]( - block_tables.to(device=device), - num_tiles, - num_blocks_per_tile, - q_head_per_kv_head, - head_id, - block_size_tiling_factor, - ).cpu() - ref_out = ref_block_tables_transform( - block_tables, - num_tiles, - num_blocks_per_tile, - q_head_per_kv_head, - head_id, - block_size_tiling_factor, - ) - assert (nki_out.shape == ref_out.shape - ), f"{nki_out.shape=} != {ref_out.shape=}" - assert torch.all(nki_out == ref_out) diff --git a/tests/neuron/1_core/test_cache.py b/tests/neuron/1_core/test_cache.py deleted file mode 100644 index 670889ad6b58d..0000000000000 --- a/tests/neuron/1_core/test_cache.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.attention.ops.nki_flash_attn import reshape_and_cache - - -@pytest.mark.parametrize( - "num_tokens, n_kv_head, d_head, num_blocks, block_size", - [ - # Small model configuration (e.g., GPT-2 small) - (32, 12, 64, 4, 128), # Typical sequence processing - (1, 12, 64, 4, 128), # Single token update - (128, 12, 64, 4, 128), # Longer sequence - - # Medium model configuration (e.g., GPT-2 medium) - (64, 16, 96, 8, 256), # Standard batch - (256, 16, 96, 8, 256), # Large batch - - # Large model configuration (e.g., GPT-3 style) - (48, 32, 128, 16, 512), # Typical processing window - (512, 32, 128, 16, 512), # Full context window - - # Edge cases and stress tests - (1024, 8, 32, 32, 32), # Many tokens, small heads - (16, 64, 256, 4, 64), # Few tokens, many heads - (2048, 24, 128, 64, 128), # Large scale test - - # Minimal configurations for debugging - (4, 2, 16, 2, 16), # Tiny test case - (1, 1, 8, 1, 8), # Minimal possible - ]) -def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks, - block_size): - # Set random seed for reproducibility - torch.manual_seed(42) - - # Create CPU tensors for reference implementation - key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( - torch.tensor(d_head)) - value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt( - torch.tensor(d_head)) - key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) - value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head) - slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens] - - # Run reference implementation on CPU - block_indices = torch.div(slot_mapping_cpu, - block_size, - rounding_mode="floor") - block_offsets = slot_mapping_cpu % block_size - - for i in range(num_tokens): - block_idx = block_indices[i] - block_offset = block_offsets[i] - key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i] - value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i] - - # Create XLA device tensors - device = torch.device('xla') - key = key_cpu.to(device) - value = value_cpu.to(device) - key_cache = torch.zeros_like(key_cache_cpu, device=device) - value_cache = torch.zeros_like(value_cache_cpu, device=device) - slot_mapping = slot_mapping_cpu.to(device) - kv_cache = torch.stack([key_cache, value_cache]) - - # Run vectorized implementation on XLA device - reshape_and_cache(key, value, kv_cache, slot_mapping) - key_cache, value_cache = torch.unbind(kv_cache, dim=0) - - # Move results back to CPU for comparison - key_cache_result = key_cache.cpu() - value_cache_result = value_cache.cpu() - - # Assert results match - torch.testing.assert_close(key_cache_result, - key_cache_cpu, - rtol=1e-5, - atol=1e-5) - torch.testing.assert_close(value_cache_result, - value_cache_cpu, - rtol=1e-5, - atol=1e-5) diff --git a/tests/neuron/1_core/test_layernorm.py b/tests/neuron/1_core/test_layernorm.py deleted file mode 100644 index c6fce1d1a0630..0000000000000 --- a/tests/neuron/1_core/test_layernorm.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.platforms import current_platform - - -@pytest.mark.parametrize("num_tokens,hidden_size,add_residual,dtype", [ - (7, 8, False, torch.half), - (83, 768, False, torch.half), - (83, 768, True, torch.half), - (83, 768, True, torch.bfloat16), - (83, 768, True, torch.float32), -]) -@torch.inference_mode() -def test_rms_norm( - num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, -) -> None: - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - current_platform.seed_everything(0) - torch.set_default_device("cpu") - layer = RMSNorm(hidden_size).to(dtype=dtype) - layer.weight.data.normal_(mean=1.0, std=0.1) - scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype).to(device=device) - x *= scale - residual = torch.randn_like(x) * scale if add_residual else None - - residual_cpu = residual.cpu() if add_residual else None - ref_out = layer.to(device="cpu").forward_native(x.cpu(), residual_cpu) - assert x.is_xla, "input tensor under testing is expected to be XLA tensor." - out = layer.to(device=device)(x, residual) - - # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger - # numerical errors than other operators because they involve reductions. - # Therefore, we use a larger tolerance. - if add_residual: - assert out[0].is_xla, "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out[0].cpu(), - ref_out[0], - atol=1e-2, - rtol=1e-2) - torch.testing.assert_close(out[1].cpu(), - ref_out[1], - atol=1e-2, - rtol=1e-2) - else: - assert out.is_xla, "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out.cpu(), ref_out, atol=1e-2, rtol=1e-2) diff --git a/tests/neuron/1_core/test_logits_processor.py b/tests/neuron/1_core/test_logits_processor.py deleted file mode 100644 index ce9eadf5a883e..0000000000000 --- a/tests/neuron/1_core/test_logits_processor.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from unittest.mock import patch - -import pytest -import torch - -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available - - -class MockLogitsProcessor(LogitsProcessor): - - def __init__(self, vocab_size: int, scale: float, - fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size, scale=scale) - self.fake_logits = fake_logits.clone() - - def forward(self, *args, **kwargs): - with patch( - "vllm.model_executor.layers.logits_processor._prune_hidden_states", - lambda x, y: x - ), patch( - "vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits", - lambda *args, **kwargs: self.fake_logits): - return super().forward(*args, **kwargs) - - -def _prepare_test( - batch_size: int -) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: - vocab_size = 32000 - input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) - fake_logits = torch.full((batch_size, vocab_size), - 1e-2, - dtype=input_tensor.dtype) - logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - return input_tensor, fake_logits, logits_processor - - -RANDOM_SEEDS = list(range(8)) - - -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_logits_processors(seed: int): - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - set_random_seed(seed) - torch.set_default_device("cpu") - batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) - - # This sample logits processor gives infinite score to the i-th token, - # where i is the length of the input sequence. - # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") - return logits - - seq_group_metadata_list = [] - seq_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, - logits_processors=[pick_ith]), - block_tables={0: [1]}, - )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=device, - pin_memory=is_pin_memory_available()) - logits_processor_output = logits_processor( - lm_head=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - - fake_logits *= logits_processor.scale - torch.testing.assert_close(logits_processor_output[:, 1], - fake_logits[:, 1], - rtol=1e-4, - atol=0.0) diff --git a/tests/neuron/1_core/test_neuron_model_runner.py b/tests/neuron/1_core/test_neuron_model_runner.py deleted file mode 100644 index 5f3268810f9fe..0000000000000 --- a/tests/neuron/1_core/test_neuron_model_runner.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os -from unittest.mock import MagicMock - -from vllm.config import VllmConfig -from vllm.engine.arg_utils import EngineArgs -from vllm.platforms import current_platform -from vllm.platforms.neuron import NeuronFramework -from vllm.sampling_params import SamplingParams -from vllm.sequence import SequenceData, SequenceGroupMetadata -from vllm.worker.neuron_model_runner import NeuronModelRunner - -os.environ[ - 'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value - - -def _create_neuron_model_runner(model: str, *args, - **kwargs) -> NeuronModelRunner: - engine_args = EngineArgs(model, *args, **kwargs) - engine_config = engine_args.create_engine_config() - vllm_config = VllmConfig( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - ) - neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config) - return neuron_model_runner - - -def test_update_neuron_sampling_params_not_full_batch(): - os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" - model_runner = _create_neuron_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - max_num_seqs=2, - ) - assert not model_runner._on_device_sampling_disabled - # Test sampling param updating only when TNx is framework - # NxDI handles sampling parameter updating inside model - if current_platform.use_transformers_neuronx(): - model_mock = MagicMock() - model_runner.model = model_mock - - seq_group_metadata_list = [ - SequenceGroupMetadata( - request_id="test_0", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0.5, - top_k=1, - top_p=0.5), - block_tables={0: [1]}, - ) - ] - - model_runner.prepare_model_input(seq_group_metadata_list) - - # Index neuron sampling parameters based on block_tables indices. - # The first block_id of the sequence 0 is 1, so its parameters are - # placed at index 1. So the sampling parameters will be: - # Index 0: default sampling parameters - # Index 1: sequecne 0's sampling parameters. - neuron_sampling_params = ( - model_runner.model_config.neuron_sampling_params) - assert neuron_sampling_params.temperature == [1.0, 0.5] - assert neuron_sampling_params.top_k == [ - model_runner._MAX_NEURON_SAMPLING_TOP_K, 1 - ] - assert neuron_sampling_params.top_p == [1.0, 0.5] - model_mock.model.update_generation_config.assert_called_once_with( - neuron_sampling_params) - - -def test_update_neuron_sampling_params_full_batch(): - os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" - model_runner = _create_neuron_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - max_num_seqs=2, - ) - assert not model_runner._on_device_sampling_disabled - - # Test sampling param updating only when TNx is framework - # NxDI handles sampling parameter updating inside model - if current_platform.use_transformers_neuronx(): - model_mock = MagicMock() - model_runner.model = model_mock - - seq_group_metadata_list = [ - SequenceGroupMetadata( - request_id="test_0", - is_prompt=True, - seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0.5, - top_k=1, - top_p=0.5), - block_tables={0: [1]}, - ), - SequenceGroupMetadata( - request_id="test_0", - is_prompt=True, - seq_data={1: SequenceData.from_seqs([4, 5, 6])}, - sampling_params=SamplingParams(temperature=0.2, - top_k=2, - top_p=0.2), - block_tables={1: [0]}, - ) - ] - - model_runner.prepare_model_input(seq_group_metadata_list) - - # Index neuron sampling parameters based on block_tables indices. - # The first block_id of the sequence 0 is 1, so its parameters are - # placed at index 1. So the sampling parameters will be: - # Index 0: sequence 1's sampling parameters - # Index 1: sequecne 0's sampling parameters. - neuron_sampling_params = ( - model_runner.model_config.neuron_sampling_params) - assert neuron_sampling_params.temperature == [0.2, 0.5] - assert neuron_sampling_params.top_k == [2, 1] - assert neuron_sampling_params.top_p == [0.2, 0.5] - model_mock.model.update_generation_config.assert_called_once_with( - neuron_sampling_params) diff --git a/tests/neuron/1_core/test_neuron_quant.py b/tests/neuron/1_core/test_neuron_quant.py deleted file mode 100644 index 0863002695928..0000000000000 --- a/tests/neuron/1_core/test_neuron_quant.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.model_executor.layers.quantization.neuron_quant import ( - NeuronQuantConfig) - - -def test_get_supported_act_dtypes(): - neuron_quant_config = NeuronQuantConfig() - supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes() - target_list = ["any_dtype1", "any_dtype2"] - for dtype in target_list: - assert dtype in supported_act_dtypes diff --git a/tests/neuron/1_core/test_prefix_prefill.py b/tests/neuron/1_core/test_prefix_prefill.py deleted file mode 100644 index abf7febc2955c..0000000000000 --- a/tests/neuron/1_core/test_prefix_prefill.py +++ /dev/null @@ -1,514 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import pytest -import torch -import torch.nn.functional as F - -from vllm.utils import cdiv - - -class BlockDiagonalCausalFromBottomRightMask: - - @staticmethod - def _from_seqlens(query_lens, seq_lens, block_size=None): - from torch import logical_and, logical_or - - contexted = block_size is None - context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) - n_queries = sum(query_lens) - num_seqs = len(query_lens) - if contexted: - key_lens_blockaligned = seq_lens - else: - n_blocks_per_seq = (context_lens + block_size - 1) // block_size - offset_per_seq = n_blocks_per_seq * block_size - key_lens_blockaligned = offset_per_seq[:num_seqs].tolist() - n_keys = sum(key_lens_blockaligned) - - a = (torch.arange(n_queries).reshape(n_queries, - 1).expand(n_queries, n_keys)) - b = torch.arange(n_keys).reshape(1, n_keys).expand(n_queries, n_keys) - q_cumsum = torch.tensor([0] + query_lens).cumsum(dim=0) - k_cumsum = torch.tensor([0] + key_lens_blockaligned).cumsum(dim=0) - - prior_mask = torch.zeros(n_queries, n_keys) - new_masks: list[torch.Tensor] = [] - for seq_id in range(num_seqs): - ri = q_cumsum[seq_id] - ci = k_cumsum[seq_id] - nr = query_lens[seq_id] - - if contexted: - nc = seq_lens[seq_id] - a_offset = ci + nc - ri - nr - new_mask = (a + a_offset) >= b - else: - nc = context_lens[seq_id] - a_offset = ci + nc - 1 - new_mask = a_offset >= b - - left_mask = b >= ci - top_mask = a >= ri - bottom_mask = a < (ri + nr) - - new_mask = logical_and( - logical_and(logical_and(new_mask, left_mask), top_mask), - bottom_mask, - ) - prior_mask = logical_or(prior_mask, new_mask) - new_masks = new_masks + [new_mask] - return prior_mask - - @staticmethod - def from_seqlens(query_lens, seq_lens, block_size=None): - contexted = block_size is None - if contexted: - prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( - query_lens, seq_lens) - active_mask = None - else: - prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( - query_lens, seq_lens, block_size) - active_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( - query_lens, query_lens) - return prior_mask, active_mask - - -def ref_softmax(x: torch.Tensor, - dim: int, - mixed_precision=False, - return_max_reduce=False): - max_value = torch.amax(x, dim=dim, keepdims=True) - exp = torch.exp(x - max_value) - if mixed_precision: - sum_value = torch.sum(exp.astype(torch.float32), - dim=dim, - keepdims=True).astype(x.dtype) - else: - sum_value = torch.sum(exp, dim=dim, keepdims=True) - if return_max_reduce: - return exp / sum_value, max_value, torch.reciprocal(sum_value) - return exp / sum_value - - -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, - return_max_reduce: Optional[bool] = False, -) -> torch.Tensor: - scaled_qk = scale * torch.einsum("qhd,khd->hqk", query, key).float() - if attn_mask is not None: - masked_score = scaled_qk + attn_mask.float() - if return_max_reduce: - norm_score, cached_max, cached_sum_reciprocal = ref_softmax( - masked_score, dim=-1, return_max_reduce=True) - else: - norm_score = ref_softmax(masked_score, dim=-1) - out = torch.einsum("hqk,khd->qhd", norm_score.to(value.dtype), value) - if return_max_reduce: - return ( - out, - cached_max, - cached_sum_reciprocal, - norm_score, - masked_score, - scaled_qk, - ) - else: - return (out, ) - - -def ref_context_attention( - query, - key, - value, - query_lens, - seq_lens, - head_size, - num_queries_per_kv, - return_max_reduce=False, -): - scale = float(1.0 / (head_size**0.5)) - if num_queries_per_kv > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - - attn_mask, _ = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens) - - # convert binary mask to -inf values - attn_mask = torch.logical_not(attn_mask) - attn_mask = attn_mask.float() * -30000 - - output, *debug_tensors = ref_masked_attention( - query, - key, - value, - scale, - attn_mask, - return_max_reduce=return_max_reduce, - ) - - output = output.unsqueeze(1) - if return_max_reduce: - cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = ( - debug_tensors) - return ( - output, - cached_max, - cached_sum_reciprocal, - lse, - masked_score, - scaled_qk, - ) - else: - return output - - -def sample_inputs( - prefill_batch_size, - decode_batch_size, - min_query_len, - max_query_len, - min_ctx_len, - max_ctx_len, - block_size, - num_heads, - num_kv_heads, - head_size, - dtype, -): - batch_size = prefill_batch_size + decode_batch_size - max_model_len = (max_query_len + max_ctx_len) * 4 - max_block_per_request = max_model_len // block_size - cache_size = (batch_size * max_block_per_request) + 2 - prefill_ctx_lens = torch.randint(min_ctx_len, - max_ctx_len + 1, (prefill_batch_size, ), - dtype=torch.long).tolist() - decode_ctx_lens = torch.randint(min_ctx_len, - max_ctx_len + 1, (decode_batch_size, ), - dtype=torch.long).tolist() - ctx_lens = prefill_ctx_lens + decode_ctx_lens - query_lens = torch.randint( - min_query_len, - max_query_len + 1, - (prefill_batch_size, ), - dtype=torch.long, - ).tolist() + [1 for _ in range(decode_batch_size)] - seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] - - num_tokens = sum(query_lens) - query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - query.uniform_(-1, 1) - torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - - kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) - kv.uniform_(-1, 1) - key, value = kv.unbind(dim=1) - - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) - values = torch.arange(0, cache_size, dtype=torch.long) - values = values[torch.randperm(cache_size)] - block_table = values[:batch_size * max_block_per_request].view( - batch_size, max_block_per_request) - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], - dtype=torch.long), - dim=0) - # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) - for i in range(batch_size): - for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) - cur_ctx = 0 - block_id = 0 - while cur_ctx < b_ctx_len[i]: - start_loc = b_seq_start_loc[i] + cur_ctx - if cur_ctx + block_size > b_ctx_len[i]: - end_loc = b_seq_start_loc[i] + b_ctx_len[i] - else: - end_loc = start_loc + block_size - start_slot = block_table[i, block_id] * block_size - end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) - cur_ctx += block_size - block_id += 1 - kv_cache = torch.stack([k_cache, v_cache]) - - return ( - query, - k, - v, - kv_cache, - block_table, - key, - value, - query_lens, - seq_lens, - ) - - -def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, - num_blocks): - context_lens = seq_lens - query_lens - blocks_per_seq = (context_lens + block_size - 1) // block_size - num_seqs = len(seq_lens) - active_blocks: list[int] = [] - for seq_id in range(num_seqs): - active_blocks = ( - active_blocks + - block_tables[seq_id, :blocks_per_seq[seq_id]].tolist()) - return F.pad( - torch.tensor(active_blocks, dtype=torch.int32), - (0, num_blocks - len(active_blocks)), - "constant", - 0, - ) - - -@pytest.mark.parametrize( - "prefill_batch_size,decode_batch_size,block_size,large_tile_size,num_heads,num_queries_per_kv,head_size,mixed_precision", - [ - # Test minimal configurations (small block size) - (1, 199, 1, 512, 4, 2, 8, False - ), # minimal block size, small dimensions - (1, 199, 1, 512, 4, 2, 8, True), # same with mixed precision - - # Test common/medium configurations - (4, 12, 32, 2048, 32, 8, 64, False), # common case, larger heads - (4, 12, 32, 2048, 16, 4, 32, - True), # medium size, mixed precision, grouped-query attention (GQA) - - # Test large configurations - (4, 12, 256, 8192, 8, 1, 128, False), # large blocks, large head size - (4, 12, 256, 8192, 64, 8, 64, True), # large blocks, many heads - - # Test asymmetric configurations - (2, 24, 64, 4096, 12, 4, 96, False), # varied batch sizes - (8, 8, 128, 2048, 24, 2, 48, True), # balanced batches - - # Test edge cases - (1, 128, 16, 1024, 4, 2, 16, False), # large decode batch - (16, 4, 8, 1024, 4, 2, 128, True), # large prefill batch - (4, 12, 32, 2048, 16, 1, 32, True), # multi-head attention (MHA) - (4, 12, 32, 2048, 16, 16, 32, True), # multi-query attention (MQA) - ]) -@torch.inference_mode() -def test_contexted_kv_attention( - monkeypatch: pytest.MonkeyPatch, - prefill_batch_size: int, - decode_batch_size: int, - num_heads: int, - num_queries_per_kv: int, - head_size: int, - block_size: int, - large_tile_size, - mixed_precision: bool, -) -> None: - - import torch_xla.core.xla_model as xm - - from vllm.attention.ops.nki_flash_attn import (flash_attn_varlen_nkifunc, - reorder_context_mask) - - assert large_tile_size % block_size == 0 - - device = xm.xla_device() - - compiler_flags_str = " ".join([ - "-O1", - "--retry_failed_compilation", - ]) - with monkeypatch.context() as m: - m.setenv("NEURON_CC_FLAGS", compiler_flags_str) - - torch.manual_seed(0) - torch.set_printoptions(sci_mode=False) - torch.set_default_device("cpu") - dtype = torch.float32 - - min_ctx_len = 32 - max_ctx_len = 1024 - min_query_len = 16 - max_query_len = 512 - num_kv_heads = num_heads // num_queries_per_kv - ( - query, - k_active, - v_active, - kv_cache, - block_table, - key, - value, - query_lens, - seq_lens, - ) = sample_inputs( - prefill_batch_size=prefill_batch_size, - decode_batch_size=decode_batch_size, - min_query_len=min_query_len, - max_query_len=max_query_len, - min_ctx_len=min_ctx_len, - max_ctx_len=max_ctx_len, - block_size=block_size, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - ) - - output_ref = ref_context_attention( - query, - key, - value, - query_lens, - seq_lens, - head_size, - num_queries_per_kv, - return_max_reduce=False, - ) - - # build neuron program - B_P_SIZE = 128 - assert (large_tile_size >= B_P_SIZE - ), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}" - - def pad_to_multiple(a, b): - return cdiv(a, b) * b - - def pad_to_next_power_of_2(a): - assert a > 0 - return 2**int(a - 1).bit_length() - - # calculate input shapes - max_num_queries = pad_to_next_power_of_2(sum(query_lens)) - context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) - num_active_blocks = cdiv(context_lens, block_size).sum().item() - num_active_blocks = pad_to_multiple(num_active_blocks, - large_tile_size // block_size) - context_kv_len = num_active_blocks * block_size - assert ( - context_kv_len % - large_tile_size == 0), f"invalid context_kv_len={context_kv_len}" - - # pad QKV tensors - pad_dims = ( - 0, - 0, - 0, - 0, - 0, - max_num_queries - query.shape[0], - ) - query = F.pad(query, pad_dims, "constant", 0) - k = F.pad(k_active, pad_dims, "constant", 0) - v = F.pad(v_active, pad_dims, "constant", 0) - - # permute QKV tensors - # query: (1, n_heads, d, seq_q) - # key: (1, n_kv_heads, d, seq_k) - # value: (1, n_kv_heads, seq_v, d) - query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous() - k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous() - v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous() - kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous() - - # transform block table - active_block_table = get_active_block_tables( - block_table.cpu(), - torch.tensor(query_lens).cpu(), - torch.tensor(seq_lens).cpu(), - block_size, - num_active_blocks, - ) - - # Build attention masks - prior_mask, active_mask = ( - BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens, block_size=block_size)) - prior_mask_padded = F.pad( - prior_mask, - ( - 0, - context_kv_len - prior_mask.shape[1], - 0, - max_num_queries - prior_mask.shape[0], - ), - "constant", - 0, - ).bool() - active_mask_padded = F.pad( - active_mask, - ( - 0, - max_num_queries - active_mask.shape[1], - 0, - max_num_queries - active_mask.shape[0], - ), - "constant", - 0, - ).bool() - attn_mask = torch.concat([prior_mask_padded, active_mask_padded], - dim=1) - - attn_mask = reorder_context_mask(attn_mask, large_tile_size, - block_size) - - input_args = ( - query.to(device=device), - k.to(device=device), - v.to(device=device), - kv_cache.to(device=device), - active_block_table.to(device=device), - attn_mask.to(device=device), - ) - input_kwargs = dict( - n_kv_head=num_kv_heads, - head_size=head_size, - mixed_precision=mixed_precision, - LARGE_TILE_SZ=large_tile_size, - ) - - output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) - - num_actual_tokens = sum(query_lens) - # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d) - output_nki = output_nki.cpu().permute(0, 2, 1, 3) - output_nki = output_nki[0, :num_actual_tokens, :, :] - output_ref_padded = F.pad( - output_ref, - (0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]), - "constant", - 0, - ) - output_ref = output_ref_padded.transpose( - 0, 1)[0, :num_actual_tokens, :, :] - - torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0) diff --git a/tests/neuron/1_core/test_rotary_embedding.py b/tests/neuron/1_core/test_rotary_embedding.py deleted file mode 100644 index a7ac79729986d..0000000000000 --- a/tests/neuron/1_core/test_rotary_embedding.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Tests for miscellaneous utilities -""" - -import pytest -import torch - -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.platforms import current_platform - - -@pytest.mark.parametrize( - "max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [ - (16, False, 32, 32, 1024, True), - (16, False, 32, 128, 1024, True), - (16, True, 32, 32, 1024, True), - (16, True, 32, 128, 1024, True), - (16, False, 32, 128, 1024, False), - (16, True, 32, 128, 1024, False), - ]) -def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, - head_size, seq_len, use_key): - import torch_xla.core.xla_model as xm - - device = xm.xla_device() - current_platform.seed_everything(0) - torch.set_default_device("cpu") - - batch_size = 1 - base = 10000 - num_heads = 8 - - rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, torch.float32) - - positions = torch.randint(0, - max_position, (batch_size, seq_len), - device="cpu") - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=torch.float32, - device="cpu") - key = torch.randn_like(query) if use_key else None - assert positions.is_cpu, \ - "reference input tensor is expected to be CPU tensor." - ref_query, ref_key = rot.to(device="cpu").forward_native( - positions, query, key) - out_query, out_key = rot.to(device=device).forward_neuron( - positions.to(device=device), query.to(device=device), - key.to(device=device) if key is not None else None) - if use_key: - assert out_query.is_xla and out_key.is_xla, \ - "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out_key.cpu(), - ref_key, - atol=1e-2, - rtol=1e-2) - else: - assert out_key is None, "expected returned key to be None" - assert out_query.is_xla, \ - "output tensor is expected to be XLA tensor" - torch.testing.assert_close(out_query.cpu(), - ref_query, - atol=1e-2, - rtol=1e-2) diff --git a/tests/neuron/2_core/test_comm_ops.py b/tests/neuron/2_core/test_comm_ops.py deleted file mode 100644 index 85a48dae58aaf..0000000000000 --- a/tests/neuron/2_core/test_comm_ops.py +++ /dev/null @@ -1,101 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools -from typing import Callable -from unittest.mock import patch - -import pytest -import torch -import torch_xla.distributed.xla_multiprocessing as xmp -from typing_extensions import ParamSpec - -from vllm.distributed.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.utils import get_distributed_init_method, get_open_port - -_P = ParamSpec("_P") - - -def reinitialize_neuron_runtime(f: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to reinitialize the Neuron Runtime before executing a test. - This is necessary for distributed tests which need to reallocate Neuron - Cores to separate subprocesses. - """ - - @functools.wraps(f) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: - runtime = torch.classes.neuron.Runtime() - runtime.initialize() - runtime.unsafe_close() - - f(*args, **kwargs) - runtime.initialize() - - return wrapper - - -def all_gather_test_worker(index, tp_degree, distributed_init_method): - init_distributed_environment(tp_degree, - index, - distributed_init_method, - index, - backend="xla") - ensure_model_parallel_initialized(tp_degree, 1) - - num_dimensions = 3 - tensor_size = list(range(2, num_dimensions + 2)) - total_size = 1 - for s in tensor_size: - total_size *= s - - all_gather_dimension = -1 - all_tensors = [ - torch.arange(total_size, dtype=torch.float32, - device="xla").reshape(tensor_size) * (r + 1) - for r in range(tp_degree) - ] - expected = torch.cat(all_tensors, dim=all_gather_dimension) - t = all_tensors[index % tp_degree] - t = tensor_model_parallel_all_gather(t, all_gather_dimension) - torch.testing.assert_close(t, expected) - - -def all_reduce_test_worker(index, tp_degree, distributed_init_method): - init_distributed_environment(tp_degree, - index, - distributed_init_method, - index, - backend="xla") - ensure_model_parallel_initialized(tp_degree, 1) - - num_elements = 8 - all_tensors = [ - torch.arange(num_elements, dtype=torch.float32, device="xla") * (r + 1) - for r in range(tp_degree) - ] - expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) - t = all_tensors[index % tp_degree] - t = tensor_model_parallel_all_reduce(t) - torch.testing.assert_close(t, expected) - - -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("test_target", - [all_reduce_test_worker, all_gather_test_worker]) -@reinitialize_neuron_runtime -def test_neuron_multi_process_tensor_parallel(monkeypatch, tp_size, - test_target): - - with patch('torch_xla._XLAC._xla_runtime_is_initialized', - return_value=False): - distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) - - monkeypatch.setenv("VLLM_USE_V1", "1") - monkeypatch.setenv("NEURONCORE_NUM_DEVICES", str(tp_size)) - monkeypatch.setenv("NEURON_PJRT_PROCESSES_NUM_DEVICES", - ','.join(['1' for _ in range(tp_size)])) - - xmp.spawn(test_target, args=(tp_size, distributed_init_method)) diff --git a/tests/neuron/2_core/test_eagle.py b/tests/neuron/2_core/test_eagle.py deleted file mode 100644 index cac642af03101..0000000000000 --- a/tests/neuron/2_core/test_eagle.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import json -import os -import shutil -import tempfile - -import torch -from huggingface_hub import snapshot_download -from safetensors import safe_open - -from vllm import LLM, SamplingParams - - -def patch_eagle_draft_with_lm_head(target_model_id: str, - draft_model_id: str) -> str: - # In NxDI, draft model checkpoint must include lm_head weights from target - # model. For more details see https://awsdocs-neuron.readthedocs-hosted.com - # /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html - # #eagle-checkpoint-compatibility - final_draft_dir = "/tmp/patched_eagle_draft" - - with tempfile.TemporaryDirectory() as tmp_dir: - target_dir = snapshot_download(repo_id=target_model_id, - local_dir=os.path.join( - tmp_dir, "target")) - draft_dir = snapshot_download(repo_id=draft_model_id, - local_dir=os.path.join(tmp_dir, "draft")) - - lm_head_key = "lm_head.weight" - index_path = os.path.join(target_dir, "model.safetensors.index.json") - with open(index_path) as f: - index = json.load(f) - shard_name = index["weight_map"][lm_head_key] - target_safetensor_path = os.path.join(target_dir, shard_name) - - with safe_open(target_safetensor_path, framework="pt") as f: - target_lm_head = f.get_tensor(lm_head_key) - - draft_path = os.path.join(draft_dir, "pytorch_model.bin") - draft_state_dict = torch.load(draft_path, map_location="cpu") - draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16) - torch.save(draft_state_dict, draft_path) - - shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True) - - return final_draft_dir - - -def test_eagle(): - patched_draft_path = patch_eagle_draft_with_lm_head( - target_model_id="meta-llama/Llama-2-7b-hf", - draft_model_id="yuhuili/EAGLE-llama2-chat-7B") - llm = LLM( - model="meta-llama/Llama-2-7b-hf", - speculative_config={ - "model": patched_draft_path, - "num_speculative_tokens": 5, - "max_model_len": 128 - }, - max_num_seqs=1, - max_model_len=128, - tensor_parallel_size=2, - override_neuron_config={ - "enable_eagle_speculation": True, - "enable_fused_speculation": True, - "fused_qkv": True - }, - ) - prompts = [ - "The president of the United States is", - ] - outputs = llm.generate(prompts, SamplingParams(top_k=1)) - expected_output = " the head of state and head of government of " \ - "the United States. The president direct" - - for output in outputs: - generated_text = output.outputs[0].text - print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") - assert (expected_output == generated_text) - - print("Neuron Eagle speculation test passed.") diff --git a/tests/neuron/2_core/test_mistral.py b/tests/neuron/2_core/test_mistral.py deleted file mode 100644 index ff59be1725b6c..0000000000000 --- a/tests/neuron/2_core/test_mistral.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm import LLM, SamplingParams - - -def test_mistral(): - llm = LLM(model="mistralai/Mistral-7B-v0.1", - tensor_parallel_size=2, - max_num_seqs=4, - max_model_len=128, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True - }) - - # Send more prompts than the compiled batch size (4) and request - # varying generation lengths to test accuracy related to Neuron - # specific sequence id sorting. - prompts = [ - "The president of the United States is", - "The capital of France is", - "What is Annapurna labs?", - "I believe the meaning of life is", - "Tell me a story about a brave knight", - "Hello, my name is Llama", - ] - - sampling_params = [ - SamplingParams(top_k=1, max_tokens=10), - SamplingParams(top_k=1, max_tokens=20), - SamplingParams(top_k=1, max_tokens=30), - SamplingParams(top_k=1, max_tokens=40), - SamplingParams(top_k=1, max_tokens=50), - SamplingParams(top_k=1, max_tokens=60) - ] - - outputs = llm.generate(prompts, sampling_params) - - expected_outputs = [ - " the most powerful person in the world. He is", - " a city of many faces. It is a city of history, culture, art, " - "fashion, and", - "\n\nAnnapurna Labs is a semiconductor company that was founded " - "in 2013 by Amazon. The company is", - " to be happy.\n\nI believe that happiness is a choice.\n\nI " - "believe that happiness is a state of mind.\n\nI believe that " - "happiness is a journey.\n\nI believe", - " who rescued a princess from a dragon.\n\nTell me a story about" - " a princess who rescued herself from a dragon.\n\nTell me a " - "story about a princess who rescued herself from a dragon and " - "then rescued a knight from", - " and I am a 10 year old male. I am a very friendly and " - "affectionate boy who loves to be around people. I am a very " - "active boy who loves to play and run around. I am a very smart " - "boy who loves to learn new things. I am a very loyal boy" - ] - - for expected_output, output in zip(expected_outputs, outputs): - generated_text = output.outputs[0].text - print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") - assert (expected_output == generated_text) - - print("Neuron Mistral test passed.") diff --git a/tests/neuron/2_core/test_multi_lora.py b/tests/neuron/2_core/test_multi_lora.py deleted file mode 100644 index 52ca9fe7b6667..0000000000000 --- a/tests/neuron/2_core/test_multi_lora.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from huggingface_hub import snapshot_download - -from vllm import LLM, SamplingParams -from vllm.lora.request import LoRARequest - - -def test_llama_single_lora(): - sql_lora_files = snapshot_download( - repo_id="yard1/llama-2-7b-sql-lora-test") - llm = LLM(model="meta-llama/Llama-2-7b-hf", - tensor_parallel_size=2, - max_num_seqs=4, - max_model_len=512, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True, - "lora_modules": [{ - "name": "lora_id_1", - "path": sql_lora_files - }] - }, - enable_lora=True, - max_loras=1, - max_lora_rank=256, - device="neuron") - """For multi-lora requests using NxDI as the backend, only the lora_name - needs to be specified. The lora_id and lora_path are supplied at the LLM - class/server initialization, after which the paths are handled by NxDI""" - lora_req_1 = LoRARequest("lora_id_1", 0, " ") - prompts = [ - "The president of the United States is", - "The capital of France is", - ] - outputs = llm.generate(prompts, - SamplingParams(top_k=1), - lora_request=[lora_req_1, lora_req_1]) - - expected_outputs = [ - " the head of state and head of government of the United States. " - "The president direct", - " a city of contrasts. The city is home to the Eiffel Tower" - ] - - for expected_output, output in zip(expected_outputs, outputs): - generated_text = output.outputs[0].text - assert (expected_output == generated_text) - - -def test_llama_multiple_lora(): - sql_lora_files = snapshot_download( - repo_id="yard1/llama-2-7b-sql-lora-test") - llm = LLM(model="meta-llama/Llama-2-7b-hf", - tensor_parallel_size=2, - max_num_seqs=4, - max_model_len=512, - override_neuron_config={ - "sequence_parallel_enabled": - False, - "skip_warmup": - True, - "lora_modules": [{ - "name": "lora_id_1", - "path": sql_lora_files - }, { - "name": "lora_id_2", - "path": sql_lora_files - }] - }, - enable_lora=True, - max_loras=2, - max_lora_rank=256, - device="neuron") - """For multi-lora requests using NxDI as the backend, only the lora_name - needs to be specified. The lora_id and lora_path are supplied at the LLM - class/server initialization, after which the paths are handled by NxDI""" - lora_req_1 = LoRARequest("lora_id_1", 0, " ") - lora_req_2 = LoRARequest("lora_id_2", 1, " ") - prompts = [ - "The president of the United States is", - "The capital of France is", - ] - outputs = llm.generate(prompts, - SamplingParams(top_k=1), - lora_request=[lora_req_1, lora_req_2]) - - expected_outputs = [ - " the head of state and head of government of the United States. " - "The president direct", - " a city of contrasts. The city is home to the Eiffel Tower" - ] - - for expected_output, output in zip(expected_outputs, outputs): - generated_text = output.outputs[0].text - assert (expected_output == generated_text) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py index a750c756c11a2..4bbb79c98a82a 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/__init__.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -def register_prithvi_india(): - return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501 -def register_prithvi_valencia(): - return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501 +def register_prithvi(): + return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501 diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index 0ebaafda94dc5..42874f0398f0a 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -234,6 +234,8 @@ def load_image( class PrithviMultimodalDataProcessor(IOProcessor): + indices = [0, 1, 2, 3, 4, 5] + def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) @@ -412,21 +414,3 @@ class PrithviMultimodalDataProcessor(IOProcessor): format="tiff", data=out_data, request_id=request_id) - - -class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor): - - def __init__(self, vllm_config: VllmConfig): - - super().__init__(vllm_config) - - self.indices = [1, 2, 3, 8, 11, 12] - - -class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor): - - def __init__(self, vllm_config: VllmConfig): - - super().__init__(vllm_config) - - self.indices = [0, 1, 2, 3, 4, 5] diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py index d480aef704c61..d4c6628211fb2 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py @@ -22,7 +22,7 @@ class DataModuleConfig(TypedDict): class ImagePrompt(BaseModel): - data_format: Literal["b64_json", "bytes", "url"] + data_format: Literal["b64_json", "bytes", "url", "path"] """ This is the data type for the input image """ diff --git a/tests/plugins/prithvi_io_processor_plugin/setup.py b/tests/plugins/prithvi_io_processor_plugin/setup.py index a03b1fbbd4a80..3ddda1a47bbe4 100644 --- a/tests/plugins/prithvi_io_processor_plugin/setup.py +++ b/tests/plugins/prithvi_io_processor_plugin/setup.py @@ -9,8 +9,7 @@ setup( packages=["prithvi_io_processor"], entry_points={ "vllm.io_processor_plugins": [ - "prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501 - "prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501 + "prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501 ] }, ) diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index b2fbef2ee25cb..3567a701a3afa 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -7,12 +7,11 @@ import requests from tests.utils import RemoteOpenAIServer from vllm.config import VllmConfig -from vllm.entrypoints.llm import LLM from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams -MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" +MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 @@ -23,61 +22,7 @@ def test_loading_missing_plugin(): get_io_processor(vllm_config, "wrong_plugin") -def test_loading_engine_with_wrong_plugin(): - - with pytest.raises(ValueError): - LLM( - model=MODEL_NAME, - skip_tokenizer_init=True, - trust_remote_code=True, - enforce_eager=True, - # Limit the maximum number of parallel requests - # to avoid the model going OOM in CI. - max_num_seqs=32, - io_processor_plugin="wrong_plugin", - ) - - -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): - - img_prompt = dict( - data=image_url, - data_format="url", - image_format="tiff", - out_data_format="b64_json", - ) - - pooling_params = PoolingParams(task="encode", softmax=False) - - with vllm_runner( - model_name, - runner="pooling", - skip_tokenizer_init=True, - trust_remote_code=True, - enforce_eager=True, - # Limit the maximum number of parallel requests - # to avoid the model going OOM in CI. - max_num_seqs=1, - io_processor_plugin="prithvi_to_tiff_valencia", - ) as llm_runner: - pooler_output = llm_runner.get_llm().encode( - img_prompt, - pooling_params=pooling_params, - ) - output = pooler_output[0].outputs - - # verify the output is formatted as expected for this plugin - assert all( - hasattr(output, attr) - for attr in ["type", "format", "data", "request_id"]) - - # We just check that the output is a valid base64 string. - # Raises an exception and fails the test if the string is corrupted. - base64.b64decode(output.data) - - -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def server(): args = [ "--runner", @@ -90,7 +35,9 @@ def server(): "--max-num-seqs", "32", "--io-processor-plugin", - "prithvi_to_tiff_valencia" + "prithvi_to_tiff", + "--model-impl", + "terratorch", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -136,3 +83,43 @@ async def test_prithvi_mae_plugin_online( # We just check that the output is a valid base64 string. # Raises an exception and fails the test if the string is corrupted. base64.b64decode(plugin_data["data"]) + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): + + img_prompt = dict( + data=image_url, + data_format="url", + image_format="tiff", + out_data_format="b64_json", + ) + + pooling_params = PoolingParams(task="encode", softmax=False) + + with vllm_runner( + model_name, + runner="pooling", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + max_num_seqs=1, + model_impl="terratorch", + io_processor_plugin="prithvi_to_tiff", + ) as llm_runner: + pooler_output = llm_runner.get_llm().encode( + img_prompt, + pooling_params=pooling_params, + ) + output = pooler_output[0].outputs + + # verify the output is formatted as expected for this plugin + assert all( + hasattr(output, attr) + for attr in ["type", "format", "data", "request_id"]) + + # We just check that the output is a valid base64 string. + # Raises an exception and fails the test if the string is corrupted. + base64.b64decode(output.data) diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py index fcbfa681d75c9..c60a03f44baec 100644 --- a/tests/quantization/test_modelopt.py +++ b/tests/quantization/test_modelopt.py @@ -27,7 +27,7 @@ def use_v0_only(monkeypatch): reason="ModelOpt FP8 is not supported on this GPU type.") def test_modelopt_fp8_checkpoint_setup(vllm_runner): """Test ModelOpt FP8 checkpoint loading and structure validation.""" - # TODO: provide a small publically available test checkpoint + # TODO: provide a small publicly available test checkpoint model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/" "TinyLlama-1.1B-Chat-v1.0-fp8-0710") diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 4a0c8ba4d8a95..c09931971e6fb 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -77,6 +77,31 @@ def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp): assert output +@pytest.mark.parametrize('tp', [1]) +def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp): + model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts" + with vllm_runner(model_path, tensor_parallel_size=tp) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + + assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) + assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8) + + if isinstance(qkv_proj.scheme, QuarkW8A8Fp8): + assert qkv_proj.weight.dtype is current_platform.fp8_dtype() + assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[ + 1] + assert qkv_proj.weight_scale.shape[1] == 1 + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output + + @pytest.mark.parametrize('tp', [1]) def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp): model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index eef3568efea12..8e68f6a2e019f 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -75,5 +75,25 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): print(output) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now") +def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): + torch._dynamo.reset() + model_name = ("torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2" + "-0.14.0.dev") + with vllm_runner(model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0") as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + print(output) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py index 84c615b6b8dbc..22bdb3b44eb03 100644 --- a/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py +++ b/tests/runai_model_streamer_test/test_runai_model_streamer_loader.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import SamplingParams -from vllm.config import LoadConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader import get_model_loader load_format = "runai_streamer" diff --git a/tests/runai_model_streamer_test/test_runai_utils.py b/tests/runai_model_streamer_test/test_runai_utils.py new file mode 100644 index 0000000000000..bde77ff665063 --- /dev/null +++ b/tests/runai_model_streamer_test/test_runai_utils.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import glob +import os +import tempfile + +import huggingface_hub.constants + +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf) +from vllm.transformers_utils.runai_utils import (is_runai_obj_uri, + list_safetensors) + + +def test_is_runai_obj_uri(): + assert is_runai_obj_uri("gs://some-gcs-bucket/path") + assert is_runai_obj_uri("s3://some-s3-bucket/path") + assert not is_runai_obj_uri("nfs://some-nfs-path") + + +def test_runai_list_safetensors_local(): + with tempfile.TemporaryDirectory() as tmpdir: + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf("openai-community/gpt2", + allow_patterns=["*.safetensors", "*.json"], + cache_dir=tmpdir) + safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) + assert len(safetensors) > 0 + parentdir = [ + os.path.dirname(safetensor) for safetensor in safetensors + ][0] + files = list_safetensors(parentdir) + assert len(safetensors) == len(files) + + +if __name__ == "__main__": + test_is_runai_obj_uri() + test_runai_list_safetensors_local() diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index cc9a88a255f9f..0320a5ef31a65 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -82,7 +82,7 @@ def test_beam_search_with_concurrency_limit( beam_width: int, ) -> None: # example_prompts[1]&[3]&[7] fails due to unknown reason even without - # concurency limit. skip them for now. + # concurrency limit. skip them for now. example_prompts = (example_prompts[:8]) concurrency_limit = 2 assert len(example_prompts) > concurrency_limit diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py deleted file mode 100644 index 87f40b1005312..0000000000000 --- a/tests/samplers/test_logprobs.py +++ /dev/null @@ -1,182 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm import SamplingParams - -from ..conftest import VllmRunner - -MODELS = ["distilbert/distilgpt2"] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module is V0 only since it uses dtype=float, so - set VLLM_USE_V1=0 for all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", - ["float"]) # needed for comparing logprobs with HF -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) -@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size -@pytest.mark.parametrize("detokenize", [True, False]) -def test_get_prompt_logprobs( - hf_runner, - vllm_runner, - model, - dtype, - chunked_prefill_token_size: int, - num_top_logprobs: int, - detokenize: bool, - example_prompts, -): - max_num_seqs = 256 - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) - max_num_batched_tokens = chunked_prefill_token_size - - max_tokens = 5 - with hf_runner(model, dtype=dtype) as hf_model: - hf_logprobs = hf_model.generate_greedy_logprobs( - example_prompts, - max_tokens=max_tokens, - ) - - with vllm_runner( - model, - dtype=dtype, - max_logprobs=num_top_logprobs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs, - ) as vllm_model: - vllm_sampling_params = SamplingParams(max_tokens=max_tokens, - logprobs=num_top_logprobs, - prompt_logprobs=num_top_logprobs, - temperature=0.0, - detokenize=detokenize) - vllm_results = vllm_model.llm.generate( - example_prompts, sampling_params=vllm_sampling_params) - - # Test whether logprobs are included in the results. - for result in vllm_results: - assert result.prompt_logprobs is not None - assert result.outputs[0].logprobs is not None - assert len(result.outputs[0].logprobs) == max_tokens - for logprobs in result.outputs[0].logprobs: - # If the output token is not included in the top X - # logprob, it can return 1 more data - assert (len(logprobs) == num_top_logprobs - or len(logprobs) == num_top_logprobs + 1) - output_text = result.outputs[0].text - output_string_from_most_likely_tokens_lst: list[str] = [] - for top_logprobs in result.outputs[0].logprobs: - top_logprob = next(iter(top_logprobs.values())) - output_string_from_most_likely_tokens_lst.append( - top_logprob.decoded_token) - - if detokenize: - output_string_from_most_likely_tokens = "".join( - output_string_from_most_likely_tokens_lst) - assert output_text == output_string_from_most_likely_tokens, ( - "The output text from the top logprob for each token position " - "should be the same as the output text in the result.") - else: - assert output_text == '' - assert output_string_from_most_likely_tokens_lst == ([None] * - max_tokens) - - # The first prompt logprob is always None - assert result.prompt_logprobs[0] is None - for prompt_logprobs in result.prompt_logprobs[1:]: - # If the prompt token is not included in the top X - # logprob, it can return 1 more data - assert (len(prompt_logprobs) == num_top_logprobs - or len(prompt_logprobs) == num_top_logprobs + 1) - - # Test whether prompt logprobs are consistent with HF - for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): - # Check prompt logprobs - # The first prompt logprob is always None, so we compare it from 1:. - vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] - for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): - for token_id, logprob in vllm_prompt_logprob_dict.items(): - torch.testing.assert_close(logprob.logprob, - hf_logprob[0][i][token_id].item(), - atol=1e-2, - rtol=1e-2) - vllm_sample_logprobs = vllm_result.outputs[0].logprobs - for i, top_logprobs in enumerate(vllm_sample_logprobs): - for token_id, sample_logprob in top_logprobs.items(): - logprob = sample_logprob.logprob - torch.testing.assert_close(logprob, - hf_logprob[i][-1][token_id].item(), - atol=1e-2, - rtol=1e-2) - if detokenize: - assert isinstance(sample_logprob.decoded_token, str), ( - "The token should be decoded by the time it is returned" - " to the user.") - - # Test if prompt logprobs are correctly set. - for vllm_result in vllm_results: - token_ids = vllm_result.prompt_token_ids - prompt_logprobs = vllm_result.prompt_logprobs - - # The first token doesn't have logprob. - assert prompt_logprobs[0] is None - - for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]): - assert token_id in logprob_dict - - -def test_max_logprobs(): - runner = VllmRunner("facebook/opt-125m", max_logprobs=1) - vllm_sampling_params = SamplingParams(logprobs=1) - # should pass - runner.generate(["Hello world"], sampling_params=vllm_sampling_params) - - bad_sampling_params = SamplingParams(logprobs=2) - with pytest.raises(ValueError): - runner.generate(["Hello world"], sampling_params=bad_sampling_params) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) -@pytest.mark.parametrize("detokenize", [True, False]) -def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int, - detokenize: bool, example_prompts): - max_num_seqs = 256 - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) - max_num_batched_tokens = chunked_prefill_token_size - max_tokens = 5 - - with vllm_runner( - model, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs, - ) as vllm_model: - sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, - logprobs=None, - temperature=0.0, - detokenize=detokenize) - results_logprobs_none = vllm_model.llm.generate( - example_prompts, sampling_params=sampling_params_logprobs_none) - - for i in range(len(results_logprobs_none)): - assert results_logprobs_none[i].outputs[0].logprobs is None - assert results_logprobs_none[i].outputs[0].cumulative_logprob is None diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index 18aa4c88c0338..571dc2e0eb50f 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -90,6 +90,7 @@ class DummyExecutor(UniProcExecutor): distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) + self.mm_receiver_cache = None self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index 0fb142a1b6e56..e00d7c2f80c67 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -161,11 +161,11 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): model = vllm_runner( model_ref, model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + pytest.fail("Expected RuntimeError for extra config keys") except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: Model loader extra config " - "is not supported for load " + assert ("ValueError: Unexpected extra config keys for load " "format auto") in combined_output finally: del model @@ -181,11 +181,12 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref, load_format="safetensors", model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + pytest.fail("Expected RuntimeError for extra config keys") except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: Model loader extra config is not supported " + assert ("ValueError: Unexpected extra config keys " "for load format safetensors") in combined_output finally: del model diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py deleted file mode 100644 index edc0849dff33f..0000000000000 --- a/tests/test_cache_block_hashing.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test hashing of cache blocks. - -Run `pytest tests/test_cache_block_hashing.py`. -""" -from typing import Optional - -import pytest - -from vllm.inputs import token_inputs -from vllm.lora.request import LoRARequest -from vllm.sequence import Sequence -from vllm.transformers_utils.tokenizer_group import TokenizerGroup - -# Make two prefixes with different first blocks. -prefix_start = [("You are an expert"), ("You are a")] -prefix_common = ( - " school principal, skilled in effectively managing " - "faculty and staff. Draft 10-15 questions for a potential first grade " - "Head Teacher for my K-12, all-girls', independent school that emphasizes " - "community, joyful discovery, and life-long learning. The candidate is " - "coming in for a first-round panel interview for a 8th grade Math " - "teaching role. They have 5 years of previous teaching experience " - "as an assistant teacher at a co-ed, public school with experience " - "in middle school math teaching. Based on this, fulfill " - "the following: ") -prefixes = [start + prefix_common for start in prefix_start] - -# Sample prompts. -sample_prompts = [ - "Hello, my name is", "The president of the United States is", - "The capital of France is", "The future of AI is" -] - - -# Helper function. -def flatten_2d(li): - return [lss for ls in li for lss in ls] - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("max_num_seqs", [256]) -@pytest.mark.parametrize("concurrent_lora_int_ids", - [[None], [1], [None, 1], [None, 1, 2], [1, 2]]) -def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, - concurrent_lora_int_ids: list[Optional[int]]): - - tokenizer = TokenizerGroup( - tokenizer_id="facebook/opt-125m", - enable_lora=False, - max_num_seqs=max_num_seqs, - max_input_length=None, - ) - - hashes: list[list[list[int]]] = [] - - for prefix in prefixes: - for lora_int_id in concurrent_lora_int_ids: - lora_request = None - - if lora_int_id is not None: - lora_request = LoRARequest( - f"example_lora_{lora_int_id}", - lora_int_id, - f"example/path/to/lora_{lora_int_id}", - ) - - hashes.append([]) - prompts = [prefix + prompt for prompt in sample_prompts] - for seq_id, prompt in enumerate(prompts): - hashes[-1].append([]) - prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, - inputs=token_inputs(prompt_token_ids, - prompt=prompt), - block_size=block_size, - eos_token_id=tokenizer.tokenizer.eos_token_id, - lora_request=lora_request) - - num_blocks = len(prompt_token_ids) // block_size - for idx in range(num_blocks): - hashes[-1][-1].append(seq.hash_of_block(idx)) - - # Check that hashes made with two prefixes with different first blocks are - # different everywhere. - for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])): - assert (hash0 != hash1) - - # Check that hashes of different prompts made with the same prefix are the - # same until the hashes that contain the prompt. - for hash_pref in hashes: - same_hashes = [tuple(h[:-1]) for h in hash_pref] - different_hashes = [h[-1] for h in hash_pref] - assert (len(set(same_hashes)) == 1) - assert (len(set(different_hashes)) == len(different_hashes)) diff --git a/tests/test_config.py b/tests/test_config.py index 957771a4226bc..6e37bdbee59eb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,8 +6,9 @@ from dataclasses import MISSING, Field, asdict, dataclass, field import pytest from vllm.compilation.backends import VllmBackend -from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, - get_field, update_config) +from vllm.config import ModelConfig, PoolerConfig, VllmConfig, update_config +from vllm.config.load import LoadConfig +from vllm.config.utils import get_field from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform @@ -298,9 +299,8 @@ def test_rope_customization(): reason="Encoder Decoder models not supported on ROCm.") @pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [ ("facebook/opt-125m", False), - ("facebook/bart-base", True), + ("openai/whisper-tiny", True), ("meta-llama/Llama-3.2-1B-Instruct", False), - ("meta-llama/Llama-3.2-11B-Vision", True), ]) def test_is_encoder_decoder(model_id, is_encoder_decoder): config = ModelConfig(model_id) diff --git a/tests/test_sampling_params.py b/tests/test_sampling_params.py deleted file mode 100644 index 7330f61e67689..0000000000000 --- a/tests/test_sampling_params.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for the SamplingParams class. -""" - -import pytest - -from vllm import SamplingParams -from vllm.config import ModelConfig -from vllm.entrypoints.openai.protocol import ChatCompletionRequest - -MODEL_NAME = "Qwen/Qwen1.5-7B" - - -def test_max_tokens_none(): - """max_tokens=None should be allowed""" - SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) - - -@pytest.fixture(scope="module") -def model_config(): - return ModelConfig( - MODEL_NAME, - seed=0, - dtype="float16", - ) - - -@pytest.fixture(scope="module") -def default_max_tokens(): - return 4096 - - -def test_sampling_params_from_request_with_no_guided_decoding_backend( - model_config, default_max_tokens): - # guided_decoding_backend is not present at request level - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - MODEL_NAME, - 'response_format': { - 'type': 'json_object', - }, - }) - - sampling_params = request.to_sampling_params( - default_max_tokens, - model_config.logits_processor_pattern, - ) - # we do not expect any backend to be present and the default - # guided_decoding_backend at engine level will be used. - assert sampling_params.guided_decoding.backend is None - - -@pytest.mark.parametrize("request_level_guided_decoding_backend,expected", - [("xgrammar", "xgrammar"), ("guidance", "guidance"), - ("outlines", "outlines")]) -def test_sampling_params_from_request_with_guided_decoding_backend( - request_level_guided_decoding_backend: str, expected: str, - model_config, default_max_tokens): - - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - MODEL_NAME, - 'response_format': { - 'type': 'json_object', - }, - 'guided_decoding_backend': - request_level_guided_decoding_backend, - }) - - sampling_params = request.to_sampling_params( - default_max_tokens, - model_config.logits_processor_pattern, - ) - # backend correctly identified in resulting sampling_params - assert sampling_params.guided_decoding.backend == expected diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 1b019be9e56dc..da9826ff05058 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,104 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest import torch -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - SequenceData, SequenceOutput) - -from .core.utils import create_dummy_prompt - - -@pytest.fixture -def sample_outputs(): - return [ - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) - ], - prompt_logprobs=None) for i in range(5) - ] - - -@pytest.fixture -def sampler_output(sample_outputs): - return SamplerOutput(outputs=sample_outputs) - - -def test_sampler_output_initialization(sampler_output, sample_outputs): - assert len(sampler_output) == len(sample_outputs) - assert sampler_output.sampled_token_probs is None - assert sampler_output.sampled_token_ids is None - - -def test_sampler_output_getitem(sampler_output, sample_outputs): - assert sampler_output[2] == sample_outputs[2] - - -def test_sampler_output_setitem(sampler_output): - new_output = CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) - ], - prompt_logprobs=None) - sampler_output[2] = new_output - assert sampler_output[2] == new_output - - -def test_sampler_output_len(sampler_output, sample_outputs): - assert len(sampler_output) == len(sample_outputs) - - -def test_sampler_output_eq(sample_outputs): - sampler_output1 = SamplerOutput(outputs=sample_outputs) - sampler_output2 = SamplerOutput(outputs=sample_outputs.copy()) - sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1]) - assert sampler_output1 == sampler_output2 - assert sampler_output1 != sampler_output3 - - -def test_sequence_data_prefill(): - seq_data = SequenceData.from_seqs([1, 2, 3, 4]) - assert seq_data.get_num_uncomputed_tokens() == 4 - assert seq_data.get_num_computed_tokens() == 0 - # advance by 2 - seq_data.update_num_computed_tokens(2) - assert seq_data.get_num_uncomputed_tokens() == 2 - assert seq_data.get_num_computed_tokens() == 2 - - # advance by 1 - seq_data.update_num_computed_tokens(1) - assert seq_data.get_num_uncomputed_tokens() == 1 - assert seq_data.get_num_computed_tokens() == 3 - - # append tokens and reset, simulating recompute - seq_data.append_token_id(1, logprob=0.0) - seq_data.reset_state_for_recompute() - assert seq_data.get_num_uncomputed_tokens() == 5 - assert seq_data.get_num_computed_tokens() == 0 - - -def test_sequence_group_stage(): - _, seq_group = create_dummy_prompt("1", 12) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(6) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(5) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(1) - assert seq_group.is_prefill() is False - seqs = seq_group.get_seqs() - assert len(seqs) == 1 - seqs[0].data.append_token_id(1, logprob=0.0) - for seq in seq_group.get_seqs(): - seq.reset_state_for_recompute() - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(5) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(7) - assert seq_group.is_prefill() is True - seq_group.update_num_computed_tokens(1) - assert seq_group.is_prefill() is False +from vllm.sequence import IntermediateTensors def test_sequence_intermediate_tensors_equal(): diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index ea7ccfbb2b456..15ea55afe963b 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -11,7 +11,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, from vllm.inputs import token_inputs from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, @@ -61,14 +61,14 @@ def _run_incremental_decode(tokenizer, skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) - request = EngineCoreRequest("", - prompt_token_ids, - None, - params, - None, - None, - 0.0, - None, + request = EngineCoreRequest(request_id="", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, cache_salt=None, data_parallel_rank=None) @@ -221,17 +221,14 @@ def test_oov_decode(tokenizer, fast): @pytest.fixture def detokenizer(tokenizer_name: str) -> Detokenizer: - tokenizer_group = TokenizerGroup( - tokenizer_id=tokenizer_name, - enable_lora=False, - max_num_seqs=100, - max_input_length=None, + tokenizer = get_tokenizer( + tokenizer_name, tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", trust_remote_code=False, revision=None, ) - return Detokenizer(tokenizer_group) + return Detokenizer(tokenizer) @pytest.fixture(name="complete_sequence_token_ids") @@ -312,8 +309,7 @@ def test_decode_prompt_logprobs(complete_sequence: str, # don't support that. if complete_sequence not in SPECIAL_TOKS_TRUTH: skip_special_tokens = True - elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None), - MistralTokenizer): + elif not isinstance(detokenizer.tokenizer, MistralTokenizer): skip_special_tokens = False else: pytest.skip("MistralTokenizers don't support " @@ -339,7 +335,7 @@ def test_decode_prompt_logprobs(complete_sequence: str, # decoded_prompt_logprobs doesn't contain the first token. token_ids = complete_sequence_token_ids - tokenizer = detokenizer.get_tokenizer_for_seq(seq) + tokenizer = detokenizer.tokenizer text_full = tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) text_first = tokenizer.decode(token_ids[0], diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py deleted file mode 100644 index 0570c1525e111..0000000000000 --- a/tests/tokenization/test_tokenizer_group.py +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -from transformers import AutoTokenizer, PreTrainedTokenizerBase - -from vllm.transformers_utils.tokenizer_group import TokenizerGroup - - -@pytest.mark.asyncio -async def test_tokenizer_group(): - reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer_group = TokenizerGroup( - tokenizer_id="gpt2", - enable_lora=False, - max_num_seqs=1, - max_input_length=None, - ) - assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( - prompt="prompt", lora_request=None) - assert reference_tokenizer.encode( - "prompt") == await tokenizer_group.encode_async(prompt="prompt", - lora_request=None) - assert isinstance(tokenizer_group.get_lora_tokenizer(None), - PreTrainedTokenizerBase) - assert tokenizer_group.get_lora_tokenizer( - None) == await tokenizer_group.get_lora_tokenizer_async(None) diff --git a/tests/tokenization/test_tokenizer_registry.py b/tests/tokenization/test_tokenizer_registry.py index 5abb101644086..68d4b416b4c99 100644 --- a/tests/tokenization/test_tokenizer_registry.py +++ b/tests/tokenization/test_tokenizer_registry.py @@ -57,6 +57,10 @@ class TestTokenizer(TokenizerBase): def max_token_id(self) -> int: raise NotImplementedError() + @property + def truncation_side(self) -> str: + raise NotImplementedError() + def __call__( self, text: Union[str, list[str], list[int]], diff --git a/tests/tool_use/test_openai_tool_parser.py b/tests/tool_use/test_openai_tool_parser.py new file mode 100644 index 0000000000000..0192c7d2765cd --- /dev/null +++ b/tests/tool_use/test_openai_tool_parser.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest +from openai_harmony import (Conversation, DeveloperContent, + HarmonyEncodingName, Message, Role, SystemContent, + load_harmony_encoding) + +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +MODEL = "gpt2" + + +@pytest.fixture(scope="module") +def openai_tokenizer(): + # The parser does not use the tokenizer, but the constructor requires it. + return get_tokenizer(MODEL) + + +@pytest.fixture +def openai_tool_parser(openai_tokenizer): + return OpenAIToolParser(openai_tokenizer) + + +@pytest.fixture(scope="module") +def harmony_encoding(): + return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall], +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 16 # Default from protocol.py + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + +def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding): + convo = Conversation.from_messages([ + Message.from_role_and_content( + Role.SYSTEM, + SystemContent.new(), + ), + Message.from_role_and_content( + Role.DEVELOPER, + DeveloperContent.new().with_instructions("Talk like a pirate!")), + Message.from_role_and_content(Role.USER, "Arrr, how be you?"), + Message.from_role_and_content(Role.ASSISTANT, + "This is a test").with_channel("final") + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, Role.ASSISTANT) + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert not extracted_info.tools_called + assert extracted_info.tool_calls == [] + assert extracted_info.content == "This is a test" + + +def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding): + convo = Conversation.from_messages([ + Message.from_role_and_content(Role.USER, + "What is the weather in Tokyo?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_current_weather").with_content_type("json"), + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, Role.ASSISTANT) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + )) + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content is None + + +def test_extract_tool_calls_multiple_tools( + openai_tool_parser, + harmony_encoding, +): + convo = Conversation.from_messages([ + Message.from_role_and_content( + Role.USER, "What is the weather in Tokyo based on where I'm at?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_current_weather").with_content_type("json"), + Message.from_role_and_content( + Role.ASSISTANT, + '{"location": "Tokyo"}').with_channel("commentary").with_recipient( + "functions.get_user_location").with_content_type("json"), + ]) + token_ids = harmony_encoding.render_conversation_for_completion( + convo, + Role.ASSISTANT, + ) + + extracted_info = openai_tool_parser.extract_tool_calls( + "", + request=None, + token_ids=token_ids, + ) + assert extracted_info.tools_called + expected_tool_calls = [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + )), + ToolCall(function=FunctionCall( + name="get_user_location", + arguments=json.dumps({"location": "Tokyo"}), + )) + ] + assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) + assert extracted_info.content is None diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index e0ed221a93e12..130e9547bdccb 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -68,7 +68,7 @@ EXAMPLE_TOOLS = [ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, should_match: bool): self = MagicMock(tool_choice="required", tools=tools) - schema = ChatCompletionRequest._get_guided_json_from_tool(self) + schema = ChatCompletionRequest._get_json_schema_from_tool(self) assert isinstance(schema, dict) # use build_regex_from_schema used in JSONLogitsProcessor to create Guide @@ -218,7 +218,7 @@ VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS] } }, {}], False), ]) -def test_guided_json(sample_output, should_match): +def test_structured_outputs_json(sample_output, should_match): _compile_and_check(tools=TypeAdapter( list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS), sample_output=sample_output, @@ -273,8 +273,9 @@ def update_parameters_empty_dict( @pytest.mark.parametrize( "update_parameters", [update_parameters_none, update_parameters_empty_dict]) -def test_guided_json_without_parameters(sample_output, should_match, - update_parameters): +def test_structured_outputs_json_without_parameters(sample_output, + should_match, + update_parameters): updated_tools = [deepcopy(EXAMPLE_TOOLS[0])] tools = TypeAdapter( list[ChatCompletionToolsParam]).validate_python(updated_tools) @@ -334,4 +335,4 @@ def test_streaming_output_valid(output, empty_params, delta_len): combined_messages += message.tool_calls[0].function.arguments combined_messages += "}]" assert json.loads(combined_messages) == output - assert json.dumps(json.loads(combined_messages)) == output_json \ No newline at end of file + assert json.dumps(json.loads(combined_messages)) == output_json diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index 6cefbae4bdd18..8d9fbd280317c 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -28,7 +28,7 @@ ACCURACY_CONFIGS = [ expected_value=0.76), # no bias # NOTE(rob): We cannot re-initialize vLLM in the same process for TPU, # so only one of these tests can run in a single call to pytest. As - # a follow up, move this into the LM-EVAL section of the CI. + # a follow-up, move this into the LM-EVAL section of the CI. # GSM8KAccuracyTestConfig( # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", # expected_value=0.66), # bias in QKV layers diff --git a/tests/tracing/test_tracing.py b/tests/tracing/test_tracing.py deleted file mode 100644 index 4dbae7c15de3a..0000000000000 --- a/tests/tracing/test_tracing.py +++ /dev/null @@ -1,237 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# ruff: noqa -# type: ignore -from __future__ import annotations - -import threading -from collections.abc import Iterable -from concurrent import futures -from typing import Callable, Generator, Literal - -import grpc -import pytest -from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( - ExportTraceServiceResponse) -from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( - TraceServiceServicer, add_TraceServiceServicer_to_server) -from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue -from opentelemetry.sdk.environment_variables import ( - OTEL_EXPORTER_OTLP_TRACES_INSECURE) - -from vllm import LLM, SamplingParams -from vllm.tracing import SpanAttributes - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch: pytest.MonkeyPatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - with monkeypatch.context() as m: - m.setenv('VLLM_USE_V1', '0') - yield - - -FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" - -FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value', - 'array_value'] - - -def decode_value(value: AnyValue): - field_decoders: dict[FieldName, Callable] = { - "bool_value": (lambda v: v.bool_value), - "string_value": (lambda v: v.string_value), - "int_value": (lambda v: v.int_value), - "double_value": (lambda v: v.double_value), - "array_value": - (lambda v: [decode_value(item) for item in v.array_value.values]), - } - for field, decoder in field_decoders.items(): - if value.HasField(field): - return decoder(value) - raise ValueError(f"Couldn't decode value: {value}") - - -def decode_attributes(attributes: Iterable[KeyValue]): - return {kv.key: decode_value(kv.value) for kv in attributes} - - -class FakeTraceService(TraceServiceServicer): - - def __init__(self): - self.request = None - self.evt = threading.Event() - - def Export(self, request, context): - self.request = request - self.evt.set() - return ExportTraceServiceResponse() - - -@pytest.fixture -def trace_service() -> Generator[FakeTraceService, None, None]: - """Fixture to set up a fake gRPC trace service""" - server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) - service = FakeTraceService() - add_TraceServiceServicer_to_server(service, server) - server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS) - server.start() - - yield service - - server.stop(None) - - -def test_traces( - monkeypatch: pytest.MonkeyPatch, - trace_service: FakeTraceService, -): - with monkeypatch.context() as m: - m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") - - sampling_params = SamplingParams( - temperature=0.01, - top_p=0.1, - max_tokens=256, - ) - model = "facebook/opt-125m" - llm = LLM( - model=model, - otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, - ) - prompts = ["This is a short prompt"] - outputs = llm.generate(prompts, sampling_params=sampling_params) - - timeout = 5 - if not trace_service.evt.wait(timeout): - raise TimeoutError( - f"The fake trace service didn't receive a trace within " - f"the {timeout} seconds timeout") - - request = trace_service.request - assert len(request.resource_spans) == 1, ( - f"Expected 1 resource span, " - f"but got {len(request.resource_spans)}") - assert len(request.resource_spans[0].scope_spans) == 1, ( - f"Expected 1 scope span, " - f"but got {len(request.resource_spans[0].scope_spans)}") - assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( - f"Expected 1 span, " - f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") - - attributes = decode_attributes( - request.resource_spans[0].scope_spans[0].spans[0].attributes) - assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE - ) == sampling_params.temperature - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS - ) == sampling_params.max_tokens - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( - outputs[0].prompt_token_ids) - completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens - metrics = outputs[0].metrics - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE - ) == metrics.time_in_queue - ttft = metrics.first_token_time - metrics.arrival_time - assert attributes.get( - SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft - e2e_time = metrics.finished_time - metrics.arrival_time - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time - assert metrics.scheduler_time > 0 - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER - ) == metrics.scheduler_time - # Model forward and model execute should be none, since detailed traces is - # not enabled. - assert metrics.model_forward_time is None - assert metrics.model_execute_time is None - - -def test_traces_with_detailed_steps( - monkeypatch: pytest.MonkeyPatch, - trace_service: FakeTraceService, -): - with monkeypatch.context() as m: - m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") - - sampling_params = SamplingParams( - temperature=0.01, - top_p=0.1, - max_tokens=256, - ) - model = "facebook/opt-125m" - llm = LLM( - model=model, - otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, - collect_detailed_traces=["all"], - ) - prompts = ["This is a short prompt"] - outputs = llm.generate(prompts, sampling_params=sampling_params) - - timeout = 5 - if not trace_service.evt.wait(timeout): - raise TimeoutError( - f"The fake trace service didn't receive a trace within " - f"the {timeout} seconds timeout") - - request = trace_service.request - assert len(request.resource_spans) == 1, ( - f"Expected 1 resource span, " - f"but got {len(request.resource_spans)}") - assert len(request.resource_spans[0].scope_spans) == 1, ( - f"Expected 1 scope span, " - f"but got {len(request.resource_spans[0].scope_spans)}") - assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( - f"Expected 1 span, " - f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") - - attributes = decode_attributes( - request.resource_spans[0].scope_spans[0].spans[0].attributes) - assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE - ) == sampling_params.temperature - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS - ) == sampling_params.max_tokens - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( - outputs[0].prompt_token_ids) - completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens - metrics = outputs[0].metrics - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE - ) == metrics.time_in_queue - ttft = metrics.first_token_time - metrics.arrival_time - assert attributes.get( - SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft - e2e_time = metrics.finished_time - metrics.arrival_time - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time - assert metrics.scheduler_time > 0 - assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER - ) == metrics.scheduler_time - assert metrics.model_forward_time > 0 - assert attributes.get( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD - ) == pytest.approx(metrics.model_forward_time / 1000) - assert metrics.model_execute_time > 0 - assert attributes.get( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE - ) == metrics.model_execute_time - assert metrics.model_forward_time < 1000 * metrics.model_execute_time diff --git a/tests/mq_llm_engine/__init__.py b/tests/transformers_utils/__init__.py similarity index 100% rename from tests/mq_llm_engine/__init__.py rename to tests/transformers_utils/__init__.py diff --git a/tests/transformers_utils/test_config_parser_registry.py b/tests/transformers_utils/test_config_parser_registry.py new file mode 100644 index 0000000000000..13c654e05d2ac --- /dev/null +++ b/tests/transformers_utils/test_config_parser_registry.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path +from typing import Optional, Union + +import pytest +from transformers import PretrainedConfig + +from vllm.transformers_utils.config import (get_config_parser, + register_config_parser) +from vllm.transformers_utils.config_parser_base import ConfigParserBase + + +@register_config_parser("custom_config_parser") +class CustomConfigParser(ConfigParserBase): + + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + raise NotImplementedError + + +def test_register_config_parser(): + assert isinstance(get_config_parser("custom_config_parser"), + CustomConfigParser) + + +def test_invalid_config_parser(): + with pytest.raises(ValueError): + + @register_config_parser("invalid_config_parser") + class InvalidConfigParser: + pass diff --git a/tests/utils.py b/tests/utils.py index 9d2073f3c1036..9a27c3de4533d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,21 +2,24 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import contextlib import copy import functools import importlib import json import os +import random import signal import subprocess import sys import tempfile import time import warnings -from contextlib import contextmanager, suppress +from contextlib import ExitStack, contextmanager, suppress from multiprocessing import Process from pathlib import Path from typing import Any, Callable, Literal, Optional, Union +from unittest.mock import patch import cloudpickle import httpx @@ -799,43 +802,106 @@ _P = ParamSpec("_P") def fork_new_process_for_each_test( - f: Callable[_P, None]) -> Callable[_P, None]: + func: Callable[_P, None]) -> Callable[_P, None]: """Decorator to fork a new process for each test function. See https://github.com/vllm-project/vllm/issues/7053 for more details. """ - @functools.wraps(f) + @functools.wraps(func) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Make the process the leader of its own process group # to avoid sending SIGTERM to the parent process os.setpgrp() from _pytest.outcomes import Skipped - pid = os.fork() - print(f"Fork a new process to run a test {pid}") - if pid == 0: - try: - f(*args, **kwargs) - except Skipped as e: - # convert Skipped to exit code 0 - print(str(e)) - os._exit(0) - except Exception: - import traceback - traceback.print_exc() - os._exit(1) + + # Create a unique temporary file to store exception info from child + # process. Use test function name and process ID to avoid collisions. + with tempfile.NamedTemporaryFile( + delete=False, + mode='w+b', + prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", + suffix=".exc") as exc_file, ExitStack() as delete_after: + exc_file_path = exc_file.name + delete_after.callback(os.remove, exc_file_path) + + pid = os.fork() + print(f"Fork a new process to run a test {pid}") + if pid == 0: + # Parent process responsible for deleting, don't delete + # in child. + delete_after.pop_all() + try: + func(*args, **kwargs) + except Skipped as e: + # convert Skipped to exit code 0 + print(str(e)) + os._exit(0) + except Exception as e: + import traceback + tb_string = traceback.format_exc() + + # Try to serialize the exception object first + exc_to_serialize: dict[str, Any] + try: + # First, try to pickle the actual exception with + # its traceback. + exc_to_serialize = {'pickled_exception': e} + # Test if it can be pickled + cloudpickle.dumps(exc_to_serialize) + except (Exception, KeyboardInterrupt): + # Fall back to string-based approach. + exc_to_serialize = { + 'exception_type': type(e).__name__, + 'exception_msg': str(e), + 'traceback': tb_string, + } + try: + with open(exc_file_path, 'wb') as f: + cloudpickle.dump(exc_to_serialize, f) + except Exception: + # Fallback: just print the traceback. + print(tb_string) + os._exit(1) + else: + os._exit(0) else: - os._exit(0) - else: - pgid = os.getpgid(pid) - _pid, _exitcode = os.waitpid(pid, 0) - # ignore SIGTERM signal itself - old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) - # kill all child processes - os.killpg(pgid, signal.SIGTERM) - # restore the signal handler - signal.signal(signal.SIGTERM, old_signal_handler) - assert _exitcode == 0, (f"function {f} failed when called with" - f" args {args} and kwargs {kwargs}") + pgid = os.getpgid(pid) + _pid, _exitcode = os.waitpid(pid, 0) + # ignore SIGTERM signal itself + old_signal_handler = signal.signal(signal.SIGTERM, + signal.SIG_IGN) + # kill all child processes + os.killpg(pgid, signal.SIGTERM) + # restore the signal handler + signal.signal(signal.SIGTERM, old_signal_handler) + if _exitcode != 0: + # Try to read the exception from the child process + exc_info = {} + if os.path.exists(exc_file_path): + with contextlib.suppress(Exception), \ + open(exc_file_path, 'rb') as f: + exc_info = cloudpickle.load(f) + + if (original_exception := + exc_info.get('pickled_exception')) is not None: + # Re-raise the actual exception object if it was + # successfully pickled. + assert isinstance(original_exception, Exception) + raise original_exception + + if (original_tb := exc_info.get("traceback")) is not None: + # Use string-based traceback for fallback case + raise AssertionError( + f"Test {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode}):\n{original_tb}" + ) from None + + # Fallback to the original generic error + raise AssertionError( + f"function {func.__name__} failed when called with" + f" args {args} and kwargs {kwargs}" + f" (exit code: {_exitcode})") from None return wrapper @@ -1077,3 +1143,57 @@ def get_attn_backend_list_based_on_platform() -> list[str]: return attn_backend_list else: raise ValueError("Unsupported platform") + + +@contextmanager +def override_cutlass_fp8_supported(value: bool): + with patch( + "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported", + return_value=value): + yield + + +def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): + """ + Generate prompts which a bunch of assignments, + then asking for the value of one of them. + The prompt is just under 10k tokens; sliding window is 4k + so the answer is outside sliding window, but should still be correct. + Args: + batch_size: number of prompts to generate + ln_range: an argument to control the length of the prompt + """ + prompts: list[str] = [] + answer: list[int] = [] + indices: list[int] = [] + random.seed(1) + for _ in range(batch_size): + idx = random.randint(30, 90) + indices.append(idx) + prompt = "```python\n# We set a number of variables, " + \ + f"x{idx} will be important later\n" + ln = random.randint(*ln_range) + for k in range(30, ln): + v = random.randint(10, 99) + if k == idx: + answer.append(v) + prompt += f"x{k} = {v}\n" + prompt += f"# Now, we check the value of x{idx}:\n" + prompt += f"assert x{idx} == " + prompts.append(prompt) + return prompts, answer, indices + + +def check_answers(indices: list[int], + answer: list[int], + outputs: list[str], + accept_rate: float = 0.7): + answer2 = [int(text[0:2].strip()) for text in outputs] + print(list(zip(indices, zip(answer, answer2)))) + numok = 0 + for a1, a2 in zip(answer, answer2): + if a1 == a2: + numok += 1 + frac_ok = numok / len(answer) + print(f"Num OK: {numok}/{len(answer)} {frac_ok}") + assert frac_ok >= accept_rate diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 66124dd854ee0..608f517f69145 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -501,34 +501,6 @@ def test_bind_kv_cache_non_attention(): assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1] -def test_bind_kv_cache_encoder_decoder(monkeypatch: pytest.MonkeyPatch): - # V1 TESTS: ENCODER_DECODER is not supported on V1 yet. - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - - from vllm.attention import Attention, AttentionType - - # example from bart - ctx = { - 'encoder.layers.0.self_attn.attn': - Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), - 'decoder.layers.0.encoder_attn.attn': - Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), - 'decoder.layers.0.self_attn.attn': - Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), - } - - kv_cache = [ - torch.zeros((1, )), - ] - encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache - - bind_kv_cache(ctx, [kv_cache]) - assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache - assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0] - assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0] - - def test_bind_kv_cache_pp(): with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): # this test runs with 1 GPU, but we simulate 2 GPUs @@ -835,22 +807,20 @@ def test_model_specification(parser_with_config, cli_config_file, @pytest.mark.parametrize("input", [(), ("abc", ), (None, ), (None, bool, [1, 2, 3])]) -@pytest.mark.parametrize("output", [0, 1, 2]) -def test_sha256(input: tuple, output: int): - hash = sha256(input) - assert hash is not None - assert isinstance(hash, int) - assert hash != 0 +def test_sha256(input: tuple): + digest = sha256(input) + assert digest is not None + assert isinstance(digest, bytes) + assert digest != b"" - bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), - byteorder="big") + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + assert digest == hashlib.sha256(input_bytes).digest() # hashing again, returns the same value - assert hash == sha256(input) + assert digest == sha256(input) # hashing different input, returns different value - assert hash != sha256(input + (1, )) + assert digest != sha256(input + (1, )) @pytest.mark.parametrize( diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index e4c07aae0ebed..0b7e103beca63 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -70,22 +70,6 @@ BATCH_SPECS = { } -def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: - """Create a dummy KV cache tensor for testing.""" - kv_cache = torch.randn( - 2, # K and V - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), - device=device, - ) - return kv_cache - - def create_and_prepopulate_kv_cache( k_contexts: list[torch.Tensor], v_contexts: list[torch.Tensor], @@ -194,6 +178,7 @@ class MockAttentionLayer: self._k_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(1.0, device=device) # Add float versions for flashinfer + self._q_scale_float = 1.0 self._k_scale_float = 1.0 self._v_scale_float = 1.0 diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 3fc1011d5042e..c74dbb3ebb17e 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -6,7 +6,7 @@ import torch from tests.v1.attention.test_attention_backends import BATCH_SPECS from tests.v1.attention.utils import create_common_attn_metadata -from vllm.v1.attention.backends.utils import (UbatchSlice, +from vllm.v1.attention.backends.utils import (UBatchSlice, _make_metadata_with_slice, slice_query_start_locs, split_attn_metadata) @@ -106,7 +106,7 @@ def mixed_small_metadata(): def test_make_metadata_with_slice_decode_batch(small_decode_metadata): """Test slicing decode batch metadata""" # Split first request only - ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1)) + ubatch_slice = UBatchSlice(slice(0, 1), slice(0, 1)) result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) @@ -120,7 +120,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata): def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): """Test slicing mixed batch metadata""" - ubatch_slice = UbatchSlice(slice(1, 3), + ubatch_slice = UBatchSlice(slice(1, 3), slice(1, 7)) # Requests 1-3, tokens 1-7 result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) @@ -137,8 +137,8 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): num_tokens = large_decode_metadata.num_reqs mid_point = num_tokens // 2 ubatch_slices = [ - UbatchSlice(slice(0, mid_point), slice(0, mid_point)), - UbatchSlice(slice(mid_point, num_tokens), slice(mid_point, + UBatchSlice(slice(0, mid_point), slice(0, mid_point)), + UBatchSlice(slice(mid_point, num_tokens), slice(mid_point, num_tokens)), ] diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py index 8c5a63653db9f..be77256a0d2f0 100644 --- a/tests/v1/attention/test_chunked_local_attention.py +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -160,7 +160,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): # Use torch.arange instead of torch.randint so we can assert on # block table tensor values. The block table will have shape # (num_batches, cdiv(max_seq_len, block_size)) and the values will be - # aranged from 0 to cdiv(max_seq_len, block_size)-1 + # arranged from 0 to cdiv(max_seq_len, block_size)-1 arange_block_indices=True, ) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 24070358799ef..a62993950affe 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, + _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA, _Backend.TRITON_MLA_VLLM_V1 ] @@ -69,25 +69,10 @@ BATCH_SPECS = { } -def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: - """Create a dummy KV cache tensor for testing.""" - kv_cache = torch.randn( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.head_size, # latent dimension - dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), - device=device, - ) - return kv_cache - - def create_and_prepopulate_kv_cache( kv_c_contexts: list[torch.Tensor], k_pe_contexts: list[torch.Tensor], block_size: int, - num_kv_heads: int, head_size: int, dtype: torch.dtype, device: torch.device, @@ -101,7 +86,6 @@ def create_and_prepopulate_kv_cache( k_pe_contexts: List of key positional embedding context tensors for each sequence block_size: Size of each block - num_kv_heads: Number of KV heads (should be 1 for MLA) head_size: Size of each head (latent dimension) dtype: Data type for the cache device: Device to create the cache on @@ -299,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): query_lens = batch_spec.query_lens num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) - num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) block_size = vllm_config.cache_config.block_size @@ -315,7 +297,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # 2. Generate data and compute SDPA reference output for MLA all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] - all_sdpa_outputs = [] + all_sdpa_outputs: list[list[torch.Tensor]] = [] kv_c_contexts, k_pe_contexts = [], [] # Create shared MLA weight matrices for consistency across all sequences @@ -331,6 +313,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): device=device) kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + for i, backend in enumerate(BACKENDS_TO_TEST): + all_sdpa_outputs.append([]) + for i in range(batch_size): s_len = seq_lens[i] q_len = query_lens[i] @@ -358,85 +343,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): dtype=dtype, device=device) - # Determine if this is decode (single token) - # or prefill (multiple tokens) - is_decode = q_len == 1 + # Determine if this is decode or prefill + is_decode = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + builder_cls, _ = get_attention_backend(backend) + is_decode.append(q_len <= builder_cls.reorder_batch_threshold) # Split q into nope and rope components q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) - if is_decode: - # Decode path: MQA-style attention in latent space - # Transform q_nope to latent space: q_nope @ W_UK - # q_nope: [1, num_heads, qk_nope_head_dim] - # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] - ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, - W_UK) # [1, num_heads, kv_lora_rank] + ####################################################### + # Decode path: MQA-style attention in latent space + # Transform q_nope to latent space: q_nope @ W_UK + # q_nope: [1, num_heads, qk_nope_head_dim] + # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, + W_UK) # [1, num_heads, kv_lora_rank] - # Build MQA attention inputs - # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] - q_mqa = torch.cat([ql_nope, q_pe], dim=-1) - # K: [s_len, kv_lora_rank + qk_rope_head_dim] - # (broadcasted to all heads) - k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) - k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) - # V: [s_len, kv_lora_rank] (broadcasted to all heads) - v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) + # Build MQA attention inputs + # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + # K: [s_len, kv_lora_rank + qk_rope_head_dim] + # (broadcasted to all heads) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) + # V: [s_len, kv_lora_rank] (broadcasted to all heads) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) - # SDPA expects (N, H, L, D) - q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) - k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) - v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + # Create custom attention mask for decode path: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their position + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale) - sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze( - 0) # [1, num_heads, kv_lora_rank] + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) - # Project back to output space: sdpa_out @ W_UV - sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV) - sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) - else: - # Prefill path: MHA-style attention with full sequence - # Apply kv_b_proj to the full kv_c tensor - kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, - kv_b_proj_weight) - k_nope_full, v_full = kv_nope_full.split( - [qk_nope_head_dim, v_head_dim], dim=-1) + sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze( + 0) # [1, num_heads, kv_lora_rank] - # Build attention inputs for full sequence - q_mha = torch.cat([q_nope, q_pe], - dim=-1) # [q_len, num_heads, total_dim] - k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) - k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) + # Project back to output space: sdpa_out @ W_UV + sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, + W_UV) + sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2) - # Create custom attention mask: - # - Query tokens can attend to all context tokens - # - Query tokens can only attend to query tokens up to their pos - attn_mask = torch.ones(q_len, - s_len, - dtype=torch.bool, - device=device) - # Apply causal mask only to the query portion (context_len onwards) - causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) - attn_mask[:, context_len:] = causal_mask + ####################################################### + # Prefill path: MHA-style attention with full sequence + # Apply kv_b_proj to the full kv_c tensor + kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight) + k_nope_full, v_full = kv_nope_full.split( + [qk_nope_head_dim, v_head_dim], dim=-1) - # SDPA expects (N, H, L, D) - q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) - k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) - v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + # Build attention inputs for full sequence + q_mha = torch.cat([q_nope, q_pe], + dim=-1) # [q_len, num_heads, total_dim] + k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) + k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) - # Single attention call with custom mask - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, - k_sdpa_in, - v_sdpa_in, - attn_mask=attn_mask, - scale=scale) - sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0) - sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + # Create custom attention mask: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their pos + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask - all_sdpa_outputs.append(sdpa_out_i) + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + # Single attention call with custom mask + sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) + sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) + + for i, backend in enumerate(BACKENDS_TO_TEST): + if is_decode[i]: + all_sdpa_outputs[i].append(sdpa_out_i_decode) + else: + all_sdpa_outputs[i].append(sdpa_out_i_prefill) # Inputs for vLLM MLA backends are just the new tokens all_q_vllm.append(q_c) @@ -451,7 +444,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): query_vllm = torch.cat(all_q_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) - sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + sdpa_outputs = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0)) # Create mock kv_b_proj using the same weights as reference implementation from vllm.model_executor.layers.linear import ColumnParallelLinear @@ -477,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): kv_c_contexts=kv_c_contexts, k_pe_contexts=k_pe_contexts, block_size=block_size, - num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, device=device, @@ -486,7 +480,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): randomize_blocks=True) # 4. Run vLLM backends and compare - for backend_name in BACKENDS_TO_TEST: + for i, backend_name in enumerate(BACKENDS_TO_TEST): backend_output = run_attention_backend( backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, @@ -494,12 +488,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): mock_kv_b_proj) # Check shape and dtype consistency - assert backend_output.shape == sdpa_output.shape, ( + assert backend_output.shape == sdpa_outputs[i].shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_output.shape}") - assert backend_output.dtype == sdpa_output.dtype, ( + f"SDPA shape {sdpa_outputs[i].shape}") + assert backend_output.dtype == sdpa_outputs[i].dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_output.dtype}") + f"SDPA dtype {sdpa_outputs[i].dtype}") assert torch.isfinite(backend_output).all(), ( f"[{backend_name}] produced non-finite values") @@ -508,12 +502,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-1 - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_diff = torch.max(torch.abs(backend_output - + sdpa_outputs[i])).item() max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_output) / - torch.abs(sdpa_output)).item() + torch.abs(backend_output - sdpa_outputs[i]) / + torch.abs(sdpa_outputs[i])).item() all_close = torch.allclose(backend_output, - sdpa_output, + sdpa_outputs[i], rtol=rtol, atol=atol) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 6a08cdc56f736..f07c6eb0ea4da 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -139,6 +139,10 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", _Backend.FLASHMLA_VLLM_V1: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", + _Backend.FLASH_ATTN_MLA: + "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", + _Backend.FLASHINFER_MLA: + "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", _Backend.TRITON_MLA_VLLM_V1: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", } diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py index ae5b751f45a4b..4e3cace86be6a 100644 --- a/tests/v1/core/test_encoder_cache_manager.py +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange from vllm.v1.core.encoder_cache_manager import EncoderCacheManager @@ -9,8 +10,17 @@ class MockRequest: def __init__(self, request_id, mm_hashes, token_counts): self.request_id = request_id - self.mm_hashes = mm_hashes self._token_counts = token_counts + self.mm_features = [] + for i, mm_hash in enumerate(mm_hashes): + feature = MultiModalFeatureSpec( + data=None, + modality="image", + identifier=mm_hash, + mm_position=PlaceholderRange(offset=0, + length=self._token_counts[i]), + ) + self.mm_features.append(feature) def get_num_encoder_tokens(self, input_id: int) -> int: return self._token_counts[input_id] diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 4d0a26f76e98e..319e6e84fba1e 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -6,29 +6,40 @@ from typing import Callable, Optional import pytest import torch +import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit +from vllm.utils import GiB_bytes, sha256, sha256_cbor from vllm.v1.core.kv_cache_manager import KVCacheManager # disable yapf here as it formats differently than isort such that both fail # yapf: disable from vllm.v1.core.kv_cache_utils import ( - FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, + BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, - get_kv_cache_config, get_max_concurrency_for_kv_cache_config, + get_kv_cache_configs, get_max_concurrency_for_kv_cache_config, get_request_block_hasher, hash_block_tokens, init_none_hash, - is_kv_cache_type_uniform, unify_kv_cache_configs) + is_kv_cache_type_uniform, make_block_hash_with_group_id) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor, - SlidingWindowSpec) + KVCacheGroupSpec, KVCacheSpec, + KVCacheTensor, SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request # yapf: enable +@pytest.fixture(autouse=True) +def _auto_init_hash_fn(request): + hash_fn: Callable + if "hash_fn" in request.fixturenames: + hash_fn = init_none_hash(request.getfixturevalue("hash_fn")) + else: + hash_fn = sha256 + init_none_hash(hash_fn) + + def make_request( request_id: str, prompt_token_ids: list[int], @@ -88,7 +99,7 @@ def new_sliding_window_spec(block_size=16, sliding_window=sliding_window) -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils @@ -98,8 +109,8 @@ def test_none_hash(monkeypatch, hash_fn): reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None - assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) - assert reloaded_kv_cache_utils.NONE_HASH != 0 + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) + assert reloaded_kv_cache_utils.NONE_HASH != b"" # case 2: PYTHONHASHSEED is set, use the seed and hash_fn with monkeypatch.context() as m: @@ -107,12 +118,11 @@ def test_none_hash(monkeypatch, hash_fn): reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None - assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH def test_kv_cache_block(): - import vllm.v1.core.kv_cache_utils # Test KVCacheBlock initialization block = KVCacheBlock(block_id=0) @@ -127,8 +137,7 @@ def test_kv_cache_block(): assert block.ref_cnt == 0 # Test block hash setting and resetting - block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123, - token_ids=(1, 2, 3)) + block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0) block.block_hash = block_hash assert block.block_hash == block_hash @@ -244,6 +253,18 @@ def test_free_kv_cache_block_queue_append_n(): assert blocks[3].next_free_block is queue.fake_free_list_tail assert queue.fake_free_list_tail.prev_free_block is blocks[3] + # Create an empty FreeKVCacheBlockQueue + invalid_queue = FreeKVCacheBlockQueue([]) + # set prev_free_block to None and this will cause assertation in append_n + invalid_queue.fake_free_list_tail.prev_free_block = None + with pytest.raises(AssertionError): + # Append 1 block + # fake_head->fake_tail + invalid_queue.append_n(blocks[0:1]) + assert invalid_queue.num_free_blocks == 0 + assert (invalid_queue.fake_free_list_head.next_free_block == + invalid_queue.fake_free_list_tail) + def test_free_kv_cache_block_queue_popleft_n(): blocks = [KVCacheBlock(block_id=i) for i in range(6)] @@ -269,9 +290,11 @@ def test_free_kv_cache_block_queue_popleft_n(): # Pop 0 block # fake_head->b1->b3->b5->b4->b0->b2->fake_tail assert len(queue.popleft_n(0)) == 0 + assert queue.num_free_blocks == 6 # Pop 1 block # fake_head->b3->b5->b4->b0->b2->fake_tail result_blocks = queue.popleft_n(1) + assert queue.num_free_blocks == 5 assert len(result_blocks) == 1 assert result_blocks[0] is blocks[1] for block in result_blocks: @@ -281,6 +304,7 @@ def test_free_kv_cache_block_queue_popleft_n(): # fake_head->b4->b0->b2->fake_tail result_blocks = queue.popleft_n(2) assert len(result_blocks) == 2 + assert queue.num_free_blocks == 3 assert result_blocks[0] is blocks[3] assert result_blocks[1] is blocks[5] for block in result_blocks: @@ -290,6 +314,7 @@ def test_free_kv_cache_block_queue_popleft_n(): # fake_head->fake_tail result_blocks = queue.popleft_n(3) assert len(result_blocks) == 3 + assert queue.num_free_blocks == 0 assert result_blocks[0] is blocks[4] assert result_blocks[1] is blocks[0] assert result_blocks[2] is blocks[2] @@ -407,27 +432,20 @@ def test_generate_block_hash_extra_keys_cache_salt(): assert next_mm_idx == 1 -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_block_tokens(hash_fn): - import vllm.v1.core.kv_cache_utils - init_none_hash(hash_fn) - parent_block_hash = 123 + parent_block_hash = BlockHash(b"123") curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") block_hash = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids, extra_keys) - assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash) - assert block_hash.hash_value == hash_fn( - (parent_block_hash, curr_block_token_ids, extra_keys)) - assert block_hash.token_ids == curr_block_token_ids - assert block_hash.extra_keys == extra_keys + expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys)) + assert block_hash == expected -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_request_block_hasher(hash_fn): - import vllm.v1.core.kv_cache_utils - init_none_hash(hash_fn) request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], @@ -442,22 +460,14 @@ def test_request_block_hasher(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) - assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) - - # Check the first block - assert block_hashes[0].token_ids == (0, 1, 2) - assert block_hashes[0].extra_keys == ("hash1", ) - - # Check the second block - assert block_hashes[1].token_ids == (3, 4, 5) - assert block_hashes[1].extra_keys == ("hash2", ) + assert block_hashes[0] == hash_fn( + (kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", ))) + assert block_hashes[1] == hash_fn( + (block_hashes[0], (3, 4, 5), ("hash2", ))) -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_tokens_different_mm_input(hash_fn): - init_none_hash(hash_fn) - request1 = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], @@ -484,10 +494,8 @@ def test_hash_tokens_different_mm_input(hash_fn): assert block_hashes1[1] != block_hashes2[1] -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_request_tokens_no_mm_inputs(hash_fn): - init_none_hash(hash_fn) - request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], @@ -500,10 +508,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert block_hashes[0].token_ids == (0, 1, 2) - assert block_hashes[0].extra_keys is None - assert block_hashes[1].token_ids == (3, 4, 5) - assert block_hashes[1].extra_keys is None + assert block_hashes[0] == hash_fn( + (kv_cache_utils.NONE_HASH, (0, 1, 2), None)) + assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None)) def test_metrics(): @@ -542,102 +549,288 @@ def test_metrics(): assert not metrics.query_queue -def test_unify_kv_cache_configs(): - same_kv_cache_config = [ - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), - ], - ), - KVCacheConfig( - num_blocks=20, - kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), - ], - ), - ] - unify_kv_cache_configs(same_kv_cache_config) - assert same_kv_cache_config[0].num_blocks == 10 - assert same_kv_cache_config[1].num_blocks == 10 +def test_get_kv_cache_configs_multiple_workers(): + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) - need_sort_kv_cache_config = [ + ref_kv_cache_spec = new_kv_cache_spec() + same_kv_cache_specs = [{ + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }] + + # Basic case. All things are the same. + kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10 + ]) + assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), KVCacheConfig( - num_blocks=20, + num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), ] - unify_kv_cache_configs(need_sort_kv_cache_config) - sorted_kv_cache_groups = [ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)), - ] - assert ( - need_sort_kv_cache_config[0].kv_cache_groups == sorted_kv_cache_groups) - assert ( - need_sort_kv_cache_config[1].kv_cache_groups == sorted_kv_cache_groups) - - diff_kv_cache_config = [ + # Different available memory. This is the case for TP. + # Use the smallest memory available. + kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 20 + ]) + assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), KVCacheConfig( - num_blocks=20, + num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=8)), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), ] + + # Different KV cache specs. This is the case for PP. + different_layer_specs = [{ + "layer1": new_kv_cache_spec(), + }, { + "layer2": new_kv_cache_spec(), + "layer3": new_kv_cache_spec(), + }] + + # Different workers have different layers. + kv_cache_configs = get_kv_cache_configs( + vllm_config, different_layer_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10 + ]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer1"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer3"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer2", "layer3"], new_kv_cache_spec()), + ], + ), + ] + + # Some layers are the same, some are different. This is the case for TP+PP + tp_pp_kv_cache_specs = [{ + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, { + "layer3": new_kv_cache_spec(), + }, { + "layer3": new_kv_cache_spec(), + }] + + kv_cache_configs = get_kv_cache_configs( + vllm_config, tp_pp_kv_cache_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer3"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer3"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), + ], + ), + ] + + # Different workers have different types of layers. This is the case for + # hybrid models + PP. + different_type_layer_specs = [{ + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, { + "layer3": new_sliding_window_spec(), + "layer4": new_sliding_window_spec(), + }] + kv_cache_configs = get_kv_cache_configs( + vllm_config, different_type_layer_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + KVCacheGroupSpec([], new_sliding_window_spec()), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer3"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer4"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec([], ref_kv_cache_spec), + KVCacheGroupSpec(["layer3", "layer4"], + new_sliding_window_spec()), + ], + ), + ] + + # When divided into multiple KVCacheGroups, need to ensure the number of + # layers per group is similar. + different_type_layer_specs = [{ + "layer1": new_kv_cache_spec(), + "layer2": new_sliding_window_spec(), + "layer3": new_sliding_window_spec(), + }, { + "layer4": new_kv_cache_spec(), + "layer5": new_sliding_window_spec(), + "layer6": new_sliding_window_spec(), + }] + kv_cache_configs = get_kv_cache_configs( + vllm_config, different_type_layer_specs, [ + ref_kv_cache_spec.page_size_bytes * 10, + ref_kv_cache_spec.page_size_bytes * 10, + ]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1", "layer2", "layer3"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], ref_kv_cache_spec), + KVCacheGroupSpec(["layer2"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer3"], new_sliding_window_spec()), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer4", "layer5", "layer6"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer4"], ref_kv_cache_spec), + KVCacheGroupSpec(["layer5"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer6"], new_sliding_window_spec()), + ], + ), + ] + + # Have conflicting layers. Need to raise an error. + conflicting_layer_specs = [{ + "layer1": new_kv_cache_spec(), + }, { + "layer1": new_sliding_window_spec(), + }] with pytest.raises(AssertionError): - unify_kv_cache_configs(diff_kv_cache_config) + get_kv_cache_configs(vllm_config, conflicting_layer_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ]) def test_merge_kv_cache_spec(): @@ -901,7 +1094,7 @@ def test_allocate_with_lookahead(): assert len(blocks.get_block_ids()[0]) == 2 -def test_get_kv_cache_config(): +def test_get_kv_cache_config_one_worker(): # pass max_model_len to pass check_enough_kv_cache_memory model_config = ModelConfig(max_model_len=16) vllm_config = VllmConfig(model_config=model_config) @@ -912,8 +1105,10 @@ def test_get_kv_cache_config(): 'layer_1': new_kv_cache_spec(), 'layer_2': new_kv_cache_spec(), } - kv_cache_config_full = get_kv_cache_config( - vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32) + kv_cache_config_full = get_kv_cache_configs( + vllm_config, [kv_cache_specs_full], + [mem_per_block_per_layer * 2 * 32])[0] + print(kv_cache_config_full) assert kv_cache_config_full == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ @@ -931,8 +1126,9 @@ def test_get_kv_cache_config(): 'layer_1': new_sliding_window_spec(), 'layer_2': new_sliding_window_spec(), } - kv_cache_config_sliding = get_kv_cache_config( - vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32) + kv_cache_config_sliding = get_kv_cache_configs( + vllm_config, [kv_cache_specs_sliding], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_sliding == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ @@ -951,8 +1147,9 @@ def test_get_kv_cache_config(): 'layer_1': new_kv_cache_spec(), 'layer_2': new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ @@ -973,8 +1170,9 @@ def test_get_kv_cache_config(): 'layer_1': new_kv_cache_spec(), 'layer_2': new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=64, kv_cache_tensors=[ @@ -996,21 +1194,22 @@ def test_get_kv_cache_config(): 'layer_5': new_sliding_window_spec(), 'layer_6': new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_3", "layer_5"]), + shared_by=["layer_1", "layer_3", "layer_4"]), KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_4", "layer_6"]), + shared_by=["layer_2", "layer_5", "layer_6"]), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer_3", "layer_4"], + KVCacheGroupSpec(["layer_3", "layer_5"], new_sliding_window_spec()), - KVCacheGroupSpec(["layer_5", "layer_6"], + KVCacheGroupSpec(["layer_4", "layer_6"], new_sliding_window_spec()), ], ) @@ -1028,27 +1227,30 @@ def test_get_kv_cache_config(): 'layer_9': new_sliding_window_spec(), 'layer_10': new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 3 * 32])[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ KVCacheTensor( size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]), + shared_by=["layer_1", "layer_4", "layer_5", "layer_6"]), + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_7", "layer_8", "layer_9"]), KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_5", "layer_8"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_3", "layer_6", "layer_9"]), + shared_by=["layer_3", "layer_10"]), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"], + KVCacheGroupSpec(["layer_4", "layer_7", "layer_10"], new_sliding_window_spec()), - KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"], + KVCacheGroupSpec(["layer_5", "layer_8"], + new_sliding_window_spec()), + KVCacheGroupSpec(["layer_6", "layer_9"], new_sliding_window_spec()), - KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()), ], ) @@ -1058,13 +1260,14 @@ def test_get_kv_cache_config(): 'layer_2': new_kv_cache_spec(), } with pytest.raises(NotImplementedError): - get_kv_cache_config(vllm_config, kv_cache_specs_hybrid, - mem_per_block_per_layer * 2 * 32) + get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 2 * 32])[0] # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 - kv_cache_config_override_blocks = get_kv_cache_config( - vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32) + kv_cache_config_override_blocks = get_kv_cache_configs( + vllm_config, [kv_cache_specs_full], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_override_blocks == KVCacheConfig( num_blocks=16, kv_cache_tensors=[ @@ -1076,3 +1279,16 @@ def test_get_kv_cache_config(): kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) ]) + + +def test_get_kv_cache_configs_attention_free(): + kv_cache_specs: dict[str, KVCacheSpec] = {} + vllm_config = VllmConfig(model_config=ModelConfig(max_model_len=16)) + kv_cache_configs = get_kv_cache_configs(vllm_config, [kv_cache_specs], [0]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=1, + kv_cache_tensors=[], + kv_cache_groups=[], + ) + ] diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index e7a8f63702b30..3cf9d93696767 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,21 +8,33 @@ from typing import Callable, Optional import pytest import torch +import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams -from vllm.utils import sha256, sha256_cbor_64bit +from vllm.utils import sha256, sha256_cbor from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + get_block_hash, get_group_id, get_request_block_hasher, - hash_block_tokens, init_none_hash) + hash_block_tokens, init_none_hash, + make_block_hash_with_group_id) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) +@pytest.fixture(autouse=True) +def _auto_init_hash_fn(request): + hash_fn: Callable + if "hash_fn" in request.fixturenames: + hash_fn = init_none_hash(request.getfixturevalue("hash_fn")) + else: + hash_fn = sha256 + init_none_hash(hash_fn) + + def make_request( request_id: str, prompt_token_ids: list[int], @@ -101,8 +113,9 @@ def make_kv_cache_config_hybrid_model(block_size: int, ) -@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) -def test_prefill(hash_algo): +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_prefill(hash_fn): + block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -110,10 +123,6 @@ def test_prefill(hash_algo): enable_caching=True, ) - # choose the hash function according to the parameter - hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else - sha256 if hash_algo == "sha256" else hash) - # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] @@ -137,10 +146,12 @@ def test_prefill(hash_algo): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == 0 assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, ): @@ -233,7 +244,7 @@ def test_prefill_hybrid_model(): enable_caching=True, ) - hash_fn = hash + hash_fn = sha256 # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(block_size)] @@ -260,11 +271,13 @@ def test_prefill_hybrid_model(): block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - for block_id in block_ids: - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + for group_id, block_id in enumerate(block_ids): + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == group_id assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, 8, 12): @@ -298,11 +311,10 @@ def test_prefill_hybrid_model(): cached_block_hash_to_block_bak = copy.copy( manager.block_pool.cached_block_hash_to_block) - def test_partial_request_hit(request_id: str, - hash_to_evict: list[BlockHashWithGroupId], + def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes], expect_hit_length: int): req = make_request(request_id, common_token_ids + unique_token_ids, - block_size, hash) + block_size, sha256) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block.pop( hash_with_group_id) @@ -319,33 +331,32 @@ def test_prefill_hybrid_model(): # Evict the blocks outside sliding window, does not affect the hit length. test_partial_request_hit("2", [ - BlockHashWithGroupId(block_hashes[0], 1), - BlockHashWithGroupId(block_hashes[0], 2) + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2) ], 3) # Evict the first block of full attention, makes total cache miss. - test_partial_request_hit("3", [ - BlockHashWithGroupId(block_hashes[0], 0), - ], 0) + test_partial_request_hit( + "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0) # Evict the last block of all layers, reduces the hit length to 2. test_partial_request_hit("4", [ - BlockHashWithGroupId(block_hashes[2], 0), - BlockHashWithGroupId(block_hashes[2], 1), - BlockHashWithGroupId(block_hashes[2], 2), + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[2], 1), + make_block_hash_with_group_id(block_hashes[2], 2), ], 2) # Evict the last block of full attention, reduces the hit length to 2. - test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)], - 2) + test_partial_request_hit( + "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)], - 2) + test_partial_request_hit( + "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)], - 2) + test_partial_request_hit( + "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2) # Evict different set of blocks for full attention and sliding window makes # total cache miss. @@ -353,9 +364,9 @@ def test_prefill_hybrid_model(): # The cache hit length of sliding window is 2 * block_size. # Then it is cache miss as the two type of layers have different hit length. test_partial_request_hit("8", [ - BlockHashWithGroupId(block_hashes[2], 0), - BlockHashWithGroupId(block_hashes[0], 1), - BlockHashWithGroupId(block_hashes[0], 2), + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), ], 0) @@ -372,8 +383,8 @@ def test_prefill_plp(): max_model_len=8192, enable_caching=True, ) - # the default hash function is hash - hash_fn = hash + # the default hash function is sha256 + hash_fn = sha256 # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] @@ -404,10 +415,12 @@ def test_prefill_plp(): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + blk_hash = (manager.block_pool.blocks[block_id].block_hash) + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == 0 assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, ): @@ -493,7 +506,7 @@ def test_decode(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids, block_size, - hash) + sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -538,7 +551,7 @@ def test_evict(): ) last_token_id = 5 * 16 + 7 - req0 = make_request("0", list(range(last_token_id)), block_size, hash) + req0 = make_request("0", list(range(last_token_id)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -550,7 +563,7 @@ def test_evict(): # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16)), block_size, - hash) + sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -572,7 +585,7 @@ def test_evict(): ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. - req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash) + req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2], ) assert num_computed_tokens == 2 * 16 @@ -597,7 +610,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 - req = make_request("0", list(range(num_tokens)), block_size, hash) + req = make_request("0", list(range(num_tokens)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -611,7 +624,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. - req = make_request("1", list(range(num_tokens - 1)), block_size, hash) + req = make_request("1", list(range(num_tokens - 1)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -638,7 +651,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 - req0 = make_request("0", list(range(num_tokens)), block_size, hash) + req0 = make_request("0", list(range(num_tokens)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -650,7 +663,7 @@ def test_computed_blocks_not_evicted(): # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), - block_size, hash) + block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -666,7 +679,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. - req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash) + req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 @@ -691,7 +704,7 @@ def test_basic_prefix_caching_disabled(): ) req1 = make_request("1", list(range(10)), block_size, - hash) # 2 blocks and some more + sha256) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] @@ -706,7 +719,7 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16)), block_size, - hash) # shared prefix + sha256) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -716,7 +729,7 @@ def test_basic_prefix_caching_disabled(): assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. - req3 = make_request("3", list(range(4)), block_size, hash) + req3 = make_request("3", list(range(4)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -726,13 +739,12 @@ def test_basic_prefix_caching_disabled(): assert not blocks -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_cache_blocks(hash_fn): """ This is a unit test that tests the correctness of the _cache_full_blocks function of KVCacheManager. """ - init_none_hash(hash_fn) block_size = 4 block_pool = BlockPool( @@ -787,7 +799,7 @@ def test_cache_blocks_multi_group(): # Block 1/5: [4, 5, 6, 7] # Block 2/6: [8, 9, 10, 11] # Block 3/7: [12, 13] - req = make_request("0", list(range(14)), block_size, hash) + req = make_request("0", list(range(14)), block_size, sha256) # Cache the blocks for group 0. blocks = [KVCacheBlock(block_id=i) for i in range(2)] @@ -845,6 +857,7 @@ def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ + block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -874,23 +887,30 @@ def test_mm_prefix_caching(): req0 = make_request("0", all_token_ids, block_size, - hash, + sha256, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - # Completed block should have hashes with extra keys. + # Completed block should have hashes assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = req0.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("aaa", ) - assert block_hashes[1].extra_keys == ("aaa", "bbb") - assert block_hashes[2].extra_keys == ("bbb", ) + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), + ("aaa", ))) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]), + ("aaa", "bbb"))) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]), + ("bbb", ))) blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) + assert blocks is not None assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 @@ -901,10 +921,10 @@ def test_mm_prefix_caching(): len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - - # The just completed block should have hashes with extra keys. assert len(block_hashes) == 4 - assert block_hashes[3].extra_keys == ("ccc", ) + assert block_hashes[3] == sha256( + (block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5), + ("ccc", ))) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -916,7 +936,7 @@ def test_mm_prefix_caching(): req1 = make_request("1", all_token_ids, block_size, - hash, + sha256, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) @@ -939,21 +959,26 @@ def test_cache_key_salting(): # 3 complete blocks and an incomplete block with 11 tokens. common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 - req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1") + req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - # Completed block should have hashes with extra keys. + # Completed block should have hashes assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = req0.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("salt1", ) - assert block_hashes[1].extra_keys is None - assert block_hashes[2].extra_keys is None + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", ))) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), + None)) blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) + assert blocks is not None assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 @@ -964,14 +989,13 @@ def test_cache_key_salting(): len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - - # Now one more block that should not have extra keys. assert len(block_hashes) == 4 - assert block_hashes[3].extra_keys is None + assert block_hashes[3] == sha256( + (block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None)) # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 - req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1") + req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 @@ -979,13 +1003,19 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 - req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2") + req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 block_hashes = req2.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("salt2", ) + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", ))) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), + None)) def test_prefill_not_enough_free_blocks_with_computed_blocks(): @@ -1004,7 +1034,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] - req0 = make_request("0", common_token_ids, block_size, hash) + req0 = make_request("0", common_token_ids, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1015,7 +1045,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | - req1 = make_request("1", common_token_ids * 2, block_size, hash) + req1 = make_request("1", common_token_ids * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 @@ -1032,7 +1062,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | - req2 = make_request("2", [7] * block_size * 2, block_size, hash) + req2 = make_request("2", [7] * block_size * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1044,7 +1074,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 - req3 = make_request("3", common_token_ids * 3, block_size, hash) + req3 = make_request("3", common_token_ids * 3, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 @@ -1069,13 +1099,13 @@ def test_reset_prefix_cache(): full_block_token_ids = [i for i in range(3) for _ in range(16)] unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids, block_size, hash) + req0 = make_request("0", all_token_ids, block_size, sha256) blocks = manager.allocate_slots(req0, 55) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req1 = make_request("1", all_token_ids, block_size, hash) + req1 = make_request("1", all_token_ids, block_size, sha256) computed_blocks, _ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 @@ -1109,7 +1139,7 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None # Call all functions that check whether log_stats is disabled. - req = make_request("0", list(range(16)), block_size, hash) + req = make_request("0", list(range(16)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1124,15 +1154,9 @@ def test_prefix_cache_stats_disabled(): def test_maybe_evict_cached_block(): pool = BlockPool(num_gpu_blocks=4, enable_caching=True) - block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10, - token_ids=(100, )), - group_id=1000) - block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20, - token_ids=(200, )), - group_id=2000) - block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30, - token_ids=(300, )), - group_id=3000) + block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000) + block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000) + block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000) block_hashes = [ block_hash0, block_hash1, @@ -1206,7 +1230,7 @@ def test_kv_cache_events(blocks_to_cache: int): ) num_tokens = block_size * blocks_to_cache - req0 = make_request("0", list(range(num_tokens)), block_size, hash) + req0 = make_request("0", list(range(num_tokens)), block_size, sha256) _ = manager.allocate_slots(req0, num_tokens) events = manager.take_events() @@ -1222,7 +1246,7 @@ def test_kv_cache_events(blocks_to_cache: int): # Should see block_to_cache number of removed block events and a new block # stored event manager.free(req0) - req1 = make_request("1", list(range(num_tokens)), block_size, hash) + req1 = make_request("1", list(range(num_tokens)), block_size, sha256) _ = manager.allocate_slots(req1, num_tokens) events = manager.take_events() @@ -1256,7 +1280,7 @@ def test_eagle_enabled_removes_last_block(): # Request with 3 full blocks (48 tokens) token_ids = [0] * (3 * block_size) - req = make_request("divisible_request", token_ids, block_size, hash) + req = make_request("divisible_request", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1266,7 +1290,7 @@ def test_eagle_enabled_removes_last_block(): manager.free(req) # New request with same tokens + Eagle enabled - req_eagle = make_request("eagle_divisible", token_ids, block_size, hash) + req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: @@ -1287,7 +1311,7 @@ def test_eagle_with_partial_blocks(): ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids, block_size, hash) + req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1297,7 +1321,7 @@ def test_eagle_with_partial_blocks(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids, block_size, hash) + req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1328,7 +1352,7 @@ def test_eagle_with_sliding_window(): # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids, block_size, hash) + req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1341,7 +1365,7 @@ def test_eagle_with_sliding_window(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids, block_size, hash) + req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1351,11 +1375,11 @@ def test_eagle_with_sliding_window(): assert manager.block_pool.get_cached_block( block_hash_first_block, kv_cache_group_ids=[0]) is not None manager.block_pool.cached_block_hash_to_block.pop( - BlockHashWithGroupId(block_hash_first_block, 0)) + make_block_hash_with_group_id(block_hash_first_block, 0)) # New request req_after_evict = make_request("partial_eagle_after_evict", token_ids, - block_size, hash) + block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 572d6c9c889f6..f6fc1e6d37d14 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) -from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -1796,11 +1796,11 @@ def test_schedule_skip_tokenizer_init(): def test_schedule_skip_tokenizer_init_structured_output_request(): scheduler = create_scheduler(skip_tokenizer_init=True) - guided_params = GuidedDecodingParams(regex="[0-9]+") + structured_outputs_params = StructuredOutputsParams(regex="[0-9]+") sampling_params = SamplingParams( ignore_eos=False, max_tokens=16, - guided_decoding=guided_params, + structured_outputs=structured_outputs_params, ) request = Request( request_id="0", diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index 7dcebba491fab..b70850a9bcff9 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -6,8 +6,8 @@ import random import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock) +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + make_block_hash_with_group_id) from vllm.v1.core.single_type_kv_cache_manager import ( ChunkedLocalAttentionManager, SlidingWindowManager) from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, @@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix(): def run_one_case(block_is_cached, tail_token, expect_length): block_hash_list = [ - BlockHash(i, ()) for i in range(len(block_is_cached)) + BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] block_pool.cached_block_hash_to_block.clear() @@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix(): for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[BlockHashWithGroupId( - block_hash, 0)] = { + block_pool.cached_block_hash_to_block[ + make_block_hash_with_group_id(block_hash, 0)] = { i: block_pool.blocks[i + 10], } @@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix(): def run_one_case(block_is_cached, expect_length): block_hash_list = [ - BlockHash(i, ()) for i in range(len(block_is_cached)) + BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] block_pool.cached_block_hash_to_block.clear() @@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix(): for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[BlockHashWithGroupId( - block_hash, 0)] = { + block_pool.cached_block_hash_to_block[ + make_block_hash_with_group_id(block_hash, 0)] = { i: block_pool.blocks[i + 10], } diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index e392c2c336e9b..d343141cdf4cb 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams +from vllm.utils import sha256 from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) from vllm.v1.core.sched.async_scheduler import AsyncScheduler @@ -130,10 +131,10 @@ def create_requests( ) -> list[Request]: global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(sha256) _none_hash_initialized = True - block_hasher = get_request_block_hasher(block_size, hash) + block_hasher = get_request_block_hasher(block_size, sha256) sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 81655e4175006..25e01806f4956 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -62,6 +62,16 @@ backend_configs = { "cudagraph_mode": "FULL_AND_PIECEWISE", }, specific_gpu_arch=(9, 0)), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": + BackendConfig(name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0)), # FA2 "FA2": BackendConfig(name="FA2", diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 4dfe1d3bb33fa..5b0c154722510 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -6,8 +6,7 @@ import pytest from vllm import LLM, SamplingParams -from ...core.block.e2e.test_correctness_sliding_window import (check_answers, - prep_prompts) +from ...utils import check_answers, prep_prompts @dataclass diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index bd0fa6b80781a..bf90f50b10828 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -83,7 +83,7 @@ def test_ngram_correctness( model_name: str, ): ''' - Compare the outputs of a original LLM and a speculative LLM + Compare the outputs of an original LLM and a speculative LLM should be the same when using ngram speculative decoding. ''' with monkeypatch.context() as m: @@ -117,45 +117,38 @@ def test_ngram_correctness( print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") - # Heuristic: expect at least 70% of the prompts to match exactly + # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.7 * len(ref_outputs)) + assert matches >= int(0.66 * len(ref_outputs)) del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() -@pytest.mark.parametrize( - ["model_setup", "mm_enabled"], - [ - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - False, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - True, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - (("eagle", "eagle618/deepseek-v3-random", - "eagle618/eagle-deepseek-v3-random", 1), False), - ], - ids=[ - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "qwen3_eagle3", - "llama3_eagle", - "llama3_eagle3", - "llama4_eagle", - "llama4_eagle_mm", - "deepseek_eagle" - ]) +@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + False, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + True, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), +], + ids=[ + "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", + "llama4_eagle", "llama4_eagle_mm", + "deepseek_eagle" + ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( @@ -169,7 +162,7 @@ def test_eagle_correctness( # TODO: Fix this flaky test pytest.skip( "TREE_ATTN is flaky in the test disable for now until it can be " - "reolved (see https://github.com/vllm-project/vllm/issues/22922)") + "resolved (see https://github.com/vllm-project/vllm/issues/22922)") # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py index d7722142b207f..a73a9a6999f72 100644 --- a/tests/v1/engine/conftest.py +++ b/tests/v1/engine/conftest.py @@ -12,7 +12,6 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, generate_dummy_prompt_logprobs_tensors, generate_dummy_sample_logprobs) from vllm.engine.arg_utils import EngineArgs -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from ...distributed.conftest import publisher_config, random_port # noqa: F401 @@ -24,7 +23,7 @@ EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor] def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: """Generate output processor dummy test vectors, without logprobs - + Returns: DummyOutputProcessorTestVectors instance with no logprobs """ @@ -48,9 +47,6 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: ] return DummyOutputProcessorTestVectors( tokenizer=tokenizer, - tokenizer_group=init_tokenizer_from_configs( - vllm_config.model_config, vllm_config.scheduler_config, - vllm_config.lora_config), vllm_config=vllm_config, full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS], prompt_tokens=prompt_tokens, @@ -68,7 +64,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: @pytest.fixture def dummy_test_vectors() -> DummyOutputProcessorTestVectors: """Generate output processor dummy test vectors, with logprobs - + Returns: DummyOutputProcessorTestVectors instance with logprobs """ diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index df04a14af70ce..aca546600d0b5 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -393,7 +393,7 @@ class MockLoggingStatLogger(LoggingStatLogger): async def test_customize_loggers(monkeypatch): """Test that we can customize the loggers. If a customized logger is provided at the init, it should - be used directly. + be added to the default loggers. """ with monkeypatch.context() as m, ExitStack() as after: @@ -410,7 +410,8 @@ async def test_customize_loggers(monkeypatch): stat_loggers = engine.logger_manager.per_engine_logger_dict assert len(stat_loggers) == 1 - assert len(stat_loggers[0]) == 1 + assert len( + stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger stat_loggers[0][0].log.assert_called_once() diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index f70a3ce147ff2..23ec3673b10b4 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -36,18 +36,19 @@ def test_prefix_caching_from_cli(): assert vllm_config.cache_config.enable_prefix_caching # default hash algorithm is "builtin" - assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" + assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" + + # set hash algorithm to sha256_cbor + args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"]) + vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() + assert vllm_config.cache_config.prefix_caching_hash_algo == \ + "sha256_cbor" # set hash algorithm to sha256 args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"]) vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" - # set hash algorithm to builtin - args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"]) - vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() - assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" - # an invalid hash algorithm raises an error parser.exit_on_error = False with pytest.raises(ArgumentError): diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 98265c6349578..17b136aa42731 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -257,9 +257,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): def execute_model( self, scheduler_output, + non_block=False, ) -> Future[ModelRunnerOutput]: """Make execute_model non-blocking.""" + # DummyExecutor used only for testing async case. + assert non_block + def _execute(): output = self.collective_rpc("execute_model", args=(scheduler_output, )) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 2848420c22085..7529c3780ec25 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional import pytest from vllm import LLM -from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector if TYPE_CHECKING: @@ -97,7 +97,7 @@ def _get_test_sampling_params( top_p=0.95, n=n, seed=seed, - guided_decoding=GuidedDecodingParams( + structured_outputs=StructuredOutputsParams( regex="[0-9]+") if structured_outputs else None, ) for n in n_list ], n_list diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 6544e8b017e70..a9632ce54eac8 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -43,7 +43,7 @@ def _ref_convert_id_to_token( [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) def test_incremental_detokenization(request_output_kind: RequestOutputKind, dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens) @@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, num_sample_logprobs: Optional[int], num_prompt_logprobs: Optional[int], dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, @@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool, ) # '<|end_of_text|>' stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>' - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) # Dummy engine core outputs, with control tokens suffixed to test stops suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids) @@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool, [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) def test_stop_string(include_stop_str_in_output: bool, num_sample_logprobs: Optional[int], dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, @@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool, def test_iteration_stats(dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group, + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) engine_core_timestamp = time.monotonic() diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index 970a59eca8ece..bdd41eece2317 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -31,7 +31,7 @@ def _mk_processor(monkeypatch, raising=True) monkeypatch.setattr(ModelConfig, "__post_init__", - lambda self: None, + lambda self, *args: None, raising=True) monkeypatch.setattr(UnspecifiedPlatform, "is_async_output_supported", @@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( *, tokenization_kwargs=None, lora_request=None, - mm_hash_overrides=None): - captured["mm_hash_overrides"] = mm_hash_overrides + mm_uuids=None): + captured["mm_uuids"] = mm_uuids # Minimal processed inputs for decoder-only flow return {"type": "token", "prompt_token_ids": [1]} @@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( params=SamplingParams(), ) - assert captured["mm_hash_overrides"] == mm_uuids + assert captured["mm_uuids"] == mm_uuids def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): @@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): *, tokenization_kwargs=None, lora_request=None, - mm_hash_overrides=None): - captured["mm_hash_overrides"] = mm_hash_overrides + mm_uuids=None): + captured["mm_uuids"] = mm_uuids return {"type": "token", "prompt_token_ids": [1]} monkeypatch.setattr(processor.input_preprocessor, @@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): ) # Expect request-id-based overrides are passed through - assert captured["mm_hash_overrides"] == { + assert captured["mm_uuids"] == { "image": [f"{request_id}-image-0", f"{request_id}-image-1"], "video": [f"{request_id}-video-0"], } diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index b58bc75fc9565..689b2c95f927e 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -9,7 +9,6 @@ import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs -from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -39,7 +38,7 @@ def _create_random_top_logprob_test_vector( upper: float, ) -> torch.Tensor: """Create a random vector of top logprob float values. - + Use to create fake sample logprobs for testing. Note that a real production scenario would require @@ -63,7 +62,7 @@ def _create_random_top_logprob_test_matrix( upper: float, ) -> torch.Tensor: """Create a random matrix of top logprob float values. - + Use to create fake prompt logprobs for testing. Note that a real production scenario would require @@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors( class DummyOutputProcessorTestVectors: """Dummy test vectors for output processor tests""" tokenizer: GeneralTokenizerType - tokenizer_group: TokenizerGroup vllm_config: EngineArgs full_tokens: list[list[int]] # Prompt + generated tokens prompt_tokens: list[list[int]] diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index ffe0612124660..46b953fe37433 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -151,7 +151,7 @@ def sample_definition_json_schema(): @pytest.fixture -def sample_guided_choice(): +def sample_structured_outputs_choices(): return [ "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby", "Swift", "Kotlin" diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index cd82eb2ac4199..4b0f3b2d9967e 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -15,12 +15,13 @@ import torch from pydantic import BaseModel from tests.reasoning.utils import run_reasoning_extraction +from vllm.config import StructuredOutputsConfig from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager -from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.sampling_params import SamplingParams, StructuredOutputsParams if TYPE_CHECKING: from vllm.config import TokenizerMode @@ -46,12 +47,12 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), ("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), + #FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402 + # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), + # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), + #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG), - #FIXME: This test is flaky on CI thus disabled - #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), @@ -90,7 +91,7 @@ def _load_json(s: str, backend: str) -> str: @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( - "model_name, guided_decoding_backend, tokenizer_mode, speculative_config", + "model_name, backend, tokenizer_mode, speculative_config", PARAMS_MODELS_BACKENDS_TOKENIZER_MODE) def test_structured_output( monkeypatch: pytest.MonkeyPatch, @@ -99,8 +100,8 @@ def test_structured_output( sample_sql_ebnf: str, sample_sql_lark: str, sample_regex: str, - sample_guided_choice: str, - guided_decoding_backend: str, + sample_structured_outputs_choices: str, + backend: str, tokenizer_mode: str, model_name: str, speculative_config: dict[str, Any], @@ -115,15 +116,15 @@ def test_structured_output( enforce_eager = bool(not current_platform.is_tpu()) # Use a single LLM instance for several scenarios to # speed up the test suite. - llm = LLM( - model=model_name, - enforce_eager=enforce_eager, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend, - guided_decoding_disable_any_whitespace=(guided_decoding_backend - in {"xgrammar", "guidance"}), - tokenizer_mode=tokenizer_mode, - speculative_config=speculative_config) + llm = LLM(model=model_name, + enforce_eager=enforce_eager, + max_model_len=1024, + structured_outputs_config=dict(backend=backend, + disable_any_whitespace=backend + in {"xgrammar", "guidance"}), + seed=120, + tokenizer_mode=tokenizer_mode, + speculative_config=speculative_config) # # Test 1: Generate JSON output based on a provided schema @@ -131,7 +132,7 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, - guided_decoding=GuidedDecodingParams(json=sample_json_schema)) + structured_outputs=StructuredOutputsParams(json=sample_json_schema)) prompt = ("Give an example JSON for an employee profile that fits this " "schema. Make the response as short as possible. Schema: " @@ -151,7 +152,7 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None - if guided_decoding_backend != 'lm-format-enforcer': + if backend != 'lm-format-enforcer': assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) @@ -160,12 +161,12 @@ def test_structured_output( # # Test 2: Generate JSON object without a schema # - if guided_decoding_backend != "outlines": + if backend != "outlines": sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, n=2, - guided_decoding=GuidedDecodingParams(json_object=True)) + structured_outputs=StructuredOutputsParams(json_object=True)) outputs = llm.generate(prompts=( "Generate a JSON object with curly braces for a person with " @@ -194,8 +195,9 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, - guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) - if guided_decoding_backend.startswith("xgrammar"): + structured_outputs=StructuredOutputsParams( + json=unsupported_json_schema)) + if backend.startswith("xgrammar"): with pytest.raises(ValueError, match="The provided JSON schema contains features " "not supported by xgrammar."): @@ -229,7 +231,7 @@ def test_structured_output( parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) - if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: + if backend not in ["outlines", "lm-format-enforcer"]: # # Test 4: Generate SQL statement using EBNF grammar # @@ -237,7 +239,8 @@ def test_structured_output( temperature=0.8, top_p=0.95, max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) + structured_outputs=StructuredOutputsParams( + grammar=sample_sql_ebnf)) outputs = llm.generate( ("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short as " @@ -270,7 +273,8 @@ def test_structured_output( temperature=0.8, top_p=0.95, max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) + structured_outputs=StructuredOutputsParams( + grammar=sample_sql_lark)) outputs = llm.generate( ("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short as " @@ -308,7 +312,8 @@ def test_structured_output( temperature=0.8, top_p=0.95, max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar="not a grammar")) + structured_outputs=StructuredOutputsParams( + grammar="not a grammar")) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( ("Generate a sql statement that selects col_1 from " @@ -324,7 +329,7 @@ def test_structured_output( sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - guided_decoding=GuidedDecodingParams(regex=sample_regex)) + structured_outputs=StructuredOutputsParams(regex=sample_regex)) prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. " f"Make the response as short as possible.") @@ -351,7 +356,8 @@ def test_structured_output( sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) + structured_outputs=StructuredOutputsParams( + choice=sample_structured_outputs_choices)) outputs = llm.generate( ("The best language for type-safe systems programming is " @@ -367,7 +373,7 @@ def test_structured_output( generated_text = output.outputs[0].text print(generated_text) assert generated_text is not None - assert generated_text in sample_guided_choice + assert generated_text in sample_structured_outputs_choices print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") # @@ -377,7 +383,7 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - guided_decoding=GuidedDecodingParams(json=json_schema)) + structured_outputs=StructuredOutputsParams(json=json_schema)) outputs = llm.generate( ("Generate a JSON with the brand, model and car_type of the most " @@ -421,7 +427,7 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, - guided_decoding=GuidedDecodingParams(json=json_schema)) + structured_outputs=StructuredOutputsParams(json=json_schema)) outputs = llm.generate( ("Generate a description of a frog using 50 characters. " @@ -443,7 +449,7 @@ def test_structured_output( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) - if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: + if backend not in ["outlines", "lm-format-enforcer"]: # # Test 11: Generate structured output using structural_tag format # @@ -469,7 +475,7 @@ def test_structured_output( sampling_params = SamplingParams( temperature=0.0, max_tokens=4096, - guided_decoding=GuidedDecodingParams( + structured_outputs=StructuredOutputsParams( structural_tag=json.dumps(structural_tag_config))) prompt = """ @@ -546,7 +552,7 @@ Make the response as short as possible. @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( - "model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 + "model_name, backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 [ ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", "deepseek_r1", NGRAM_SPEC_CONFIG), @@ -555,7 +561,7 @@ Make the response as short as possible. ) def test_structured_output_with_reasoning_matrices( monkeypatch: pytest.MonkeyPatch, - guided_decoding_backend: str, + backend: str, tokenizer_mode: TokenizerMode, reasoning_parser: str, model_name: str, @@ -575,13 +581,14 @@ def test_structured_output_with_reasoning_matrices( enforce_eager=bool(not current_platform.is_tpu()), max_model_len=1024, max_num_seqs=16, - guided_decoding_backend=guided_decoding_backend, - guided_decoding_disable_any_whitespace=True, + structured_outputs_config=dict(backend=backend, + disable_any_whitespace=backend + in {"xgrammar", "guidance"}, + reasoning_parser=reasoning_parser), tokenizer_mode=tokenizer_mode, - reasoning_parser=reasoning_parser, speculative_config=speculative_config, ) - tokenizer = llm.get_tokenizer(None) + tokenizer = llm.get_tokenizer() reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)( tokenizer=tokenizer) @@ -602,7 +609,7 @@ def test_structured_output_with_reasoning_matrices( sampling_params = SamplingParams( temperature=0.1, max_tokens=8192, - guided_decoding=GuidedDecodingParams(json=reasoning_schema), + structured_outputs=StructuredOutputsParams(json=reasoning_schema), ) outputs = llm.generate( [reasoning_prompt], @@ -639,13 +646,14 @@ def test_structured_output_auto_mode( llm = LLM(model=model_name, max_model_len=1024, - guided_decoding_backend="auto", + structured_outputs_config=dict(backend="auto"), tokenizer_mode=tokenizer_mode) sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) + structured_outputs=StructuredOutputsParams( + json=unsupported_json_schema)) prompts = ( "Give an example JSON object for a grade " @@ -680,9 +688,10 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", max_model_len=1024, - guided_decoding_backend="guidance", - guided_decoding_disable_any_whitespace=True, - guided_decoding_disable_additional_properties=True) + structured_outputs_config=dict( + backend="guidance", + disable_any_whitespace=True, + disable_additional_properties=True)) schema = { 'type': 'object', @@ -708,14 +717,15 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): "<|im_end|>\n<|im_start|>assistant\n") def generate_with_backend(backend): - guided_params = GuidedDecodingParams( + structured_outputs_params = StructuredOutputsParams( json=schema, backend=backend, disable_any_whitespace=True, disable_additional_properties=True) - sampling_params = SamplingParams(temperature=0, - max_tokens=256, - guided_decoding=guided_params) + sampling_params = SamplingParams( + temperature=0, + max_tokens=256, + structured_outputs=structured_outputs_params) outputs = llm.generate(prompt, sampling_params=sampling_params) assert outputs is not None @@ -735,12 +745,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): assert "a6" not in generated -@pytest.mark.parametrize("guided_decoding_backend", - ["guidance", "xgrammar", "outlines"]) -def test_structured_output_batched_with_non_guided_requests( +@pytest.mark.parametrize("backend", ["guidance", "xgrammar", "outlines"]) +def test_structured_output_batched_with_non_structured_outputs_requests( monkeypatch: pytest.MonkeyPatch, sample_json_schema: dict[str, Any], - guided_decoding_backend: str, + backend: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") @@ -752,24 +761,25 @@ def test_structured_output_batched_with_non_guided_requests( model="meta-llama/Meta-Llama-3.1-8B-Instruct", enforce_eager=enforce_eager, max_model_len=1024, - guided_decoding_backend=guided_decoding_backend, - guided_decoding_disable_any_whitespace=(guided_decoding_backend - in {"xgrammar", "guidance"}), + structured_outputs_config=StructuredOutputsConfig( + backend=backend, + disable_any_whitespace=backend in {"xgrammar", "guidance"}, + ), ) - guided_prompt = ( + structured_outputs_prompt = ( "Give an example JSON for an employee profile that fits this " "schema. Make the response as short as possible. Schema: " f"{sample_json_schema}") - non_guided_prompt = "The diameter of the Earth in kilometers is " + non_structured_outputs_prompt = "The diameter of the Earth in kilometers is " - prompts = [guided_prompt, non_guided_prompt] + prompts = [structured_outputs_prompt, non_structured_outputs_prompt] sampling_params = [ - SamplingParams( - temperature=1.0, - max_tokens=400, - guided_decoding=GuidedDecodingParams(json=sample_json_schema)), + SamplingParams(temperature=1.0, + max_tokens=400, + structured_outputs=StructuredOutputsParams( + json=sample_json_schema)), # No max tokens, temp=0 to assert on contents SamplingParams( seed=42, @@ -800,16 +810,16 @@ def test_structured_output_batched_with_non_guided_requests( print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}") if index == 0: - # First prompt is guided, expect valid JSON + # First prompt is structured outputs, expect valid JSON assert "\n" not in generated_text output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) else: - # Second prompt is not guided, expect valid output + # Second prompt is not structured outputs, expect valid output # Cannot assert on exact output, but we can expect it to be factual assert "12,742" in generated_text - # non-guided requests should not return a valid JSON here + # non-structured outputs requests should not return a valid JSON here with pytest.raises(ValueError): output_json = json.loads(generated_text) diff --git a/tests/v1/entrypoints/openai/responses/test_basic.py b/tests/v1/entrypoints/openai/responses/test_basic.py index 7a0baa5767cba..2ee1004493a16 100644 --- a/tests/v1/entrypoints/openai/responses/test_basic.py +++ b/tests/v1/entrypoints/openai/responses/test_basic.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import openai # use the official client for correctness check +import openai.types.responses as openai_responses_types import pytest @@ -86,3 +87,18 @@ async def test_logprobs(client: openai.AsyncOpenAI): outputs = response.output assert outputs[-1].content[-1].logprobs assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5 + + +@pytest.mark.asyncio +async def test_streaming(client: openai.AsyncOpenAI): + stream = await client.responses.create( + input="What is 13 * 24?", + stream=True, + ) + events = [event async for event in stream] + assert isinstance(events[0], openai_responses_types.ResponseCreatedEvent) + assert any( + isinstance(event, openai_responses_types.ResponseTextDeltaEvent) + for event in events) + assert isinstance(events[-1], + openai_responses_types.ResponseCompletedEvent) diff --git a/tests/v1/entrypoints/openai/responses/test_image.py b/tests/v1/entrypoints/openai/responses/test_image.py index c8d09fd39fb13..3ed36ca678c0c 100644 --- a/tests/v1/entrypoints/openai/responses/test_image.py +++ b/tests/v1/entrypoints/openai/responses/test_image.py @@ -8,17 +8,17 @@ import pytest import pytest_asyncio from tests.utils import RemoteOpenAIServer -from vllm.multimodal.utils import encode_image_base64, fetch_image +from vllm.multimodal.utils import encode_image_base64 # Use a small vision model for testing MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" MAXIMUM_IMAGES = 2 # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) -TEST_IMAGE_URLS = [ - "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", - "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +TEST_IMAGE_ASSETS = [ + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] @@ -52,16 +52,17 @@ async def client(image_server): @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_url: + encode_image_base64(local_asset_server.get_image_asset(image_url)) + for image_url in TEST_IMAGE_ASSETS } @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image(client: openai.AsyncOpenAI, model_name: str, image_url: str): content_text = "What's in this image?" @@ -91,11 +92,11 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) async def test_single_chat_session_image_base64encoded( client: openai.AsyncOpenAI, model_name: str, - image_url: str, + raw_image_url: str, base64_encoded_image: dict[str, str], ): content_text = "What's in this image?" @@ -106,7 +107,7 @@ async def test_single_chat_session_image_base64encoded( { "type": "input_image", "image_url": - f"data:image/jpeg;base64,{base64_encoded_image[image_url]}", + f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}", "detail": "auto", }, { @@ -127,7 +128,8 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( "image_urls", - [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) + [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], + indirect=True) async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, image_urls: list[str]): messages = [{ diff --git a/tests/v1/entrypoints/openai/test_chat_completion.py b/tests/v1/entrypoints/openai/test_chat_completion.py index dffb32846c05e..9aa285aa9b18d 100644 --- a/tests/v1/entrypoints/openai/test_chat_completion.py +++ b/tests/v1/entrypoints/openai/test_chat_completion.py @@ -77,7 +77,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, "role": "user", "content": prompt, }], - extra_body={"guided_json": invalid_json_schema}, + extra_body={"structured_outputs": { + "json": invalid_json_schema + }}, ) @@ -99,7 +101,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): "content": prompt, }], extra_body={ - "guided_regex": r"[.*", + "structured_outputs": { + "regex": r"[.*" + }, "stop": ["\n"] }, ) @@ -134,5 +138,9 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): "role": "user", "content": prompt, }], - extra_body={"guided_grammar": invalid_simplified_sql_grammar}, + extra_body={ + "structured_outputs": { + "grammar": invalid_simplified_sql_grammar + } + }, ) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 3a65583fab8d3..9090beb4bbd2a 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -627,7 +627,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, await client.completions.create( model=model_name, prompt=prompt, - extra_body={"guided_json": invalid_json_schema}, + extra_body={"structured_outputs": { + "json": invalid_json_schema + }}, ) @@ -646,7 +648,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): model=model_name, prompt=prompt, extra_body={ - "guided_regex": r"[.*", + "structured_outputs": { + "regex": r"[.*" + }, "stop": ["\n"] }, ) @@ -678,7 +682,11 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): await client.completions.create( model=model_name, prompt=prompt, - extra_body={"guided_grammar": invalid_simplified_sql_grammar}, + extra_body={ + "structured_outputs": { + "grammar": invalid_simplified_sql_grammar + } + }, ) @@ -686,7 +694,7 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): async def test_completion_with_empty_prompt_embeds( client: openai.AsyncOpenAI) -> None: """Test completion with empty prompt embeds.""" - payload: dict[str, list] = {"prompt_embeds": []} + payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []} headers: dict[str, str] = {"Content-Type": "application/json"} # base_url = http://localhost:8000/v1/completions response = requests.post(f"{client.base_url}completions", diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index d8c56ac42f718..380e72a156336 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -42,7 +42,7 @@ def test_basic_lifecycle(): engine_core_outputs = scheduler.update_from_output(scheduler_output, model_runner_output) - # Ensure the request is finished after 1 tokens. + # Ensure the request is finished after 1 token. assert request.is_finished() assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED output = engine_core_outputs[0].outputs[0] @@ -141,7 +141,7 @@ def test_short_prompt_lifecycle(): def test_prefix_cache_lifecycle(): - """Test that remote decode params still works with a prefix cache hit.""" + """Test that remote decode params still work with a prefix cache hit.""" vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) diff --git a/tests/v1/kv_connector/unit/test_shared_storage_connector.py b/tests/v1/kv_connector/unit/test_shared_storage_connector.py index db203b81f15fc..6be261e45cb00 100644 --- a/tests/v1/kv_connector/unit/test_shared_storage_connector.py +++ b/tests/v1/kv_connector/unit/test_shared_storage_connector.py @@ -33,7 +33,7 @@ def _check_path_len(path): def _list_path(path): - """Return the list of foldername (hashes generatd) under the path""" + """Return the list of foldername (hashes generated) under the path""" return list(path.iterdir()) @@ -41,7 +41,7 @@ def run_test(tmp_path, processor, llm: LLM, question: str, image_urls: list[Image], expected_len: int, info: str): """ One individual test to process the prompt and output base on 1 set of input - Then check if the length in the strorage path matches the expected length + Then check if the length in the storage path matches the expected length `info` introduces details or purpose of the individual test """ print(f"***info: {info}***") @@ -115,7 +115,7 @@ def test_shared_storage_connector_hashes(tmp_path): """ Tests that SharedStorageConnector saves KV to the storage locations with proper hashes; that are unique for inputs with identical text but - differnt images (same size), or same multiple images but different orders. + different images (same size), or same multiple images but different orders. """ # Using tmp_path as the storage path to store KV print(f"KV storage path at: {str(tmp_path)}") @@ -171,12 +171,12 @@ def test_shared_storage_connector_hashes(tmp_path): img=[image_1], expected_len=2, info=("image_1 single input the 2nd time. " - "It should not form aother new hash.")), + "It should not form another new hash.")), InputCase(text=TEXT_PROMPTS[0], img=[image_2], expected_len=2, info=("image_2 single input the 2nd time. " - "It should not form aother new hash.")), + "It should not form another new hash.")), InputCase(text=TEXT_PROMPTS[0], img=[image_1, image_2], expected_len=3, @@ -189,12 +189,12 @@ def test_shared_storage_connector_hashes(tmp_path): img=[image_1, image_2], expected_len=4, info=("[image_1, image_2] input the 2nd time. " - "It should not form aother new hash.")), + "It should not form another new hash.")), InputCase(text=TEXT_PROMPTS[0], img=[image_2, image_1], expected_len=4, info=("[image_2, image_1] input the 2nd time. " - "It should not form aother new hash.")), + "It should not form another new hash.")), InputCase(text=TEXT_PROMPTS[0], img=[], expected_len=5, diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 3f068d5e8c7eb..0cae1c7bc0518 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa SharedStorageConnector) +from vllm.utils import sha256 from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) @@ -127,11 +128,11 @@ def create_request(request_id: int, use_all_1s_for_prompt_tokens: bool = False, num_remote_blocks: int = 3, block_size: int = 16, - hash_fn: Callable = hash) -> Request: + hash_fn: Callable = sha256) -> Request: """Make dummy request for testing.""" global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(hash_fn) _none_hash_initialized = True kv_transfer_params: Optional[dict[str, Any]] = None diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py index a7fde1990f7ed..891f55a14633b 100644 --- a/tests/v1/logits_processors/test_custom_offline.py +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -15,6 +15,7 @@ from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, POOLING_MODEL_NAME, TEMP_GREEDY, CustomLogitprocSource, DummyLogitsProcessor, + WrappedPerReqLogitsProcessor, dummy_module) from tests.v1.logits_processors.utils import entry_points as fake_entry_points from tests.v1.logits_processors.utils import prompts @@ -80,7 +81,7 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: target_token = params.extra_args[DUMMY_LOGITPROC_ARG] if not all(x == target_token for x in lp_toks): raise AssertionError( - f"Request {bdx} generated {lp_toks}, shoud all be " + f"Request {bdx} generated {lp_toks}, should all be " f"{target_token}") else: # This request does not exercise custom logitproc (or custom @@ -161,6 +162,38 @@ def test_custom_logitsprocs(monkeypatch, _run_test(kwargs, logitproc_loaded=True) +@create_new_process_for_each_test() +def test_custom_logitsprocs_req(monkeypatch): + """Test passing request-level logits processor to offline Python interface + + Wrap a request-level logits processor to create a batch level logits + processor that has a well-defined behavior (mask out all tokens except one + `target_token`) + + Construct an `LLM` instance which loads the wrapped logits processor. Pass + the custom logitproc as a class object. + + Construct a reference `LLM` instance with no custom logitproc + + Pass in a batch of requests, 50% of which pass a `target_token` value + in through `SamplingParams.extra_args`, 50% of which do not. + + Validate that + * Requests which do not activate the custom logitproc, yield the same + results for both `LLM` instances + * Requests which activate the custom logitproc, only output `target_token` + + Args: + monkeypatch: for setting env vars + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + random.seed(40) + _run_test({"logits_processors": [WrappedPerReqLogitsProcessor]}, + logitproc_loaded=True) + + @create_new_process_for_each_test() @pytest.mark.parametrize("logitproc_source", [ CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT, diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index c36f1bd021c70..d3b7f314da099 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -3,15 +3,21 @@ import types from enum import Enum, auto -from typing import Optional +from typing import Any, Optional import torch from vllm.config import VllmConfig -from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate, - LogitsProcessor) +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, + AdapterLogitsProcessor, + BatchUpdate, LogitsProcessor, + RequestLogitsProcessor) from vllm.v1.sample.logits_processor.builtin import process_dict_updates +logger = init_logger(__name__) + MODEL_NAME = "facebook/opt-125m" POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5" DUMMY_LOGITPROC_ARG = "target_token" @@ -63,11 +69,12 @@ class DummyLogitsProcessor(LogitsProcessor): return logits # Save target values before modification - rows_list = list(self.req_info.keys()) - cols = torch.tensor([self.req_info[i] for i in rows_list], + cols = torch.tensor(list(self.req_info.values()), + dtype=torch.long, + device=logits.device) + rows = torch.tensor(list(self.req_info.keys()), dtype=torch.long, device=logits.device) - rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) values_to_keep = logits[rows, cols].clone() # Mask all but target tokens @@ -104,5 +111,60 @@ class EntryPoints(list): self.names = [ep.name for ep in eps] +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Optional[ + Any] = params.extra_args and params.extra_args.get("target_token") + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", target_token) + return None + return DummyPerReqLogitsProcessor(target_token) + + """Fake version of importlib.metadata.entry_points""" entry_points = lambda group: EntryPoints(group) diff --git a/tests/v1/metrics/test_engine_logger_apis.py b/tests/v1/metrics/test_engine_logger_apis.py new file mode 100644 index 0000000000000..e6a4d0a2a2e8b --- /dev/null +++ b/tests/v1/metrics/test_engine_logger_apis.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy + +import pytest + +from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM +from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger + + +class DummyStatLogger: + """ + A dummy stat logger for testing purposes. + Implements the minimal interface expected by StatLoggerManager. + """ + + def __init__(self, vllm_config, engine_idx): + self.vllm_config = vllm_config + self.engine_idx = engine_idx + self.recorded = [] + self.logged = False + self.engine_initialized = False + + def record(self, scheduler_stats, iteration_stats, engine_idx): + self.recorded.append((scheduler_stats, iteration_stats, engine_idx)) + + def log(self): + self.logged = True + + def log_engine_initialized(self): + self.engine_initialized = True + + +@pytest.fixture +def log_stats_enabled_engine_args(): + """ + Shared fixture providing common AsyncEngineArgs configuration + used across multiple tests. + """ + return AsyncEngineArgs( + model="distilbert/distilgpt2", + dtype="half", + disable_log_stats=False, + enforce_eager=True, + ) + + +@pytest.mark.asyncio +async def test_async_llm_replace_default_loggers( + log_stats_enabled_engine_args): + """ + RayPrometheusStatLogger should replace the default PrometheusStatLogger + """ + + engine = AsyncLLM.from_engine_args(log_stats_enabled_engine_args, + stat_loggers=[RayPrometheusStatLogger]) + assert isinstance(engine.logger_manager.prometheus_logger, + RayPrometheusStatLogger) + engine.shutdown() + + +@pytest.mark.asyncio +async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args): + """ + It's still possible to use custom stat loggers exclusively by passing + disable_log_stats=True in addition to a list of custom stat loggers. + """ + # Create engine_args with disable_log_stats=True for this test + disabled_log_engine_args = copy.deepcopy(log_stats_enabled_engine_args) + disabled_log_engine_args.disable_log_stats = True + + # Disable default loggers; pass custom stat logger to the constructor + engine = AsyncLLM.from_engine_args(disabled_log_engine_args, + stat_loggers=[DummyStatLogger]) + + assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1 + assert isinstance(engine.logger_manager.per_engine_logger_dict[0][0], + DummyStatLogger) + + # log_stats is still True, since custom stat loggers are used + assert engine.log_stats + + engine.shutdown() diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index e835c029634ce..570e330208a39 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -430,7 +430,7 @@ def test_zero_logprobs(vllm_model, example_prompts, def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): - """Engine should return all vocabulary logprobs + """Engine should return all vocabulary logprobs and prompt logprobs Args: example_prompts: list of example prompts (test fixture) @@ -444,16 +444,24 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): # 2 other llms alive during whole session gpu_memory_utilization=0.15, max_model_len=256) + sampling_params_logprobs_all = SamplingParams(max_tokens=5, - logprobs=-1) + logprobs=-1, + prompt_logprobs=-1) results_logprobs_all = runner.llm.generate( example_prompts, sampling_params=sampling_params_logprobs_all) vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() + for i in range(len(results_logprobs_all)): logprobs = results_logprobs_all[i].outputs[0].logprobs + prompt_logprobs = results_logprobs_all[i].prompt_logprobs assert logprobs is not None for logprob in logprobs: assert len(logprob) == vocab_size + assert prompt_logprobs is not None + assert prompt_logprobs[0] is None + for prompt_logprob in prompt_logprobs[1:]: + assert len(prompt_logprob) == vocab_size @pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 46e3a611c6d26..e7f6b68fc3f77 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -12,12 +12,15 @@ from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, get_attention_backend) -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.config.load import LoadConfig from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch model_dir = "meta-llama/Llama-3.1-8B-Instruct" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" @@ -63,6 +66,86 @@ def _create_proposer( device=current_platform.device_type) +def test_prepare_next_token_ids(): + """ + Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded. + Each will produce a device tensor of next_token_ids, taking as input + either the GPU tensor of sampled_token_ids with -1 for rejected tokens, + or the CPU python list[list[int]] with the rejected tokens removed. + """ + device = torch.device(current_platform.device_type) + + num_requests = 4 + num_speculative_tokens = 4 + batch_spec = BatchSpec( + seq_lens=[num_speculative_tokens + 1] * num_requests, + query_lens=[num_speculative_tokens + 1] * num_requests, + ) + + req_ids = [f"req_{i+1}" for i in range(num_requests)] + mock_input_batch = mock.MagicMock(spec=InputBatch) + mock_input_batch.req_ids = req_ids + mock_input_batch.num_reqs = num_requests + mock_input_batch.vocab_size = 100 + + mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids} + mock_requests = {} + for req_id in req_ids: + mock_request = mock.MagicMock(spec=CachedRequestState) + # Each request will have a backup next token id of 10, 20, 30, 40 + mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10 + mock_request.num_computed_tokens = 0 + mock_requests[req_id] = mock_request + + sampled_token_ids = [ + [0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled + [0, 1, 2, 3, 4], # all accepted, "4" sampled + [-1, -1, -1, -1, -1], # sampling skipped, use backup token "30" + [-1, -1, -1, -1, -1] # this request will be discarded + ] + sampled_token_ids_tensor = torch.tensor(sampled_token_ids, + dtype=torch.int32, + device=device) + sampled_token_ids_cpu = [[i for i in seq if i != -1] + for seq in sampled_token_ids] + + expected_next_token_ids_cpu = [1, 4, 30, 40] + expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu, + dtype=torch.int32, + device=device) + + proposer = _create_proposer("eagle", num_speculative_tokens) + + next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu( + sampled_token_ids_cpu, mock_requests, mock_input_batch, + mock_num_scheduled_tokens) + + assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device) + num_discarded_reqs = 1 + + expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0], + dtype=torch.int32, + device=device) + + next_token_ids_from_padded, valid_sampled_tokens_count = \ + proposer.prepare_next_token_ids_padded( + common_attn_metadata, sampled_token_ids_tensor, mock_requests, + mock_input_batch, discarded_req_indices, num_discarded_reqs) + + assert torch.equal(next_token_ids_from_padded, + expected_next_token_ids_tensor) + assert torch.equal(valid_sampled_tokens_count, + expected_valid_sampled_tokens_count) + + def test_prepare_inputs(): """ cu_target_query_lens: [0, a, a + b, a + b + c] @@ -89,10 +172,24 @@ def test_prepare_inputs(): device=device, ) - # Rejected tokens per request: [1, 3, 2] - num_rejected_tokens = torch.tensor([1, 3, 2], - dtype=torch.int32, - device=device) + # If there are `k` sampled tokens, then `k-1` tokens are draft tokens + # from the previous iteration, and the last token is the bonus token sampled + # from the base model. + num_draft_tokens = [3, 6, 4] # one less than query_lens + # num rejected tokens is [1, 3, 2] + ACCEPT_TOKEN = 0 + BONUS_TOKEN = 1 + REJECT_TOKEN = -1 + sampled_token_ids = [ + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN], + [ + ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, + REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN + ], + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN] + ] + sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN] + for seq in sampled_token_ids] # Expected calculations: # query_len_per_req = [4, 7, 5] @@ -124,7 +221,7 @@ def test_prepare_inputs(): proposer = _create_proposer("eagle", 1) updated_metadata, token_indices = proposer.prepare_inputs( - common_attn_metadata, num_rejected_tokens.cpu()) + common_attn_metadata, sampled_token_ids, num_draft_tokens) assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens) @@ -132,6 +229,77 @@ def test_prepare_inputs(): assert torch.equal(token_indices, expected_token_indices) +def test_prepare_inputs_padded(): + """ + Input scenario is 3 requests with num_speculative_tokens == 2 and: + - Request 1: query_len = 3, rejected = 1 + - Request 2: query_len = 3, rejected = 0 + - Request 3: query_len = 3, rejected = 2 + + Expected outputs: + token_indices: [0, 1, 2, + 3, 4, 5, + 6, 7, 8] + Reason: Deferred computation should not disturb the original indices. + + token_indices_to_sample: [1, 5, 6] + Reason: After accounting for rejections, these are the valid token positions + from the original indices to sample from. + """ + + device = torch.device(current_platform.device_type) + + expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], + dtype=torch.int32, + device=device) + expected_token_indices_to_sample = torch.tensor([1, 5, 6], + dtype=torch.int32, + device=device) + + num_speculative_tokens = 2 + batch_spec = BatchSpec( + seq_lens=[3, 3, 3], + query_lens=[3, 3, 3], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9] + expected_query_start_loc = torch.tensor([0, 3, 6, 9], + dtype=torch.int32, + device=device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids=[[0] * num_speculative_tokens] * 3, + device=device, + ) + + # num_rejected_tokens = [1, 0, 2] + # num_draft_tokens = [2, 2, 2] + # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens + valid_sampled_tokens_count = torch.tensor([2, 3, 1], + dtype=torch.int32, + device=device) + + proposer = _create_proposer("eagle", num_speculative_tokens) + + output_metadata, token_indices, token_indices_to_sample = \ + proposer.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) + + assert output_metadata.max_query_len == 3 + assert torch.equal(output_metadata.query_start_loc, + expected_query_start_loc) + assert torch.equal(token_indices, expected_token_indices) + assert torch.equal(token_indices_to_sample, + expected_token_indices_to_sample) + + @pytest.mark.parametrize("method", ["eagle", "eagle3"]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @@ -364,12 +532,15 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder + proposer.runner.attn_groups[0][0].metadata_builders = [ + attn_metadata_builder + ] result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, + last_token_indices=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) @@ -488,7 +659,9 @@ def test_propose_tree(spec_token_tree): # Mock runner for attention metadata building. proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder + proposer.runner.attn_groups[0][0].metadata_builders = [ + attn_metadata_builder + ] # Setup inputs for the proposer. target_token_ids = torch.randint(0, @@ -521,6 +694,7 @@ def test_propose_tree(spec_token_tree): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, + last_token_indices=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) assert result.shape == (batch_size, num_speculative_tokens) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 6317817408661..eacb2ad584baf 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -187,7 +187,7 @@ def test_tree_attn_correctness() -> None: dtype=torch.bfloat16, ) - # Setup the block table and KV cache for paged KV. + # Set up the block table and KV cache for paged KV. assert max_sequence_length % block_size == 0 max_blocks_per_batch = max_sequence_length // block_size kv_cache = torch.randn( @@ -222,7 +222,7 @@ def test_tree_attn_correctness() -> None: num_alloc_blocks_per_batch] = block_ids.view( -1, num_alloc_blocks_per_batch) - # Setup the slot mapping for the input KVs. + # Set up the slot mapping for the input KVs. tree_positions = sequence_position + torch.arange( 0, tree_size_q, diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 1f16e92f657e0..794c1f68f1471 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -9,25 +9,9 @@ from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -UNSUPPORTED_MODELS_V1 = [ - "openai/whisper-large-v3", # transcription - "facebook/bart-large-cnn", # encoder decoder -] - MODEL = "meta-llama/Llama-3.2-1B-Instruct" -@pytest.mark.parametrize("model", UNSUPPORTED_MODELS_V1) -def test_reject_unsupported_models(monkeypatch, model): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - args = AsyncEngineArgs(model=model) - - with pytest.raises(NotImplementedError): - _ = args.create_engine_config() - m.delenv("VLLM_USE_V1") - - def test_reject_bad_config(monkeypatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "0") @@ -78,12 +62,6 @@ def test_enable_by_default_fallback(monkeypatch): assert envs.VLLM_USE_V1 m.delenv("VLLM_USE_V1") - # Should fall back to V0 for supported model. - _ = AsyncEngineArgs( - model=UNSUPPORTED_MODELS_V1[0]).create_engine_config() - assert not envs.VLLM_USE_V1 - m.delenv("VLLM_USE_V1") - def test_v1_llm_by_default(monkeypatch): with monkeypatch.context() as m: diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py index bcc2993028dd6..9947fcbe73135 100644 --- a/tests/v1/tpu/test_multimodal.py +++ b/tests/v1/tpu/test_multimodal.py @@ -4,18 +4,19 @@ import openai import pytest -from vllm.multimodal.utils import encode_image_base64, fetch_image +from vllm.multimodal.utils import encode_image_base64 from vllm.platforms import current_platform -from ...entrypoints.openai.test_vision import TEST_IMAGE_URLS +from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS from ...utils import RemoteOpenAIServer @pytest.fixture(scope="session") -def base64_encoded_image() -> dict[str, str]: +def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: encode_image_base64(fetch_image(image_url)) - for image_url in TEST_IMAGE_URLS + image_asset: + encode_image_base64(local_asset_server.get_image_asset(image_asset)) + for image_asset in TEST_IMAGE_ASSETS } @@ -66,7 +67,7 @@ async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, client: openai.AsyncOpenAI = remote_server.get_async_client() # Other requests now should be much faster - for image_url in TEST_IMAGE_URLS: + for image_url in TEST_IMAGE_ASSETS: image_base64 = base64_encoded_image[image_url] chat_completion_from_base64 = await client.chat.completions\ .create( diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index bfba3af57f715..1bc8dff317a74 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -33,10 +33,12 @@ def test_ragged_paged_attention(): ) class FakeAttentionLayer: + _q_scale_float: float _k_scale_float: float _v_scale_float: float layer = FakeAttentionLayer() + layer._q_scale_float = 1.0 layer._k_scale_float = 1.0 layer._v_scale_float = 1.0 diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index ca5c067b364e0..05751badc7619 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -6,8 +6,12 @@ import pytest import torch from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, - apply_top_k_top_p_tpu) +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p + +# isort: off +from vllm.v1.sample.tpu.sampler import (apply_top_k_top_p as + apply_top_k_top_p_tpu) +# isort: on if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 941aa0a77692c..4f4a9c7db88a3 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -10,7 +10,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, - get_kv_cache_config) + get_kv_cache_configs) from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.worker.tpu_model_runner import ( @@ -64,9 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_kwargs=[], - mm_hashes=[], - mm_positions=[], + mm_features=[], sampling_params=SamplingParams(), pooling_params=PoolingParams(), block_ids=([0], ), # block_ids should be tuple[list[int]] @@ -127,7 +125,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: return False num_blocks = block_table.num_blocks_per_row[req_index] - block_table_values = block_table.block_table_np[req_index, :num_blocks] + block_table_values = block_table.block_table.np[req_index, :num_blocks] return (block_table_values == req_block_ids).all() @@ -479,8 +477,8 @@ def test_init_kv_cache_without_kv_sharing(): # 2 (non-MLA) * 8 (num_heads) * 128 (head_dim) # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers) - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 2 assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 @@ -552,8 +550,8 @@ def test_init_kv_cache_with_kv_sharing_valid(): # with KV sharing, we can allocate (available_mem//page_size//1) blocks # which is twice as many as without KV sharing num_expected_blocks = 2 * 20480 # 20GB / 512KB - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache diff --git a/tests/tracing/__init__.py b/tests/v1/tracing/__init__.py similarity index 100% rename from tests/tracing/__init__.py rename to tests/v1/tracing/__init__.py diff --git a/tests/v1/tracing/test_tracing.py b/tests/v1/tracing/test_tracing.py new file mode 100644 index 0000000000000..da8655f95e195 --- /dev/null +++ b/tests/v1/tracing/test_tracing.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa +# type: ignore +from __future__ import annotations + +import threading +from collections.abc import Iterable +from concurrent import futures +from typing import Callable, Generator, Literal + +import grpc +import pytest +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceResponse) +from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( + TraceServiceServicer, add_TraceServiceServicer_to_server) +from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue +from opentelemetry.sdk.environment_variables import ( + OTEL_EXPORTER_OTLP_TRACES_INSECURE) + +from vllm import LLM, SamplingParams +from vllm.tracing import SpanAttributes + +FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" + +FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value', + 'array_value'] + + +def decode_value(value: AnyValue): + field_decoders: dict[FieldName, Callable] = { + "bool_value": (lambda v: v.bool_value), + "string_value": (lambda v: v.string_value), + "int_value": (lambda v: v.int_value), + "double_value": (lambda v: v.double_value), + "array_value": + (lambda v: [decode_value(item) for item in v.array_value.values]), + } + for field, decoder in field_decoders.items(): + if value.HasField(field): + return decoder(value) + raise ValueError(f"Couldn't decode value: {value}") + + +def decode_attributes(attributes: Iterable[KeyValue]): + return {kv.key: decode_value(kv.value) for kv in attributes} + + +class FakeTraceService(TraceServiceServicer): + + def __init__(self): + self.request = None + self.evt = threading.Event() + + def Export(self, request, context): + self.request = request + self.evt.set() + return ExportTraceServiceResponse() + + +@pytest.fixture +def trace_service() -> Generator[FakeTraceService, None, None]: + """Fixture to set up a fake gRPC trace service""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) + service = FakeTraceService() + add_TraceServiceServicer_to_server(service, server) + server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS) + server.start() + + yield service + + server.stop(None) + + +def test_traces( + monkeypatch: pytest.MonkeyPatch, + trace_service: FakeTraceService, +): + with monkeypatch.context() as m: + m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") + m.setenv("VLLM_USE_V1", "1") + sampling_params = SamplingParams( + temperature=0.01, + top_p=0.1, + max_tokens=256, + ) + model = "facebook/opt-125m" + llm = LLM(model=model, + otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, + gpu_memory_utilization=0.3, + disable_log_stats=False) + prompts = ["This is a short prompt"] + outputs = llm.generate(prompts, sampling_params=sampling_params) + print(f"test_traces outputs is : {outputs}") + + timeout = 10 + if not trace_service.evt.wait(timeout): + raise TimeoutError( + f"The fake trace service didn't receive a trace within " + f"the {timeout} seconds timeout") + + request = trace_service.request + assert len(request.resource_spans) == 1, ( + f"Expected 1 resource span, " + f"but got {len(request.resource_spans)}") + assert len(request.resource_spans[0].scope_spans) == 1, ( + f"Expected 1 scope span, " + f"but got {len(request.resource_spans[0].scope_spans)}") + assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( + f"Expected 1 span, " + f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") + + attributes = decode_attributes( + request.resource_spans[0].scope_spans[0].spans[0].attributes) + # assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model + assert attributes.get( + SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE + ) == sampling_params.temperature + assert attributes.get( + SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS + ) == sampling_params.max_tokens + assert attributes.get( + SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n + assert attributes.get( + SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( + outputs[0].prompt_token_ids) + completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) + assert attributes.get( + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens + + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0 + assert attributes.get( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0 + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0 diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 7031859078264..98700ff73fd1f 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -15,6 +15,7 @@ from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -45,7 +46,7 @@ def _compare_objs(obj1, is_same = False if isinstance(a, torch.Tensor): - if (a.numel() == 0 or b.numel() == 0): + if a.numel() == 0 or b.numel() == 0: is_same = (a.numel() == 0 and b.numel() == 0) elif torch.allclose(a, b): is_same = True @@ -61,6 +62,8 @@ def _compare_objs(obj1, is_same = True # if we make it here must be same elif a == b: is_same = True + elif isinstance(a, CpuGpuBuffer): + is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu) assert is_same, f"Attribute {attr_name} is different"\ f" in {obj1} and {obj2}: {a} != {b}" @@ -203,9 +206,7 @@ def _construct_cached_request_state(req_id_suffix: int): prompt_token_ids=prompt_token_ids, sampling_params=_create_sampling_params(), pooling_params=None, - mm_kwargs=[], - mm_positions=[], - mm_hashes=[], + mm_features=[], block_ids=([], ), generator=None, num_computed_tokens=len(output_token_ids), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 6d99029e404ef..8b571f95c5ecf 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import random - import numpy as np import pytest import torch @@ -17,7 +15,7 @@ from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, update_environment_variables from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, - get_kv_cache_config) + get_kv_cache_configs) from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -120,9 +118,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: NewRequestData( req_id=req_id, prompt_token_ids=[1, 2, 3], - mm_kwargs=[], - mm_hashes=[], - mm_positions=[], + mm_features=[], sampling_params=SamplingParams(), pooling_params=None, block_ids=([0], ), @@ -169,7 +165,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: req_state.block_ids[0]): return False num_blocks = block_table.num_blocks_per_row[req_index] - return (block_table.block_table_np[req_index, :num_blocks] == + return (block_table.block_table.np[req_index, :num_blocks] == req_state.block_ids[0]).all() @@ -409,29 +405,30 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): model_runner.model_config.get_head_size() ] # TODO mla test - default_stride = list(range(5)) + default_stride = tuple(range(5)) # Permutation that gets you back to expected kv shape - rnd_stride = tuple(random.sample(default_stride, len(default_stride))) + for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)): - def rnd_stride_order(): - return rnd_stride + def rnd_stride_order(test_stride=test_stride): + return test_stride - # Patch the attention backend class and re-trigger the KV cache creation. - for attn_group in model_runner._attn_group_iterator(): - attn_backend = attn_group.backend - monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", - rnd_stride_order) + # Patch the attention backend class and re-trigger the KV cache creation + for attn_group in model_runner._attn_group_iterator(): + attn_backend = attn_group.backend + monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", + rnd_stride_order) - model_runner.attn_groups = [] - model_runner.initialize_kv_cache(model_runner.kv_cache_config) + model_runner.attn_groups = [] + model_runner.kv_caches = [] + model_runner.initialize_kv_cache(model_runner.kv_cache_config) - # Shape is unchanged, but layout may differ - kv_cache_shape = model_runner.kv_caches[0].shape - assert list(kv_cache_shape) == expected_kv_cache_shape - if default_stride == rnd_stride: - assert all(kv.is_contiguous() for kv in model_runner.kv_caches) - else: - assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) + # Shape is unchanged, but layout may differ + kv_cache_shape = model_runner.kv_caches[0].shape + assert list(kv_cache_shape) == expected_kv_cache_shape + if default_stride == test_stride: + assert all(kv.is_contiguous() for kv in model_runner.kv_caches) + else: + assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) def test_update_config(model_runner): @@ -588,8 +585,8 @@ def test_init_kv_cache_without_kv_sharing(): available_memory = 20 * GiB_bytes # page size for layer 0's kv_cache_spec is 32KB num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers) - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 2 assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 @@ -660,8 +657,8 @@ def test_init_kv_cache_with_kv_sharing_valid(): # with KV sharing, we can allocate (available_mem//page_size//1) blocks # which is twice as many as without KV sharing num_expected_blocks = 655360 # 20GB / 32KB - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache @@ -791,8 +788,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] runner.initialize_kv_cache(kv_cache_config) # random partition of blocks diff --git a/tests/worker/__init__.py b/tests/worker/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/worker/conftest.py b/tests/worker/conftest.py deleted file mode 100644 index 3f202d4dbe948..0000000000000 --- a/tests/worker/conftest.py +++ /dev/null @@ -1,11 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This module tests V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') \ No newline at end of file diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py deleted file mode 100644 index 35ac90b38e840..0000000000000 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ /dev/null @@ -1,648 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import itertools - -import pytest -import torch - -from vllm.engine.arg_utils import EngineArgs -from vllm.platforms import current_platform -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import make_tensor_with_pad -from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner - -BATCH_SIZES = [1, 4, 16, 64, 256] - - -def _create_model_runner(model: str, *args, - **kwargs) -> EncoderDecoderModelRunner: - engine_args = EngineArgs(model, *args, **kwargs) - engine_config = engine_args.create_engine_config() - model_runner = EncoderDecoderModelRunner( - vllm_config=engine_config, - is_driver_worker=True, - ) - return model_runner - - -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") -def test_empty_seq_group(): - """Verify prepare prompt and decode returns empty output - for empty seq group list""" - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=True, - ) - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - ( - input_tokens, - input_positions, - encoder_input_tokens, - encoder_input_positions, - attn_metadata, - return_seq_lens, - ) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.encoder_input_tokens, - model_input.encoder_input_positions, - model_input.attn_metadata, - model_input.seq_lens, - ) - assert input_tokens is None - assert input_positions is None - assert encoder_input_tokens is None - assert encoder_input_positions is None - assert attn_metadata is None - assert return_seq_lens is None - - -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -def test_prepare_prompt(batch_size): - ''' - Test the ability of the encoder/decoder model runner subclass to - produce prefill-phase model inputs & attention metadata. - - Test behavior: - - * Instantiate BART base model & enc/dec model runner - * Construct sequence-group metadata for dummy prompts - * Test that encoder attention, decoder self-attention, - and encoder/decoder cross-attention inputs are correct - - Arguments: - - * batch_size - * backend_name: The attention backend under test - * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) - ''' - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=True, - ) - - seq_lens: list[int] = [] - encoder_seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = {0: [1]} - cross_block_table = [2] - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - ) - assert seq_group_metadata.token_chunk_size == seq_data.get_len() - seq_group_metadata_list.append(seq_group_metadata) - - # Build - # * Decoder model inputs - # * Decoder self-attention KV caching data structures - # * Encoder model inputs - # * Encoder/decoder cross-attention KV caching data structures - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - encoder_input_tokens = model_input.encoder_input_tokens - encoder_input_positions = model_input.encoder_input_positions - cross_slot_mapping = attn_metadata.cross_slot_mapping - assert return_seq_lens == seq_lens - assert len(slot_mapping) == len(input_tokens) - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify input metadata is correct for prompts. - # - Decoder attention metadata - device = model_runner.device - assert attn_metadata.num_prefills > 0 - assert attn_metadata.num_decode_tokens == 0 - assert torch.equal(attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_prefill_seq_len == max(seq_lens) - assert attn_metadata.max_decode_seq_len == 0 - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == encoder_seq_lens - assert torch.equal( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) - - # Test decoder subquery start locs. - start_idx = 0 - start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - start_loc.append(start_idx) - assert torch.equal( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device), - ) - - # Test decoder seq start locs & context lengths - - assert torch.equal( - attn_metadata.seq_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device), - ) - assert torch.equal( - attn_metadata.context_lens_tensor, - torch.zeros(attn_metadata.context_lens_tensor.shape[0], - dtype=torch.int, - device=device), - ) - - # Verify block tables are correct for prompts - # - Decoder self-attention - expected = torch.tensor( - [[] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.equal( - attn_metadata.block_tables, - expected, - ) - # - Encoder/decoder cross-attention - assert torch.equal( - attn_metadata.cross_block_tables, - expected, - ) - - # Cuda graph should not be used for prefill. - assert attn_metadata.use_cuda_graph is False - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) - # -- An indirect check that model_input.input_tokens - # and model_input.input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - input_tokens, - input_positions, - ) - # - Encoder - assert len(encoder_input_tokens) == sum(encoder_seq_lens) - # -- An indirect check that model_input.encoder_input_tokens - # and model_input.encoder_input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - encoder_input_tokens, - encoder_input_positions, - ) - - # Test that vLLM sampling infrastructure chooses the correct - # sequence positions at which to sample (i.e. the end of - # each sequence) in the prefill phase - - expected_selected_token_indices = [] - selected_token_start_idx = 0 - for seq_len in seq_lens: - # Compute the index offset of the final token in each - # prompt (recall that the prompts are concatenated) - expected_selected_token_indices.append(selected_token_start_idx + - seq_len - 1) - selected_token_start_idx += seq_len - - sampling_metadata = model_input.sampling_metadata - actual = sampling_metadata.selected_token_indices - expected = torch.tensor( - expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype, - ) - assert torch.equal(actual, expected) - - -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) -def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): - ''' - Test the ability of the encoder/decoder model runner subclass to - produce decode-phase model inputs & attention metadata. - - Test behavior: - - * Instantiate BART base model & enc/dec model runner - * Construct sequence-group metadata for dummy prompts - * Test that encoder attention, decoder self-attention, - and encoder/decoder cross-attention inputs are correct - - Arguments: - - * batch_size - * multiple_seqs_per_seq_group - * backend_name: The attention backend under test - * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) - ''' - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=True, - ) - - seq_lens: list[int] = [] - encoder_seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = { - 0: [1], - 1: [3] - } if multiple_seqs_per_seq_group else { - 0: [1] - } - cross_block_table = [2] - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(seq_len)) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) - - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={ - 0: seq_data, - 1: seq_data - } if multiple_seqs_per_seq_group else {0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_group_metadata_list.append(seq_group_metadata) - seq_lens.extend( - [seq_len for _ in range(len(seq_group_metadata.seq_data))]) - encoder_seq_lens.extend( - [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) - - # Build - # * Decoder model inputs - # * Decoder self-attention KV caching data structures - # * Encoder model inputs - # * Encoder/decoder cross-attention KV caching data structures - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - encoder_input_tokens = model_input.encoder_input_tokens - encoder_input_positions = model_input.encoder_input_positions - cross_slot_mapping = attn_metadata.cross_slot_mapping - assert return_seq_lens == seq_lens - assert len(slot_mapping) == len(input_tokens) - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify input metadata is correct for decode phase. - # - Decoder attention metadata - device = model_runner.device - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_decode_tokens > 0 - assert torch.equal(attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_prefill_seq_len == 0 - assert attn_metadata.max_decode_seq_len == max(seq_lens) - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == encoder_seq_lens - assert torch.equal( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) - - # Test decoder subquery start locs. - start_idx = 0 - start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += 1 - start_loc.append(start_idx) - assert torch.equal( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device), - ) - - # Test decoder seq start locs. Note that for normal prefill it is - # equivalent to query_start_loc. - start_idx = 0 - seq_start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - seq_start_loc.append(start_idx) - - # Test seq_start_loc and context lengths - - assert torch.equal( - attn_metadata.seq_start_loc, - torch.tensor(seq_start_loc, dtype=torch.int32, device=device), - ) - assert torch.equal( - attn_metadata.context_lens_tensor, - torch.tensor([seq_len - 1 for seq_len in seq_lens], - dtype=torch.int, - device=device)) - - # Verify block tables are correct for prompts - # - Decoder self-attention - flattened_block_tables = [ - block_table for block_table in block_tables.values() - ] - expected = torch.tensor(flattened_block_tables * - len(seq_group_metadata_list), - dtype=torch.int32, - device=model_runner.device) - assert torch.equal( - attn_metadata.block_tables, - expected, - ) - # - Encoder/decoder cross-attention - expected = torch.tensor([ - cross_block_table for seq_group_metadata in seq_group_metadata_list - for _ in range(len(seq_group_metadata.seq_data)) - ], - dtype=torch.int32, - device=model_runner.device) - assert torch.equal( - attn_metadata.cross_block_tables, - expected, - ) - - # Model runner's CUDAGraph setting should be propagated to attention - # metadata. - assert attn_metadata.use_cuda_graph is False - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == len(seq_lens) - assert len(input_positions) == len(seq_lens) - # -- An indirect check that model_input.input_tokens - # and model_input.input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - input_tokens, - input_positions, - ) - # - Encoder - assert len(encoder_input_tokens) == 0 - assert len(encoder_input_tokens) == 0 - # -- An indirect check that model_input.encoder_input_tokens - # and model_input.encoder_input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - encoder_input_tokens, - encoder_input_positions, - ) - - # Test that vLLM sampling infrastructure chooses the correct - # sequence positions at which to sample (i.e. the end of - # each sequence) in the decode phase - - expected_selected_token_indices = [] - for selected_token_start_idx, seq_len in enumerate(seq_lens): - # Compute the index offset of the final token in each - # sequence's decoded outputs; since a single token is - # decoded per iteration per sequence, then the length - # of the decoded tokens for a given sequence is 1 and - # the final index offset into a given sequence's - # generated tokens is 0 (i.e. the expected sampling index - # for a given sequence is just `selected_token_start_idx`) - expected_selected_token_indices.append(selected_token_start_idx) - - sampling_metadata = model_input.sampling_metadata - actual = sampling_metadata.selected_token_indices - expected = torch.tensor( - expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype, - ) - assert torch.equal(actual, expected) - - -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) -def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): - """ - Tests that for encoder-decoder models with CUDA Graph capture and replay - enabled, the tensors used during the decode phase are correctly padded - for varying input batch sizes. - """ - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=False, - ) - block_tables = { - 0: [1], - 1: [3] - } if multiple_seqs_per_seq_group else { - 0: [1] - } - seq_lens: list[int] = [] - encoder_seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - - cross_block_table = [2] - expanded_batch_size = 0 - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(seq_len)) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={ - 0: seq_data, - 1: seq_data - } if multiple_seqs_per_seq_group else {0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_lens.extend( - [seq_len for _ in range(len(seq_group_metadata.seq_data))]) - encoder_seq_lens.extend( - [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) - expanded_batch_size = expanded_batch_size + len( - seq_group_metadata.seq_data) - seq_group_metadata_list.append(seq_group_metadata) - - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - encoder_input_tokens = model_input.encoder_input_tokens - encoder_input_positions = model_input.encoder_input_positions - cross_slot_mapping = attn_metadata.cross_slot_mapping - - # With CUDA Graph capture and replay enabled, the decoder and encoder - # input sequences will be padded. Create the expected padded tensors - # accordingly. - graph_batch_size = model_runner.vllm_config.pad_for_cudagraph( - expanded_batch_size) - cuda_graph_pad_size = graph_batch_size - expanded_batch_size - padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) - padded_encoder_seq_lens = encoder_seq_lens + list( - itertools.repeat(1, cuda_graph_pad_size)) - - assert return_seq_lens == padded_seq_lens - assert len(slot_mapping) == len(input_tokens) - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify attention metadata - device = model_runner.device - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_decode_tokens > 0 - assert torch.equal( - attn_metadata.seq_lens_tensor, - torch.tensor(padded_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == padded_seq_lens - assert attn_metadata.max_prefill_seq_len == 0 - assert attn_metadata.max_decode_seq_len == max(seq_lens) - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens - assert torch.equal( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens) - - # Verify block tables are correct for prompts - # - Decoder self-attention. Pad the block tables as expected. - flattened_block_tables = [ - block_table for _ in range(len(seq_group_metadata_list)) - for block_table in block_tables.values() - ] - flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)]) - expected = make_tensor_with_pad( - flattened_block_tables, - max_len=64, - pad=0, - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.equal( - attn_metadata.block_tables, - expected, - ) - # - Encoder/decoder cross-attention. Pad the cross-attention block tables - # as expected. - expected = [ - cross_block_table for seq_group_metadata in seq_group_metadata_list - for _ in range(len(seq_group_metadata.seq_data)) - ] - expected.extend([[] for _ in range(cuda_graph_pad_size)]) - expected = make_tensor_with_pad( - expected, - max_len=64, - pad=0, - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.equal( - attn_metadata.cross_block_tables, - expected, - ) - - # Model runner's CUDAGraph setting should be propagated to attention - # metadata. - assert attn_metadata.use_cuda_graph is True - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == len(padded_seq_lens) - assert len(input_positions) == len(padded_seq_lens) - # -- An indirect check that model_input.input_tokens - # and model_input.input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - input_tokens, - input_positions, - ) - # - Encoder - assert len(encoder_input_tokens) == 0 - assert len(encoder_input_tokens) == 0 - # -- An indirect check that model_input.encoder_input_tokens - # and model_input.encoder_input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - encoder_input_tokens, - encoder_input_positions, - ) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py deleted file mode 100644 index 0f28ef2ba857b..0000000000000 --- a/tests/worker/test_model_input.py +++ /dev/null @@ -1,113 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses - -import torch - -from vllm.attention import AttentionMetadata, AttentionMetadataBuilder -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.utils import CommonAttentionState -from vllm.model_executor import SamplingMetadata -from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - - -class MockAttentionBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - raise NotImplementedError - - @staticmethod - def get_impl_cls(): - raise NotImplementedError - - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return AttentionMetadata - - @staticmethod - def get_builder_cls() -> type["AttentionMetadataBuilder"]: - return AttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> tuple[int, ...]: - raise NotImplementedError - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - pass - - @staticmethod - def copy_blocks( - kv_caches: list[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - pass - - -def test_model_runner_input(): - sampling_metadata = SamplingMetadata( - ["seq_group"], - "selected_token_indices", - "categorized_sample_indices", - "num_prompts", - ) - attn_metadata = AttentionMetadata( - num_prefills=1, - num_prefill_tokens=2, - num_decode_tokens=3, - slot_mapping=torch.zeros(1), - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - ) - model_input = ModelInputForGPUWithSamplingMetadata( - input_tokens=torch.ones(10), - input_positions=torch.ones(10), - sampling_metadata=sampling_metadata, - attn_metadata=attn_metadata) - - assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata) - - # Test round trip serialization. - tensor_dict = model_input.as_broadcastable_tensor_dict() - attn_backend = MockAttentionBackend() - received_model_input = ( - ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=attn_backend)) - # Check that received copy has correct values. - assert isinstance(received_model_input, - ModelInputForGPUWithSamplingMetadata) - assert received_model_input.input_tokens is not None - assert ( - received_model_input.input_tokens == model_input.input_tokens).all() - assert received_model_input.input_positions is not None - assert (received_model_input.input_positions == model_input.input_positions - ).all() - assert received_model_input.multi_modal_kwargs is None - assert (received_model_input.multi_modal_kwargs == - model_input.multi_modal_kwargs) - assert received_model_input.lora_requests is None - assert received_model_input.lora_requests == model_input.lora_requests - assert received_model_input.lora_mapping is None - assert received_model_input.lora_mapping == model_input.lora_mapping - for field in dataclasses.fields(AttentionMetadata): - assert getattr(received_model_input.attn_metadata, field.name, - None) == getattr(attn_metadata, field.name, None) - # For sampling metadata, only selected_token_indices is copied. - assert (received_model_input.sampling_metadata.selected_token_indices == - sampling_metadata.selected_token_indices) - assert received_model_input.sampling_metadata.seq_groups is None diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py deleted file mode 100644 index 0be25aa2fc35d..0000000000000 --- a/tests/worker/test_model_runner.py +++ /dev/null @@ -1,462 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import get_open_port -from vllm.worker.model_runner import ModelRunner - - -def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: - engine_args = EngineArgs(model, *args, **kwargs) - engine_config = engine_args.create_engine_config() - model_runner = ModelRunner( - vllm_config=engine_config, - is_driver_worker=True, - ) - return model_runner - - -def test_deepseek_mla_attn_backend_module(): - model_runner = _create_model_runner( - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", - trust_remote_code=True, - enable_chunked_prefill=False, - ) - assert model_runner.attn_backend.__name__ == "TritonMLABackend" - - -@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) -@pytest.mark.parametrize("use_prompt_embeds", [True, False]) -def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch): - if use_prompt_embeds: - # Prompt Embeddings is only currently supported on V0 - monkeypatch.setenv("VLLM_USE_V1", "0") - - model_runner = _create_model_runner( - "facebook/opt-125m", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enable_prompt_embeds=True, - ) - - seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = {0: [1]} - expected_input_embeds_len = 0 - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - if use_prompt_embeds: - seq_data = SequenceData.from_seqs( - prompt_token_ids=[0] * seq_len, - prompt_embeds=torch.rand(seq_len, 10), - ) - expected_input_embeds_len += seq_len - else: - seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len)) - - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - ) - assert seq_group_metadata.token_chunk_size == seq_data.get_len() - seq_group_metadata_list.append(seq_group_metadata) - - expected_selected_token_indices = [] - selected_token_start_idx = 0 - for seq_len in seq_lens: - expected_selected_token_indices.append(selected_token_start_idx + - seq_len - 1) - selected_token_start_idx += seq_len - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - input_embeds = model_input.inputs_embeds - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - assert return_seq_lens == seq_lens - assert len(slot_mapping) == len(input_tokens) - - # Verify input metadata is correct for prompts. - device = model_runner.device - assert attn_metadata.num_prefills > 0 - assert attn_metadata.num_decode_tokens == 0 - torch.testing.assert_close( - attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_prefill_seq_len == max(seq_lens) - assert attn_metadata.max_decode_seq_len == 0 - - # Test subquery start locs. - start_idx = 0 - start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - start_loc.append(start_idx) - torch.testing.assert_close( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device)) - - # Test seq start locs. Note that for normal prefill it is - # equivalent to query_start_loc. - start_idx = 0 - seq_start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - seq_start_loc.append(start_idx) - - torch.testing.assert_close( - attn_metadata.seq_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device)) - torch.testing.assert_close( - attn_metadata.context_lens_tensor, - torch.zeros(attn_metadata.context_lens_tensor.shape[0], - dtype=torch.int, - device=device)) - - expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) - torch.testing.assert_close(attn_metadata.block_tables, expected) - # Cuda graph should not be used for prerill. - assert attn_metadata.use_cuda_graph is False - - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) - if expected_input_embeds_len == 0: - torch.testing.assert_close(input_tokens, input_positions) - assert input_embeds is None - else: - assert len(input_embeds) == expected_input_embeds_len - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - query_lens=seq_lens, - device=model_runner.device, - pin_memory=model_runner.pin_memory) - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) - actual = sampling_metadata.selected_token_indices - expected = torch.tensor(expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype) - torch.testing.assert_close(actual, expected) - torch.allclose(input_tokens, input_positions) - - actual = sampling_metadata.selected_token_indices - expected = torch.tensor(expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype) - torch.testing.assert_close(actual, expected) - - -@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) -@pytest.mark.parametrize("use_prompt_embeds", [True, False]) -def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): - if use_prompt_embeds: - # Prompt Embeddings is only currently supported on V0 - monkeypatch.setenv("VLLM_USE_V1", "0") - - model_runner = _create_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - enforce_eager=False, - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enable_prompt_embeds=True, - ) - - context_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - # Assume each seq group finishes prefill. - for i in range(batch_size): - # make sure all tokens fit into one block - context_len = i % (model_runner.block_size - 1) + 1 - context_lens.append(context_len) - if use_prompt_embeds: - seq_data = SequenceData.from_seqs( - prompt_token_ids=[0] * context_len, - prompt_embeds=torch.rand(context_len, 10), - ) - output_embed = torch.rand(10) - else: - seq_data = SequenceData.from_seqs( - prompt_token_ids=range(context_len)) - output_embed = None - seq_data.update_num_computed_tokens(context_len) - # Append one token ID since prefill is finished. - seq_data.append_token_id(1, 0, output_embed) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_group_metadata_list.append(seq_group_metadata) - - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - input_embeds = model_input.inputs_embeds - attn_metadata = model_input.attn_metadata - slot_mapping = attn_metadata.slot_mapping - - assert len(slot_mapping) == len(input_tokens) - - expected_bs = model_runner.vllm_config.pad_for_cudagraph( - len(seq_group_metadata_list)) - # Verify input metadata is correct for prompts. - device = model_runner.device - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_prefill_tokens == 0 - seq_lens = [context_len + 1 for context_len in context_lens] - # seq_lens are padded to expected_bs - for _ in range(expected_bs - len(seq_lens)): - seq_lens.append(1) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.num_decode_tokens == len(seq_lens) - start_idx = 0 - start_loc = [start_idx] - for _ in context_lens: - # decode has only 1 token for query. - start_idx += 1 - start_loc.append(start_idx) - torch.testing.assert_close( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device)) - - start_idx = 0 - seq_start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - seq_start_loc.append(start_idx) - torch.testing.assert_close( - attn_metadata.seq_start_loc, - torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) - - torch.testing.assert_close( - attn_metadata.context_lens_tensor, - torch.tensor(context_lens, dtype=torch.int, device=device)) - assert attn_metadata.max_decode_seq_len == max(seq_lens) - torch.testing.assert_close( - attn_metadata.seq_lens_tensor[:len(seq_lens)], - torch.tensor(seq_lens, dtype=torch.int, device=device)) - - # block table's first index corresponds to each batch, meaning in - # decoding it is each token. - assert attn_metadata.block_tables.shape[0] == len(input_tokens) - # Block table's second dim corresponds to each token's block number. - # It is padded up to - assert attn_metadata.block_tables.shape[1] == ( - model_runner.get_max_block_per_batch()) - assert attn_metadata.use_cuda_graph is True - - assert len(input_tokens) == expected_bs - assert len(input_positions) == expected_bs - if use_prompt_embeds: - expected_input_embeds_length = start_loc[-1] - assert len(input_embeds) == expected_input_embeds_length - assert expected_input_embeds_length <= expected_bs - else: - assert input_embeds is None - - # Verify Sampling - expected_selected_token_indices = [] - for selected_token_start_idx, _ in enumerate(context_lens): - expected_selected_token_indices.append(selected_token_start_idx) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query lens is all 1 for decode. - query_lens=[1 for _ in range(len(context_lens))], - device=model_runner.device, - pin_memory=model_runner.pin_memory) - actual = sampling_metadata.selected_token_indices - expected = torch.tensor(expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype) - torch.testing.assert_close(actual, expected) - - -def test_empty_seq_group(): - """Verify prepare prompt and decode returns empty output.""" - model_runner = _create_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - enforce_eager=False, - ) - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - - assert input_tokens is None - assert input_positions is None - assert attn_metadata is None - - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - input_embeds = model_input.inputs_embeds - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - - assert input_tokens is None - assert input_positions is None - assert input_embeds is None - assert attn_metadata is None - assert return_seq_lens is None - - -@pytest.fixture -def distributed_init(): - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}", - local_rank=0) - ensure_model_parallel_initialized(1, 1) - - -@pytest.mark.parametrize("batch_size", list(range(2, 128, 3))) -@pytest.mark.parametrize("enforce_eager", [True, False]) -@pytest.mark.parametrize('use_prompt_embeds', [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds, - distributed_init, monkeypatch): - if use_prompt_embeds: - # Prompt Embeddings is only currently supported on V0 - monkeypatch.setenv("VLLM_USE_V1", "0") - - model_runner = _create_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - enforce_eager=enforce_eager, - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=True, - enable_prompt_embeds=True, - ) - - # Add prefill requests. - seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - prefill_metadata_list: list[SequenceGroupMetadata] = [] - decode_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = {0: [1]} - prefill_batch_size = batch_size // 2 - decode_batch_size = batch_size - prefill_batch_size - expected_input_embeds_len = 0 - for i in range(prefill_batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - if use_prompt_embeds: - seq_data = SequenceData.from_seqs( - prompt_token_ids=[0] * seq_len, - prompt_embeds=torch.rand(seq_len, 10), - ) - expected_input_embeds_len += seq_len - else: - seq_data = SequenceData.from_seqs( - prompt_token_ids=range(seq_len), ) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - ) - assert seq_group_metadata.token_chunk_size == seq_data.get_len() - seq_group_metadata_list.append(seq_group_metadata) - prefill_metadata_list.append(seq_group_metadata) - - # Add decode requests - for i in range(prefill_batch_size, batch_size): - # make sure all tokens fit into one block - context_len = i % (model_runner.block_size - 1) + 1 - if use_prompt_embeds: - seq_data = SequenceData.from_seqs( - prompt_token_ids=[0] * context_len, - prompt_embeds=torch.rand(context_len, 10), - ) - output_embed = torch.rand(10) - # This also iterates the expected input_embeds, because the model - # needs both the input and output embeddings passed into together - expected_input_embeds_len += 1 - else: - seq_data = SequenceData.from_seqs( - prompt_token_ids=range(context_len), ) - output_embed = None - assert len(seq_data.prompt_token_ids) == context_len - seq_data.append_token_id(1, 0, output_embed) - seq_data.update_num_computed_tokens(context_len) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables={0: [1]}, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_group_metadata_list.append(seq_group_metadata) - decode_metadata_list.append(seq_group_metadata) - - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - input_embeds = model_input.inputs_embeds - attn_metadata = model_input.attn_metadata - - prefill_meta_actual = attn_metadata.prefill_metadata - decode_meta_actual = attn_metadata.decode_metadata - - assert len(attn_metadata.slot_mapping) == len(input_tokens) - assert len(input_positions) == len(input_tokens) - assert attn_metadata.num_prefills == prefill_batch_size - assert attn_metadata.num_decode_tokens == decode_batch_size - assert attn_metadata.num_prefill_tokens == sum(seq_lens) - if expected_input_embeds_len == 0: - assert input_embeds is None - else: - assert len(input_embeds) == expected_input_embeds_len - - # Verify attn metadata is consistent. We don't need to test individual - # values here because they are tested above. - attn_metadata = model_runner._prepare_model_input_tensors( - seq_group_metadata_list).attn_metadata - - for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), - vars(prefill_meta_actual)): - assert attr_expected[1] == attr_actual[1] - for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), - vars(decode_meta_actual)): - assert attr_expected[1] == attr_actual[1] diff --git a/tests/worker/test_profile.py b/tests/worker/test_profile.py deleted file mode 100644 index d8767f700b576..0000000000000 --- a/tests/worker/test_profile.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.engine.arg_utils import EngineArgs -from vllm.utils import get_distributed_init_method, get_ip, get_open_port -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.worker import Worker - - -def test_gpu_memory_profiling(): - # Tests the gpu profiling that happens in order to determine the number of - # KV cache blocks that we can allocate on the GPU. - # This test mocks the maximum available gpu memory so that it can run on - # any gpu setup. - - # Set up engine args to build a worker. - engine_args = EngineArgs(model="facebook/opt-125m", - dtype="half", - load_format="dummy") - engine_config = engine_args.create_engine_config() - engine_config.cache_config.num_gpu_blocks = 1000 - engine_config.cache_config.num_cpu_blocks = 1000 - - # Create the worker. - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - worker = Worker( - vllm_config=engine_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - is_driver_worker=True, - ) - - # Set 10GiB as the total gpu ram to be device-agnostic - def mock_mem_info(): - current_usage = torch.cuda.memory_stats( - )["allocated_bytes.all.current"] - mock_total_bytes = 10 * 1024**3 - free = mock_total_bytes - current_usage - - return (free, mock_total_bytes) - - from unittest.mock import patch - with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info): - # Load the model so we can profile it - worker.init_device() - worker.load_model() - gpu_blocks, _ = worker.determine_num_available_blocks() - - # Peak vram usage by torch should be 0.47 GiB - # Model weights take 0.25 GiB - # No memory should be allocated outside of torch - # 9.0 GiB should be the utilization target - # 8.28 GiB should be available for the KV cache - block_size = CacheEngine.get_cache_block_size( - engine_config.cache_config, engine_config.model_config, - engine_config.parallel_config) - - expected_blocks = (8.28 * 1024**3) // block_size - - # Check within a small tolerance for portability - # Hardware, kernel, or dependency changes could all affect memory - # utilization. - # A 100 block tolerance here should be about 60MB of wiggle room. - assert abs(gpu_blocks - expected_blocks) < 100 diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py deleted file mode 100644 index 6d9f404ac207b..0000000000000 --- a/tests/worker/test_swap.py +++ /dev/null @@ -1,87 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_distributed_init_method, get_ip, get_open_port -from vllm.worker.worker import Worker - - -def test_swap() -> None: - # Configure the engine. - engine_args = EngineArgs(model="distilbert/distilgpt2", - dtype="half", - load_format="dummy") - engine_config = engine_args.create_engine_config() - engine_config.cache_config.num_gpu_blocks = 1000 - engine_config.cache_config.num_cpu_blocks = 1000 - - # Create the worker. - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - worker = Worker( - vllm_config=engine_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - is_driver_worker=True, - ) - - # Initialize the worker. - worker.init_device() - worker.load_model() - worker.initialize_cache( - num_gpu_blocks=engine_config.cache_config.num_gpu_blocks, - num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) - - # Randomly initialize the cache. - gpu_cache = worker.cache_engine[0].gpu_cache - cpu_cache = worker.cache_engine[0].cpu_cache - num_layers = len(gpu_cache) - for i in range(num_layers): - gpu_key_cache, gpu_value_cache = gpu_cache[i] - gpu_key_cache.random_() - gpu_value_cache.random_() - cpu_key_cache, cpu_value_cache = cpu_cache[i] - cpu_key_cache.random_() - cpu_value_cache.random_() - - allclose = lambda a, b: torch.allclose( - a.cuda(), b.cuda(), rtol=0.0, atol=0.0) - - # Test swap out. - blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)] - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=[], - blocks_to_swap_in=[], - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=[], - ) - worker.execute_model(execute_model_req=execute_model_req) - - for i in range(num_layers): - gpu_key_cache, gpu_value_cache = gpu_cache[i] - cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_out: - assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) - assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) - - # Test swap in. - execute_model_req.blocks_to_swap_out = [] - execute_model_req.blocks_to_swap_in = [ - (19, 45), - (67, 23), - (12, 78), - (40, 99), - (1, 71), - ] - worker.execute_model(execute_model_req=execute_model_req) - - for i in range(num_layers): - gpu_key_cache, gpu_value_cache = gpu_cache[i] - cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in execute_model_req.blocks_to_swap_in: - assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) - assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py index ad0ae45d1d465..fe717121db40d 100644 --- a/tools/check_pickle_imports.py +++ b/tools/check_pickle_imports.py @@ -39,6 +39,7 @@ ALLOWED_FILES = set([ 'vllm/engine/multiprocessing/client.py', 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', + 'vllm/distributed/device_communicators/shm_object_storage.py', 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', 'benchmarks/kernels/benchmark_lora.py', diff --git a/tools/generate_cmake_presets.py b/tools/generate_cmake_presets.py index 5f92f2f5848fa..4869a71307e4e 100644 --- a/tools/generate_cmake_presets.py +++ b/tools/generate_cmake_presets.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse import json import multiprocessing import os @@ -26,7 +27,8 @@ def get_cpu_cores(): return multiprocessing.cpu_count() -def generate_presets(output_path="CMakeUserPresets.json"): +def generate_presets(output_path="CMakeUserPresets.json", + force_overwrite=False): """Generates the CMakeUserPresets.json file.""" print("Attempting to detect your system configuration...") @@ -143,12 +145,15 @@ def generate_presets(output_path="CMakeUserPresets.json"): output_file_path = os.path.join(project_root, output_path) if os.path.exists(output_file_path): - overwrite = input( - f"'{output_file_path}' already exists. Overwrite? (y/N): ").strip( - ).lower() - if overwrite != 'y': - print("Generation cancelled.") - return + if force_overwrite: + print(f"Overwriting existing file '{output_file_path}'") + else: + overwrite = input( + f"'{output_file_path}' already exists. Overwrite? (y/N): " + ).strip().lower() + if overwrite != 'y': + print("Generation cancelled.") + return try: with open(output_file_path, "w") as f: @@ -166,4 +171,12 @@ def generate_presets(output_path="CMakeUserPresets.json"): if __name__ == "__main__": - generate_presets() + parser = argparse.ArgumentParser() + parser.add_argument( + "--force-overwrite", + action="store_true", + help="Force overwrite existing CMakeUserPresets.json without prompting" + ) + + args = parser.parse_args() + generate_presets(force_overwrite=args.force_overwrite) diff --git a/tools/install_deepgemm.sh b/tools/install_deepgemm.sh index b125cda96f179..98427f1835ec2 100755 --- a/tools/install_deepgemm.sh +++ b/tools/install_deepgemm.sh @@ -105,4 +105,4 @@ fi popd -echo "✅ DeepGEMM installation completed successfully" \ No newline at end of file +echo "✅ DeepGEMM installation completed successfully" diff --git a/tools/install_gdrcopy.sh b/tools/install_gdrcopy.sh new file mode 100755 index 0000000000000..481723320c63b --- /dev/null +++ b/tools/install_gdrcopy.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Usage: install_gdrcopy.sh <GDRCOPY_OS_VERSION> <GDRCOPY_CUDA_VERSION> <uuarch> +# uuarch must be "x64" or "aarch64" +# Optional: set GDRCOPY_VERSION to override the libgdrapi package version (default: 2.5.1-1) +# Requires: curl, apt-get, root privileges +if [[ $(id -u) -ne 0 ]]; then + echo "Must be run as root" >&2 + + exit 1 +fi +if [[ $# -ne 3 ]]; then + echo "Usage: $0 <GDRCOPY_OS_VERSION> <GDRCOPY_CUDA_VERSION> <uuarch(x64|aarch64)>" >&2 + exit 1 +fi + +OS_VER="$1" +CUDA_VER="$2" +UUARCH_RAW="$3" + +# Normalize/validate arch +case "${UUARCH_RAW,,}" in + aarch64|arm64) + URL_ARCH="aarch64" + DEB_ARCH="arm64" + ;; + x64|x86_64|amd64) + URL_ARCH="x64" + DEB_ARCH="amd64" + ;; + *) + echo "Unsupported uuarch: ${UUARCH_RAW}. Use 'x64' or 'aarch64'." >&2 + exit 1 + ;; +esac + +OS_VER_LOWER="$(tr '[:upper:]' '[:lower:]' <<<"$OS_VER")" +GDRCOPY_PKG_VER="${GDRCOPY_VERSION:-2.5.1-1}" + +DEB_NAME="libgdrapi_${GDRCOPY_PKG_VER}_${DEB_ARCH}.${OS_VER}.deb" +BASE_URL="https://developer.download.nvidia.com/compute/redist/gdrcopy" +URL="${BASE_URL}/CUDA%20${CUDA_VER}/${OS_VER_LOWER}/${URL_ARCH}/${DEB_NAME}" + +echo "Downloading: ${URL}" +TMPDIR="$(mktemp -d)" +trap 'rm -rf "${TMPDIR}"' EXIT + +curl -fSL "${URL}" -o "${TMPDIR}/${DEB_NAME}" + +export DEBIAN_FRONTEND=noninteractive +apt-get update +apt-get install -y "${TMPDIR}/${DEB_NAME}" +apt-get clean +rm -rf /var/lib/apt/lists/* + +echo "Installed ${DEB_NAME}" diff --git a/tools/install_nixl.sh b/tools/install_nixl.sh deleted file mode 100644 index 56717cfb77f7b..0000000000000 --- a/tools/install_nixl.sh +++ /dev/null @@ -1,109 +0,0 @@ -#!/bin/bash -# Usage: ./install_nixl.sh [--force] - -FORCE=false -if [ "$1" == "--force" ]; then - FORCE=true -fi - -SUDO=false -if command -v sudo >/dev/null 2>&1 && sudo -n true 2>/dev/null; then - SUDO=true -fi - -ARCH=$(uname -m) - -ROOT_DIR="/usr/local" -mkdir -p "$ROOT_DIR" -GDR_HOME="$ROOT_DIR/gdrcopy" -UCX_HOME="$ROOT_DIR/ucx" -NIXL_HOME="$ROOT_DIR/nixl" -CUDA_HOME=/usr/local/cuda - -export PATH="$GDR_HOME/bin:$UCX_HOME/bin:$NIXL_HOME/bin:$PATH" -export LD_LIBRARY_PATH="$GDR_HOME/lib:$UCX_HOME/lib:$NIXL_HOME/lib/$ARCH-linux-gnu:$LD_LIBRARY_PATH" - -TEMP_DIR="nixl_installer" -mkdir -p "$TEMP_DIR" -cd "$TEMP_DIR" - -pip install meson ninja pybind11 - -if [ ! -e "/dev/gdrdrv" ] || [ "$FORCE" = true ]; then - echo "Installing gdrcopy\n" - wget https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v2.5.tar.gz - tar xzf v2.5.tar.gz; rm v2.5.tar.gz - cd gdrcopy-2.5 - make prefix=$GDR_HOME CUDA=$CUDA_HOME all install - - if $SUDO; then - echo "Running insmod.sh with sudo" - sudo ./insmod.sh - else - echo "Skipping insmod.sh - sudo not available" - echo "Please run 'sudo ./gdrcopy-2.5/insmod.sh' manually if needed" - fi - - cd .. -else - echo "Found /dev/gdrdrv. Skipping gdrcopy installation" -fi - -if ! command -v ucx_info &> /dev/null || [ "$FORCE" = true ]; then - echo "Installing UCX" - wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz - tar xzf ucx-1.18.0.tar.gz; rm ucx-1.18.0.tar.gz - cd ucx-1.18.0 - - # Checking Mellanox NICs - MLX_OPTS="" - if lspci | grep -i mellanox > /dev/null || command -v ibstat > /dev/null; then - echo "Mellanox NIC detected, adding Mellanox-specific options" - MLX_OPTS="--with-rdmacm \ - --with-mlx5-dv \ - --with-ib-hw-tm" - fi - - ./configure --prefix=$UCX_HOME \ - --enable-shared \ - --disable-static \ - --disable-doxygen-doc \ - --enable-optimizations \ - --enable-cma \ - --enable-devel-headers \ - --with-cuda=$CUDA_HOME \ - --with-dm \ - --with-gdrcopy=$GDR_HOME \ - --with-verbs \ - --enable-mt \ - $MLX_OPTS - make -j - make -j install-strip - - if $SUDO; then - echo "Running ldconfig with sudo" - sudo ldconfig - else - echo "Skipping ldconfig - sudo not available" - echo "Please run 'sudo ldconfig' manually if needed" - fi - - cd .. -else - echo "Found existing UCX. Skipping UCX installation" -fi - -if ! command -v nixl_test &> /dev/null || [ "$FORCE" = true ]; then - echo "Installing NIXL" - wget https://github.com/ai-dynamo/nixl/archive/refs/tags/0.2.0.tar.gz - tar xzf 0.2.0.tar.gz; rm 0.2.0.tar.gz - cd nixl-0.2.0 - meson setup build --prefix=$NIXL_HOME -Ducx_path=$UCX_HOME - cd build - ninja - ninja install - - cd ../.. -else - echo "Found existing NIXL. Skipping NIXL installation" -fi diff --git a/tools/mypy.sh b/tools/mypy.sh index 781d8fc02884b..63e3b9a916634 100755 --- a/tools/mypy.sh +++ b/tools/mypy.sh @@ -29,7 +29,7 @@ run_mypy vllm/engine run_mypy vllm/executor run_mypy vllm/inputs run_mypy vllm/lora -run_mypy vllm/model_executor +run_mypy --exclude 'vllm/model_executor/layers/fla/ops' vllm/model_executor run_mypy vllm/plugins run_mypy vllm/worker run_mypy vllm/v1 diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index 038d3c44f043a..30d6547073d38 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -119,7 +119,7 @@ def attempt_to_make_names_unique(entries_and_traces): if not all_the_same(trace_eles)), None) if first_trace_difference is None: - # can't create a unique name, leave them names as the + # can't create a unique name, leave the names as they # are they will get aggregated by the pivot_table call continue diff --git a/tools/validate_config.py b/tools/validate_config.py index 8b1e955c653d7..f6439fa9ada5f 100644 --- a/tools/validate_config.py +++ b/tools/validate_config.py @@ -9,6 +9,8 @@ import ast import inspect import sys +import regex as re + def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: """ @@ -88,11 +90,12 @@ def validate_class(class_node: ast.ClassDef): for stmt in class_node.body: # A field is defined as a class variable that has a type annotation. if isinstance(stmt, ast.AnnAssign): - # Skip ClassVar + # Skip ClassVar and InitVar # see https://docs.python.org/3/library/dataclasses.html#class-variables - if isinstance(stmt.annotation, ast.Subscript) and isinstance( - stmt.annotation.value, - ast.Name) and stmt.annotation.value.id == "ClassVar": + # and https://docs.python.org/3/library/dataclasses.html#init-only-variables + if (isinstance(stmt.annotation, ast.Subscript) + and isinstance(stmt.annotation.value, ast.Name) + and stmt.annotation.value.id in {"ClassVar", "InitVar"}): continue if isinstance(stmt.target, ast.Name): @@ -132,7 +135,7 @@ def validate_ast(tree: ast.stmt): def validate_file(file_path: str): try: - print(f"validating {file_path} config dataclasses ", end="") + print(f"Validating {file_path} config dataclasses ", end="") with open(file_path, encoding="utf-8") as f: source = f.read() @@ -140,7 +143,7 @@ def validate_file(file_path: str): validate_ast(tree) except ValueError as e: print(e) - SystemExit(2) + raise SystemExit(1) from e else: print("✅") @@ -151,7 +154,13 @@ def fail(message: str, node: ast.stmt): def main(): for filename in sys.argv[1:]: - validate_file(filename) + # Only run for Python files in vllm/ or tests/ + if not re.match(r"^(vllm|tests)/.*\.py$", filename): + continue + # Only run if the file contains @config + with open(filename, encoding="utf-8") as f: + if "@config" in f.read(): + validate_file(filename) if __name__ == "__main__": diff --git a/use_existing_torch.py b/use_existing_torch.py index a9f79e16981c4..76480f3e58fee 100644 --- a/use_existing_torch.py +++ b/use_existing_torch.py @@ -18,4 +18,4 @@ for file in requires_files: else: print(line.strip()) print(f"<<< done cleaning {file}") - print() + print() \ No newline at end of file diff --git a/vllm/__init__.py b/vllm/__init__.py index 7b90fd3a241bd..3a5c1b1ce0daf 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -14,6 +14,8 @@ import typing import vllm.env_override # noqa: F401 MODULE_ATTRS = { + "bc_linter_skip": "._bc_linter:bc_linter_skip", + "bc_linter_include": "._bc_linter:bc_linter_include", "AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs", "EngineArgs": ".engine.arg_utils:EngineArgs", "AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine", @@ -54,6 +56,8 @@ if typing.TYPE_CHECKING: ScoringRequestOutput) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams + + from ._bc_linter import bc_linter_include, bc_linter_skip else: def __getattr__(name: str) -> typing.Any: @@ -70,6 +74,8 @@ else: __all__ = [ "__version__", + "bc_linter_skip", + "bc_linter_include", "__version_tuple__", "LLM", "ModelRegistry", diff --git a/vllm/_bc_linter.py b/vllm/_bc_linter.py new file mode 100644 index 0000000000000..52a95dbee1866 --- /dev/null +++ b/vllm/_bc_linter.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# vllm/_bc_linter.py +from __future__ import annotations + +from typing import Any, Callable, TypeVar, overload + +T = TypeVar("T") + + +@overload +def bc_linter_skip(obj: T) -> T: + ... + + +@overload +def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: + ... + + +def bc_linter_skip(obj: Any = None, *, reason: str | None = None): + """ + No-op decorator to mark symbols/files for BC-linter suppression. + + Usage: + @bc_linter_skip + def legacy_api(...): ... + """ + + def _wrap(x: T) -> T: + return x + + return _wrap if obj is None else obj + + +@overload +def bc_linter_include(obj: T) -> T: + ... + + +@overload +def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: + ... + + +def bc_linter_include(obj: Any = None, *, reason: str | None = None): + """ + Usage: + @bc_linter_include + def public_api(...): ... + """ + + def _wrap(x: T) -> T: + return x + + return _wrap if obj is None else obj + + +__all__ = ["bc_linter_skip", "bc_linter_include"] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 340d6e1164e4f..712295aa92886 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -117,13 +117,14 @@ def paged_attention_rocm( k_scale: torch.Tensor, v_scale: torch.Tensor, fp8_out_scale: Optional[torch.Tensor] = None, + mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16", ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, query_start_loc, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, - v_scale, fp8_out_scale) + v_scale, fp8_out_scale, mfma_type) def mla_decode_kvcache_cpu( @@ -257,16 +258,6 @@ def rotary_embedding( cos_sin_cache, is_neox) -def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor], head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) - - # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: @@ -280,6 +271,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def poly_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + bias: torch.Tensor, epsilon: float) -> None: + # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + input_contiguous = input.contiguous() + torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon) + + def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: @@ -710,6 +708,7 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) + def cutlass_sparse_compress(a: torch.Tensor) \ -> tuple[torch.Tensor, torch.Tensor]: """ @@ -1625,20 +1624,6 @@ def concat_and_cache_mla( scale) -def cp_fused_concat_and_cache_mla( - kv_c: torch.Tensor, - k_pe: torch.Tensor, - cp_local_token_select_indices: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - scale: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla( - kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping, - kv_cache_dtype, scale) - - def copy_blocks(key_caches: list[torch.Tensor], value_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: @@ -1838,22 +1823,13 @@ def flash_mla_with_kvcache( return out, softmax_lse -def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - seq_lens: torch.Tensor, page_table: torch.Tensor, - scale: float) -> torch.Tensor: - torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale) - return out - - -def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, +def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor, + q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, scale: float, num_kv_splits: int) -> torch.Tensor: - torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe, + torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, scale, num_kv_splits) @@ -1928,6 +1904,35 @@ class CPUDNNLGEMMHandler: torch.ops._C.release_dnnl_matmul_handler(self.handler) +if hasattr(torch.ops._C, "create_onednn_mm_handler"): + _supports_onednn = True +else: + _supports_onednn = False + + +def create_onednn_mm( + weight: torch.Tensor, # [K, N] + primitive_cache_size: int = 128, +) -> CPUDNNLGEMMHandler: + handler = CPUDNNLGEMMHandler() + handler.k, handler.n = weight.size() + handler.handler = torch.ops._C.create_onednn_mm_handler( + weight, primitive_cache_size) + return handler + + +def onednn_mm( + dnnl_handler: CPUDNNLGEMMHandler, + x: torch.Tensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) + torch.ops._C.onednn_mm(output, x.reshape(-1, dnnl_handler.k), bias, + dnnl_handler.handler) + + return output + + def create_onednn_scaled_mm( weight: torch.Tensor, # [K, N] weight_scales: torch.Tensor, @@ -1997,3 +2002,27 @@ def onednn_scaled_mm( input_zp_adj, bias, dnnl_handler.handler) return output + + +def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor: + """ + Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832) + kernels. Note that these kernels exploit the recursive properties of + Sylvester Hadamards, and therefore do not require transform weight data + + Note that sylvester hadamard transforms are also symmetric, which means that + this function is also applies the (transpose <=> inverse) transform. + + :param x: value to be transformed inplace + :param inplace: modify value in place + :return: value after transformation + """ + return torch.ops._C.hadacore_transform(x, inplace) + + +if hasattr(torch.ops._C, "hadacore_transform"): + + @register_fake("_C::hadacore_transform") + def _hadacore_transform_fake(x: torch.Tensor, + inplace: bool) -> torch.Tensor: + return torch.empty_like(x) if not inplace else x diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 19f6c4e3060ce..9d2eda482fcf8 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -13,7 +13,7 @@ logger = init_logger(__name__) try: import intel_extension_for_pytorch as ipex except ImportError as e: - logger.warning("Import error msg: %s", e.msg) + logger.debug("Import error msg: %s", e.msg) class ipex_ops: @@ -148,17 +148,6 @@ class ipex_ops: head_size, cos_sin_cache, is_neox, rot_dim) - @staticmethod - def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, head_size: int, - cos_sin_cache: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor) -> None: - ipex.llm.functional.rotary_embedding_batched(positions, query, key, - head_size, cos_sin_cache, - is_neox, rot_dim, - cos_sin_cache_offsets) - @staticmethod def rms_norm(input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> torch.Tensor: @@ -242,10 +231,9 @@ class ipex_ops: k_scale_float: float = 1.0, v_scale_float: float = 1.0, ) -> None: - assert kv_cache_dtype == "auto" - # TODO: support FP8 kv cache. ipex.llm.modules.PagedAttention.reshape_and_cache_flash( - key, value, key_cache, value_cache, slot_mapping) + key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, + k_scale_float, v_scale_float) @staticmethod def flash_attn_varlen_func( diff --git a/vllm/adapter_commons/__init__.py b/vllm/adapter_commons/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py deleted file mode 100644 index 9753a08806565..0000000000000 --- a/vllm/adapter_commons/layers.py +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass - - -@dataclass -class AdapterMapping: - # Per every token in input_ids: - index_mapping: tuple[int, ...] - # Per sampled token: - prompt_mapping: tuple[int, ...] - - def __post_init__(self): - self.index_mapping = tuple(self.index_mapping) - self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py deleted file mode 100644 index 7b685880a9e6c..0000000000000 --- a/vllm/adapter_commons/models.py +++ /dev/null @@ -1,106 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, TypeVar - -from torch import nn - -from vllm.logger import init_logger -from vllm.utils import LRUCache - -logger = init_logger(__name__) - - -class AdapterModel(ABC): - - def __init__(self, model_id=None): - self.id = model_id - - @abstractmethod - def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): - # Common initialization code - # Load weights or embeddings from local checkpoint - raise NotImplementedError("Subclasses must implement this method.") - - -T = TypeVar('T') - - -class AdapterLRUCache(LRUCache[int, T]): - - def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): - super().__init__(capacity) - self.deactivate_fn = deactivate_fn - - def _on_remove(self, key: int, value: Optional[T]): - logger.debug("Removing adapter int id: %d", key) - self.deactivate_fn(key) - return super()._on_remove(key, value) - - -class AdapterModelManager(ABC): - - def __init__( - self, - model: nn.Module, - ): - """Create a AdapterModelManager and adapter for a given model. - Args: - model: the model to be adapted. - """ - self.model: nn.Module = model - self._registered_adapters: dict[int, Any] = {} - # Dict instead of a Set for compatibility with LRUCache. - self._active_adapters: dict[int, None] = {} - self.adapter_type = 'Adapter' - self._last_mapping = None - - def __len__(self) -> int: - return len(self._registered_adapters) - - @property - @abstractmethod - def adapter_slots(self) -> int: - raise NotImplementedError - - @property - @abstractmethod - def capacity(self) -> int: - raise NotImplementedError - - @abstractmethod - def activate_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def deactivate_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def add_adapter(self, adapter: Any) -> bool: - raise NotImplementedError - - @abstractmethod - def set_adapter_mapping(self, mapping: Any) -> None: - raise NotImplementedError - - @abstractmethod - def remove_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def remove_all_adapters(self) -> None: - raise NotImplementedError - - @abstractmethod - def get_adapter(self, adapter_id: int) -> Optional[Any]: - raise NotImplementedError - - @abstractmethod - def list_adapters(self) -> dict[int, Any]: - raise NotImplementedError - - @abstractmethod - def pin_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py deleted file mode 100644 index 8135b54ba19f6..0000000000000 --- a/vllm/adapter_commons/request.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod - - -class AdapterRequest(ABC): - """ - Base class for adapter requests. - """ - - @property - @abstractmethod - def adapter_id(self) -> int: - raise NotImplementedError - - def __post_init__(self) -> None: - if self.adapter_id < 1: - raise ValueError(f"id must be > 0, got {self.adapter_id}") - - def __eq__(self, value: object) -> bool: - return isinstance( - value, self.__class__) and self.adapter_id == value.adapter_id - - def __hash__(self) -> int: - return hash(self.adapter_id) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py deleted file mode 100644 index a1a56b6bbd4ba..0000000000000 --- a/vllm/adapter_commons/utils.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Callable, Optional - - -## model functions -def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None], - deactivate_func: Callable) -> bool: - if adapter_id in active_adapters: - deactivate_func(adapter_id) - active_adapters.pop(adapter_id) - return True - return False - - -def add_adapter(adapter: Any, registered_adapters: dict[int, Any], - capacity: int, add_func: Callable) -> bool: - if adapter.id not in registered_adapters: - if len(registered_adapters) >= capacity: - raise RuntimeError('No free adapter slots.') - add_func(adapter) - registered_adapters[adapter.id] = adapter - return True - return False - - -def set_adapter_mapping(mapping: Any, last_mapping: Any, - set_mapping_func: Callable) -> Any: - if last_mapping != mapping: - set_mapping_func(mapping) - return mapping - return last_mapping - - -def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any], - deactivate_func: Callable) -> bool: - deactivate_func(adapter_id) - return bool(registered_adapters.pop(adapter_id, None)) - - -def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]: - return dict(registered_adapters) - - -def get_adapter(adapter_id: int, - registered_adapters: dict[int, Any]) -> Optional[Any]: - return registered_adapters.get(adapter_id) - - -## worker functions -def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any], - apply_adapters_func, - set_adapter_mapping_func) -> None: - apply_adapters_func(requests) - set_adapter_mapping_func(mapping) - - -def add_adapter_worker(adapter_request: Any, list_adapters_func, - load_adapter_func, add_adapter_func, - activate_adapter_func) -> bool: - if adapter_request.adapter_id in list_adapters_func(): - return False - loaded_adapter = load_adapter_func(adapter_request) - loaded = add_adapter_func(loaded_adapter) - activate_adapter_func(loaded_adapter.id) - return loaded - - -def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func, - adapter_slots: int, remove_adapter_func, - add_adapter_func) -> None: - models_that_exist = list_adapters_func() - models_map = { - adapter_request.adapter_id: adapter_request - for adapter_request in adapter_requests if adapter_request - } - if len(models_map) > adapter_slots: - raise RuntimeError( - f"Number of requested models ({len(models_map)}) is greater " - f"than the number of GPU model slots " - f"({adapter_slots}).") - new_models = set(models_map) - models_to_add = new_models - models_that_exist - models_to_remove = models_that_exist - new_models - for adapter_id in models_to_remove: - remove_adapter_func(adapter_id) - for adapter_id in models_to_add: - add_adapter_func(models_map[adapter_id]) - - -def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]: - return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py deleted file mode 100644 index 07e85d138ac50..0000000000000 --- a/vllm/adapter_commons/worker_manager.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from typing import Any, Optional - -import torch - - -class AbstractWorkerManager(ABC): - - def __init__(self, device: torch.device): - self.device = device - - @property - @abstractmethod - def is_enabled(self) -> bool: - raise NotImplementedError - - @abstractmethod - def set_active_adapters(self, requests: set[Any], - mapping: Optional[Any]) -> None: - raise NotImplementedError - - @abstractmethod - def add_adapter(self, adapter_request: Any) -> bool: - raise NotImplementedError - - @abstractmethod - def remove_adapter(self, adapter_id: int) -> bool: - raise NotImplementedError - - @abstractmethod - def remove_all_adapters(self) -> None: - raise NotImplementedError - - @abstractmethod - def list_adapters(self) -> set[int]: - raise NotImplementedError diff --git a/vllm/assets/image.py b/vllm/assets/image.py index c8f8d43a98355..4639a11187d03 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from pathlib import Path from typing import Literal import torch @@ -11,17 +12,29 @@ from .base import get_vllm_public_assets VLM_IMAGES_DIR = "vision_model_images" -ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato"] +ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato", + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk", + "Grayscale_8bits_palette_sample_image", + "1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300", + "231-200x300", "27-500x500", "17-150x600", + "handelsblatt-preview", "paper-11"] @dataclass(frozen=True) class ImageAsset: name: ImageAssetName + def get_path(self, ext: str) -> Path: + """ + Return s3 path for given image. + """ + return get_vllm_public_assets(filename=f"{self.name}.{ext}", + s3_prefix=VLM_IMAGES_DIR) + @property - def pil_image(self) -> Image.Image: - image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", - s3_prefix=VLM_IMAGES_DIR) + def pil_image(self, ext="jpg") -> Image.Image: + + image_path = self.get_path(ext) return Image.open(image_path) @property @@ -29,6 +42,9 @@ class ImageAsset: """ Image embeddings, only used for testing purposes with llava 1.5. """ - image_path = get_vllm_public_assets(filename=f"{self.name}.pt", - s3_prefix=VLM_IMAGES_DIR) + image_path = self.get_path('pt') return torch.load(image_path, map_location="cpu", weights_only=True) + + def read_bytes(self, ext: str) -> bytes: + p = Path(self.get_path(ext)) + return p.read_bytes() diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 8ab0e9760be87..5c9e403c4b91f 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -76,7 +76,7 @@ def video_to_pil_images_list(path: str, return [Image.fromarray(frame) for frame in frames] -def video_get_metadata(path: str) -> dict[str, Any]: +def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]: cap = cv2.VideoCapture(path) if not cap.isOpened(): raise ValueError(f"Could not open video file {path}") @@ -85,11 +85,18 @@ def video_get_metadata(path: str) -> dict[str, Any]: fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames / fps if fps > 0 else 0 + if num_frames == -1 or num_frames > total_frames: + num_frames = total_frames + metadata = { - "total_num_frames": total_frames, + "total_num_frames": num_frames, "fps": fps, "duration": duration, - "video_backend": "opencv" + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + # extra field used to control hf processor's video + # sampling behavior + "do_sample_frames": num_frames == total_frames, } return metadata @@ -110,22 +117,23 @@ class VideoAsset: def filename(self) -> str: return self._NAME_TO_FILE[self.name] + @property + def video_path(self) -> str: + return download_video_asset(self.filename) + @property def pil_images(self) -> list[Image.Image]: - video_path = download_video_asset(self.filename) - ret = video_to_pil_images_list(video_path, self.num_frames) + ret = video_to_pil_images_list(self.video_path, self.num_frames) return ret @property def np_ndarrays(self) -> npt.NDArray: - video_path = download_video_asset(self.filename) - ret = video_to_ndarrays(video_path, self.num_frames) + ret = video_to_ndarrays(self.video_path, self.num_frames) return ret @property def metadata(self) -> dict[str, Any]: - video_path = download_video_asset(self.filename) - ret = video_get_metadata(video_path) + ret = video_get_metadata(self.video_path, self.num_frames) return ret def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: @@ -134,5 +142,4 @@ class VideoAsset: See also: examples/offline_inference/qwen2_5_omni/only_thinker.py """ - video_path = download_video_asset(self.filename) - return librosa.load(video_path, sr=sampling_rate)[0] + return librosa.load(self.video_path, sr=sampling_rate)[0] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0b9c625533cb7..75bcdc4bbcf0d 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -240,6 +240,7 @@ class AttentionLayer(Protocol): _q_scale: torch.Tensor _k_scale: torch.Tensor _v_scale: torch.Tensor + _q_scale_float: float _k_scale_float: float _v_scale_float: float _prob_scale: torch.Tensor @@ -257,6 +258,32 @@ class AttentionLayer(Protocol): class AttentionImpl(ABC, Generic[T]): + # Whether the attention impl can return the softmax lse for decode. + # Some features like decode context parallelism require the softmax lse. + can_return_lse_for_decode: bool = False + + # some attention backends might not always want to return lse + # even if they can return lse (for efficiency reasons) + need_to_return_lse_for_decode: bool = False + + dcp_world_size: int + dcp_rank: int + + def __new__(cls, *args, **kwargs): + # use __new__ so that all subclasses will call this + self = super().__new__(cls) + try: + from vllm.distributed.parallel_state import get_dcp_group + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \ + and self.can_return_lse_for_decode + return self + @abstractmethod def __init__( self, diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index caa02530d2fd6..a7d0e3afb517f 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -734,6 +734,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, ) assert prefill_output.shape == output[: num_prefill_tokens].shape @@ -755,6 +756,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, ).squeeze(1) except Exception as e: logger.error("Error in PagedAttention.forward_decode: %s", @@ -787,6 +789,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, ).squeeze(1) return output diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d8cb208c4f2ea..78c768f92d3c2 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -901,7 +901,7 @@ def _get_query_key_seq_metadata( attn_metadata.encoder_seq_start_loc, attn_metadata.max_encoder_seq_len) elif attn_type == AttentionType.ENCODER: - # For encoder attention both the query and the key are same i.e the + # For encoder attention both the query and the key are same i.e. the # encoder sequence. return (attn_metadata.encoder_seq_start_loc, attn_metadata.max_encoder_seq_len, diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index f23c096952ce0..aeaa0ab631cfb 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -178,8 +178,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - assert is_flashmla_supported(), \ - "FlashMLA is not supported on this device" + is_supported, reason = is_flashmla_supported() + assert is_supported, reason unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index c5ed4c6e40326..789393eb39a73 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -822,7 +822,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): and context_lens_tensor is not None \ and context_lens_tensor[:self.num_prefills].max() > 0: - # NOTE: it is recommend you read the `Chunked Prefill` section in + # NOTE: it is recommended you read the `Chunked Prefill` section in # the comment at the top of the file before trying to understand # the following code @@ -1052,7 +1052,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): return layer.weight # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 237802afccde9..15c0ce33e9659 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -23,8 +23,9 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend, current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils import GiB_bytes, direct_register_custom_op logger = init_logger(__name__) USE_XFORMERS_OPS = None @@ -55,6 +56,14 @@ def check_xformers_availability(): return USE_XFORMERS_OPS +def check_upstream_fa_availability(dtype: torch.dtype): + if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda( + ) and current_platform.has_device_capability(80): + from transformers.utils import is_flash_attn_2_available + return is_flash_attn_2_available() + return False + + class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -216,9 +225,26 @@ class Attention(nn.Module, AttentionLayerBase): ).parallel_config.pipeline_parallel_size) ] - self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + try: + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, + dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, + dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, + dtype=torch.float32) + except torch.cuda.OutOfMemoryError as e: + logger.error( + "Failed to initialize attention q/k/v range constants: %s", e) + if torch.cuda.is_available(): + logger.debug("CUDA device: %s", torch.cuda.current_device()) + logger.debug("Allocated: %.2f GiB", + torch.cuda.memory_allocated() / GiB_bytes) + logger.debug("Reserved: %.2f GiB", + torch.cuda.memory_reserved() / GiB_bytes) + raise RuntimeError( + "Failed to initialize q/k/v range constants. " + "This may be caused by insufficient memory to allocate " + "kv cache.") from e def forward( self, @@ -349,29 +375,55 @@ class MultiHeadAttention(nn.Module): f"divisible by num_kv_heads ({self.num_kv_heads})" self.num_queries_per_kv = self.num_heads // self.num_kv_heads + # During model initialization, the default dtype is set as the model + # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype=None, - block_size=16, - is_attention_free=False) - backend = backend_name_to_enum(attn_backend.get_name()) - if current_platform.is_rocm(): - # currently, only torch_sdpa is supported on rocm + + # Determine the attention backend + backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) + + # Some auto-selected backends can be upgraded + # to upstream flash attention if available. + # If vllm native fa is selected, we use it directly. + use_upstream_fa = False + if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + dtype): + backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if current_platform.is_rocm() or current_platform.is_xpu(): + # currently, only torch_sdpa is supported on rocm/xpu self.attn_backend = _Backend.TORCH_SDPA else: - if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, - _Backend.FLEX_ATTENTION): - backend = _Backend.XFORMERS self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 + _Backend.TORCH_SDPA, + _Backend.TORCH_SDPA_VLLM_V1, + _Backend.XFORMERS, + _Backend.PALLAS_VLLM_V1, + _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, } else _Backend.TORCH_SDPA if (self.attn_backend == _Backend.XFORMERS and not check_xformers_availability()): self.attn_backend = _Backend.TORCH_SDPA + if self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1 + }: + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + self._flash_attn_varlen_func = flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + self._flash_attn_varlen_func = flash_attn_varlen_func + + logger.info_once( + f"MultiHeadAttention attn_backend: {self.attn_backend}, " + f"use_upstream_fa: {use_upstream_fa}") + def forward( self, query: torch.Tensor, @@ -392,14 +444,39 @@ class MultiHeadAttention(nn.Module): key = torch.repeat_interleave(key, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2) - if self.attn_backend == _Backend.XFORMERS: + if self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, + }: + + cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=query.device) + cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, + step=kv_len, + dtype=torch.int32, + device=key.device) + + out = self._flash_attn_varlen_func( + query.flatten(0, 1), + key.flatten(0, 1), + value.flatten(0, 1), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + softmax_scale=self.scale, + ) + elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(query, key, value, scale=self.scale) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif (self.attn_backend == _Backend.TORCH_SDPA + or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1): query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query, @@ -413,6 +490,19 @@ class MultiHeadAttention(nn.Module): from torch_xla.experimental.custom_kernel import flash_attention out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) + elif self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + + # ROCm Flash Attention expects (batch, seq, heads, head_dim) + out = flash_attn_varlen_func(query, + key, + value, + softmax_scale=self.scale) + else: + # ViT attention hasn't supported this backend yet + raise NotImplementedError( + f"ViT attention hasn't supported {self.attn_backend} " + f"backend yet.") return out.reshape(bsz, q_len, -1) diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py new file mode 100644 index 0000000000000..9400c5bffa380 --- /dev/null +++ b/vllm/attention/layers/cross_attention.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from copy import copy +from typing import Optional + +import numpy as np +import torch + +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, VllmConfig +from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + subclass_attention_backend) +from vllm.v1.kv_cache_interface import CrossAttentionSpec + +logger = init_logger(__name__) + + +def _get_max_encoder_len(vllm_config: "VllmConfig") -> int: + """Gets the max number of encoder input tokens from the config. + """ + sc = vllm_config.scheduler_config + assert sc and isinstance(sc.max_num_encoder_input_tokens, int), \ + "max_num_encoder_input_tokens must be int for enc-dec models" + return sc.max_num_encoder_input_tokens + + +def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray, + block_table_tensor: torch.Tensor, + kv_cache_spec: CrossAttentionSpec, + device: torch.device) -> torch.Tensor: + """Get cross-attention slot mappings.""" + + block_size = kv_cache_spec.block_size + slot_mappings = [] + + # Find indices with non-zero encoder sequence lengths + # The majority of parallel requests will be running the + # decoder, so this list should be relatively small. + active_indices = np.nonzero(encoder_seq_lens)[0] + + for req_index in active_indices: + encoder_seq_len = encoder_seq_lens[req_index].item() + + # Calculate the number of blocks needed for this request + num_blocks_needed = cdiv(encoder_seq_len, block_size) + + # Get the block IDs for this request from the tensor + req_block_ids = block_table_tensor[req_index] + + # Get only the blocks we need (first num_blocks_needed blocks) + needed_block_ids = req_block_ids[:num_blocks_needed] + + # All needed blocks are allocated + i_values = torch.arange(encoder_seq_len, + dtype=torch.int64, + device=device) + block_indices = i_values // block_size + block_offsets = i_values % block_size + block_numbers = needed_block_ids[block_indices] + slot_mapping = block_numbers * block_size + block_offsets + + slot_mappings.append(slot_mapping) + + if slot_mappings: + return torch.cat(slot_mappings) + else: + return torch.empty(0, dtype=torch.int64, device=device) + + +@functools.lru_cache +def create_cross_attention_backend( + underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + prefix = "CrossAttention_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class CrossAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_metadata = copy(common_attn_metadata) + new_metadata.causal = False + max_encoder_len = _get_max_encoder_len(self.vllm_config) + new_metadata.max_seq_len = max_encoder_len + + new_metadata.seq_lens = torch.full( + (new_metadata.num_reqs, ), + max_encoder_len, + dtype=torch.int32, + device=self.device, + ) + new_metadata.seq_lens_cpu = torch.full( + (new_metadata.num_reqs, ), + max_encoder_len, + dtype=torch.int32, + device="cpu", + ) + new_metadata.slot_mapping = _get_cross_slot_mapping( + new_metadata.encoder_seq_lens, new_metadata.block_table_tensor, + self.kv_cache_spec, self.device) + return super().build(common_prefix_len, new_metadata, fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=CrossAttentionBuilder) + + return attn_backend + + +class CrossAttention(Attention): + """ + Cross-attention for encoder-decoder models. + Handles attention between decoder queries and encoder keys/values. + """ + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_cross_attention_backend( + underlying_attn_backend) + else: + # in v0 cross attention is handled inside the backends + attn_backend = None + + if attn_type is not None: + assert attn_type == AttentionType.ENCODER_DECODER, ( + "CrossAttention only supports AttentionType.ENCODER_DECODER") + + super().__init__(num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_DECODER, + **kwargs) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index e5b90a8b27558..bf4b06512a3c1 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -15,6 +15,8 @@ from vllm.triton_utils import tl, triton from .prefix_prefill import context_attention_fwd +float8_info = torch.finfo(current_platform.fp8_dtype()) + @triton.jit def cdiv_fn(x, y): @@ -34,6 +36,7 @@ def kernel_paged_attention_2d( scale, # float32 k_scale, # float32 v_scale, # float32 + out_scale_inv, num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int num_queries_per_kv_padded: tl.constexpr, # int @@ -60,7 +63,9 @@ def kernel_paged_attention_2d( filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, # [num_seqs+1] USE_SINKS: tl.constexpr, # bool -): + USE_FP8: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -204,6 +209,9 @@ def kernel_paged_attention_2d( # epilogue acc = acc / L[:, None] + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) output_offset = (cur_batch_in_all_start_index * output_stride_0 + query_head_idx * output_stride_1) @@ -234,6 +242,7 @@ def chunked_prefill_paged_decode( alibi_slopes=None, sliding_window=None, sm_scale=None, + output_scale=None, # Optional tensor for sinks sinks=None, ): @@ -266,6 +275,7 @@ def chunked_prefill_paged_decode( sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, + fp8_out_scale=output_scale, sinks=sinks, ) @@ -316,7 +326,7 @@ def chunked_prefill_paged_decode( tmp_output = torch.empty( size=(total_num_seq, num_query_heads, max_num_partitions, head_size), - dtype=output.dtype, + dtype=query.dtype, device=output.device, ) exp_sums = torch.empty( @@ -345,6 +355,7 @@ def chunked_prefill_paged_decode( kv_cache_dtype=kv_cache_dtype, k_scale=k_scale, v_scale=v_scale, + fp8_out_scale=output_scale, ) else: kernel_paged_attention_2d[( @@ -362,6 +373,8 @@ def chunked_prefill_paged_decode( scale=sm_scale, k_scale=k_scale, v_scale=v_scale, + out_scale_inv=1.0 / + output_scale if output_scale is not None else 1.0, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, num_queries_per_kv_padded=num_queries_per_kv_padded, @@ -388,4 +401,5 @@ def chunked_prefill_paged_decode( filter_by_query_len=True, query_start_len_ptr=query_start_loc, USE_SINKS=sinks is not None, + USE_FP8=output_scale is not None, ) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py new file mode 100644 index 0000000000000..189b57e8e8b82 --- /dev/null +++ b/vllm/attention/ops/common.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.triton_utils import tl, triton + + +@triton.jit +def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr, + vlse_ptr, outputs_stride_B, outputs_stride_H, + outputs_stride_D, lses_stride_N, lses_stride_B, + lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr, + N_ROUNDED: tl.constexpr): + """ + Apply the all-gathered lses to correct each local rank's attention + output. we still need perform a cross-rank reduction to obtain the + final attention output. + + Args: + output: [ B, H, D ] + lses : [ N, B, H ] + cp, batch, q_heads, v_head_dim + Return: + output: [ B, H, D ] + lse : [ B, H ] + """ + batch_idx = tl.program_id(axis=0).to(tl.int64) + head_idx = tl.program_id(axis=1).to(tl.int64) + d_offsets = tl.arange(0, HEAD_DIM) + num_n_offsets = tl.arange(0, N_ROUNDED) + + # shape = [N] + lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \ + lses_stride_B + head_idx * lses_stride_H + + # calc final lse + lse = tl.load(lses_ptr + lse_offsets) + lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse) + lse_max = tl.max(lse, axis=0) + lse -= lse_max + lse_exp = tl.exp(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log(lse_acc) + lse += lse_max + + lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H + tl.store(vlse_ptr + lse_offsets, lse) + + # shape = [D] + output_offsets = batch_idx * outputs_stride_B + \ + head_idx * outputs_stride_H + \ + d_offsets * outputs_stride_D + + # correct output + lse_offset = lse_idx * lses_stride_N + batch_idx * \ + lses_stride_B + head_idx * lses_stride_H + lse_tmp = tl.load(lses_ptr + lse_offset) + lse_finally = lse_tmp - lse + lse_finally = tl.where( + (lse_finally != lse_finally) | (lse_finally == float('inf')), + -float('inf'), lse_finally) + factor = tl.exp(lse_finally) + output = tl.load(outputs_ptr + output_offsets) + output = output * factor + + tl.store(new_output_ptr + output_offsets, output) + + +class CPTritonContext: + """ The CPTritonContext is used to avoid recompilation of the Triton JIT. + """ + + def __init__(self): + self.inner_kernel = None + + def call_kernel(self, kernel, grid, *regular_args, **const_args): + if self.inner_kernel is None: + self.inner_kernel = kernel[grid](*regular_args, **const_args) + else: + self.inner_kernel[grid](*regular_args) + + +def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int, + ctx: CPTritonContext): + """ + Apply the all-gathered lses to correct each local rank's attention + output. we still need perform a cross-rank reduction to obtain the + final attention output. + + Args: + output: [ B, H, D ] + lses : [ N, B, H ] + Return: + output: [ B, H, D ] + lse : [ B, H ] + """ + if ctx is None: + ctx = CPTritonContext() + + lse = torch.empty_like(lses[0]) + + grid = (out.shape[0], out.shape[1], 1) + regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), + cp_rank) + const_args = { + "HEAD_DIM": out.shape[-1], + "N_ROUNDED": lses.shape[0], + } + + ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, + **const_args) + return out, lse + + +def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + if cp_group.world_size == 1: + return cp_attn_out + + if ctx is None: + ctx = CPTritonContext() + + lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape, + dtype=cp_attn_lse.dtype, + device=cp_attn_lse.device) + + cp_attn_lse = cp_attn_lse.contiguous() + lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) + out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + assert out.is_contiguous() + out = cp_group.reduce_scatter(out, dim=1) + return out diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 564042cf8eb12..2c3e8c42400ce 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -105,7 +105,9 @@ def flash_mla_with_kvcache( descale_q, descale_k, ) - return out, softmax_lse + + # Note(hc): need revisit when we support DCP with decode query_len > 1. + return out.squeeze(1), softmax_lse.squeeze(-1) # diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py deleted file mode 100644 index 29fa432017616..0000000000000 --- a/vllm/attention/ops/nki_flash_attn.py +++ /dev/null @@ -1,903 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import neuronxcc.nki.isa as nisa -import neuronxcc.nki.language as nl -import numpy as np -import torch -from neuronxcc import nki -from neuronxcc.nki.language import par_dim - -from vllm.utils import cdiv - - -def is_power_of_2(x): - return x > 0 and (x & (x - 1)) == 0 - - -@nki.jit -def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): - """ - Load block tables from HBM into SRAM - - `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`. - In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension. - """ - B_P_SIZE = 128 - - # reshape as `(num_tiles, num_blocks_per_tile)` - assert len(block_tables_hbm.shape) == 1 - (num_total_blocks, ) = block_tables_hbm.shape - assert num_blocks_per_tile * num_tiles == num_total_blocks - block_tables_hbm = block_tables_hbm.reshape( - (num_tiles, num_blocks_per_tile)) - - block_tables_sbuf = nl.zeros( - (cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), - dtype=nl.int32, - ) - for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)): - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = nl.arange(num_blocks_per_tile)[None, :] - block_tables_sbuf[i, i_p, i_f] = nl.load( - block_tables_hbm[i_p + i * B_P_SIZE, i_f], - dtype=nl.int32, - mask=(i_p + i * B_P_SIZE < num_tiles), - ) - return block_tables_sbuf - - -@nki.jit -def transform_block_tables_for_indirect_load( - block_tables, - block_size_tiling_factor, - num_head, - head_id, -): - """ - This function does two things: - 1. calculate new `block_tables` for a `head_id` after flattening - `num_block`, `num_head`, and `block_size_tiling_factor` dimensions - 2. transpose the result so that `block_table` for each tile is mapped to - SBUF Partition dimension for vectorized DMA - - Tiling trick to further improve DMA performance: - Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M - blocks of a given `head_id` from HBM, the load `cache[block_tables, - head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not - fully utilize hardware parallelization. The solution is to tile `block_size` - into `(block_size_tiling_factor, tiled_block_size)` s.t. `M * - block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape - `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. - - Note: - We don't further tile D dimension as small DMA size also hurts performance. - """ - B_P_SIZE = 128 - num_partitions, num_tiles_per_partition, num_blocks_per_tile = ( - block_tables.shape) - assert num_tiles_per_partition == B_P_SIZE - assert is_power_of_2( - num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" - - num_loads = cdiv(num_blocks_per_tile, B_P_SIZE) - block_tables_transposed = nl.ndarray( - ( - num_loads, - par_dim(B_P_SIZE), - num_partitions * num_tiles_per_partition, - ), - dtype=nl.int32, - ) - - # prepare iota ahead of time to avoid repeatedly using Gpsimd - if num_head > 1: - head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1)) - head_id = nl.transpose( - head_id.broadcast_to((1, num_tiles_per_partition))) - if num_blocks_per_tile > 1: - head_id = head_id.broadcast_to( - (num_tiles_per_partition, num_blocks_per_tile)) - - if block_size_tiling_factor > 1: - broadcast_shape = ( - num_tiles_per_partition, - num_blocks_per_tile, - block_size_tiling_factor, - ) - offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :], - dtype=nl.int32).broadcast_to(broadcast_shape) - - for partition_id in nl.affine_range(num_partitions): - block_tables_partition = block_tables[partition_id] - if num_head > 1: - # fuse num_block and num_head dimension - block_tables_partition = block_tables_partition * num_head + head_id - - if block_size_tiling_factor > 1: - # need to apply block size tiling trick - assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE - block_tables_partition = ((block_tables_partition * - block_size_tiling_factor).reshape( - (num_tiles_per_partition, - num_blocks_per_tile, - 1)).broadcast_to(broadcast_shape)) - new_block_tables = block_tables_partition + offset - new_block_tables = new_block_tables.reshape( - (num_tiles_per_partition, B_P_SIZE)) - else: - new_block_tables = block_tables_partition - - # transpose the block table so that it can be used by vector DGE - for i in nl.affine_range(num_loads): - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = (partition_id * num_tiles_per_partition + - nl.arange(num_tiles_per_partition)[None, :]) - block_tables_transposed[i, i_p, i_f] = nl.transpose( - new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) - return block_tables_transposed - - -@nki.jit -def load_kv_tile_from_cache( - cur_k_tile, - cur_v_tile, - kv_cache, - block_tables, - large_k_tile_idx, - num_blocks_per_large_tile, - tiled_block_size, - B_P_SIZE, - B_D_SIZE, -): - """ - Load KV cache and transform Key and Value into layout required by Matmul - - Vectorized DMA Load layout: - Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) - - Layout used by attention matmuls: - Key: (par_dim(B_D_SIZE), seqlen_kv) - Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE) - equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) - """ - # load key cache - num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) - for load_idx in nl.affine_range(num_loads): - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] - loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p, - large_k_tile_idx], i_f]) - if cur_k_tile.dtype != loaded.dtype: - loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) - # Transpose SBUF tensor using PE - for tb_i in nl.affine_range(tiled_block_size): - cur_k_tile[ - :, - nl.ds( - load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE, - B_P_SIZE, - ), - ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)]) - - # load value cache - for load_idx in nl.affine_range(num_loads): - loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p, - large_k_tile_idx], i_f]) - if cur_v_tile.dtype != loaded.dtype: - loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) - i_p = nl.arange(B_P_SIZE)[:, None] - i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] - cur_v_tile[ - :, - nl.ds( - load_idx * tiled_block_size * B_D_SIZE, - tiled_block_size * B_D_SIZE, - ), - ] = loaded - - -@nki.jit -def transpose_p_local(p_local_transposed, - p_local, - LARGE_TILE_SZ, - B_F_SIZE=512): - for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): - if nisa.get_nc_version() == nisa.nc_version.gen3: - p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), - buffer=nl.sbuf, - dtype=p_local.dtype) - else: - p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), - buffer=nl.psum, - dtype=np.float32) - - for j in nl.affine_range(B_F_SIZE // 128): - j_128_slice = nl.ds(j * 128, 128) - i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128) - - if nisa.get_nc_version() == nisa.nc_version.gen3: - p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( - p_local[:, i_j_128_slice]) - else: - p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( - p_local[:, i_j_128_slice]) - - p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( - p_local_t_tmp, dtype=p_local_transposed.dtype) - - -@nki.jit -def _flash_attention_core( - q_local_tile, - k, - v, - o_buffer, - l_buffer, - m_buffer, - kernel_dtype, - acc_type, - tile_mask, - use_causal_mask, - q_tile_idx=None, - initialize=False, - LARGE_TILE_SZ=2048, - B_P_SIZE=128, - B_F_SIZE=512, - B_D_SIZE=128, - qk_res_buffer=None, -): - """ - The flash attention core function to calculate self attention between a tile - of q and a block of K and V. - The q_local_tile has (B_P_SIZE, B_D_SIZE) - The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will - be split into size B_F_SIZE tiles - - The results are stored in the following three buffers - o_buffer: (B_P_SIZE, d) - l_buffer: (B_P_SIZE, 1) - m_buffer: (B_P_SIZE, 1) - - All IO buffers are in SBUF. - """ - num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE - - qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - buffer=nl.sbuf, - dtype=acc_type) - max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), - dtype=acc_type) - for k_i in nl.affine_range(num_k_tile_per_large_tile): - k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) - - if use_causal_mask: - # mask are used to only apply computation to the lower half of the - # matrix, which reduce the arithmetic intensity by up to 50% - multiplication_required_selection = (q_tile_idx * B_P_SIZE - >= k_i * B_F_SIZE) - else: - multiplication_required_selection = True - - if multiplication_required_selection: - qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), - dtype=np.float32, - buffer=nl.psum) # (128, 512) - qk_psum[:, :] = nl.matmul(q_local_tile, - k[:, k_i_b_f_slice], - transpose_x=True) # (p(128), 512) - qk_res_buf[:, k_i_b_f_slice] = nl.where( - tile_mask[:, k_i_b_f_slice], - qk_psum[:, nl.ds(0, B_F_SIZE)], - -9984.0, - dtype=acc_type, - ) - else: - qk_res_buf[:, k_i_b_f_slice] = -9984.0 - - # Calculate max of the current tile - max_local[:, k_i] = nisa.tensor_reduce( - np.max, - qk_res_buf[:, k_i_b_f_slice], - axis=(1, ), - dtype=acc_type, - negate=False, - ) - - if qk_res_buffer is not None: - qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :]) - - max_ = nisa.tensor_reduce( - np.max, - max_local[:, :], - axis=(1, ), - dtype=acc_type, - negate=False, - ) - - o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), - dtype=o_buffer.dtype) - - if initialize: - m_buffer[:, 0] = nl.copy(max_) - m_current = max_ - else: - m_previous = nl.copy(m_buffer[:, 0]) - m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) - - m_current = m_buffer[:, 0] - # Compute scaling factor - alpha = nisa.activation( - np.exp, - m_previous, - bias=-1 * m_current, - scale=1.0, - ) - o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) - - p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) - - p_partial_sum = nl.ndarray( - (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), - dtype=acc_type, - ) - - for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): - k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) - - # compute exp(qk - max) - # Compute partial row - tile sum of exp(qk - max)) - # FIXME : Use activation accumulate to accumulate over k_r_i loop ? - p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce( - np.exp, - qk_res_buf[:, k_r_i_reduce_slice], - bias=-1 * m_current, - scale=1.0, - reduce_op=nl.add, - reduce_res=p_partial_sum[:, k_r_i], - dtype=kernel_dtype, - ) - - ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) - - p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - transpose_p_local( - p_local_transposed=p_local_transposed, - p_local=p_local, - LARGE_TILE_SZ=LARGE_TILE_SZ, - B_F_SIZE=B_F_SIZE, - ) - - pv_psum = nl.zeros( - (par_dim(B_P_SIZE), B_D_SIZE), - dtype=np.float32, - buffer=nl.psum, - ) - for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): - pv_psum[:, :] += nl.matmul( - p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], - v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)], - transpose_x=True, - ) # (128, 128) (p(Br), d) - - if initialize: - o_buffer[:, :] = nl.copy(pv_psum[:, :]) - l_buffer[:, 0] = nl.add(nl.log(ps), max_) - else: - o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) - - l_prev = l_buffer[:, 0] - l_exp = nl.add( - nl.exp(nl.subtract(l_prev, m_current)), - ps, - ) - l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp)) - - -@nki.jit -def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ): - B_P_SIZE = 128 - B_D_SIZE = v_hbm_tile.shape[-1] - loaded = nl.load(v_hbm_tile[ - nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), - :, - ]) - if cur_v_tile.dtype != loaded.dtype: - loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) - cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded - - -@nki.jit -def flash_paged_attention( - query, - key, - value, - kv_cache, - block_tables, - mask, - softmax_scale=None, - mixed_precision=True, - LARGE_TILE_SZ=2048, - return_debug_tensors=False, -): - """ - Flash PagedAttention Forward Kernel. - - IO tensor layouts: - - query: shape (1, n_heads, d, seq_q) - - key: shape (1, n_kv_heads, d, seq_k) - - value: shape (1, n_kv_heads, seq_v, d) - - kv_cache: (2, num_blocks, n_kv_heads, block_size, d) - - block_tables: (num_active_blocks, ) - - mask: (seq_q, num_active_blocks * block_size + seq_q) - - o: shape (1, n_heads, seq_q, d) - - - This kernel requires seq_k == seq_v - - We use continuous batching by default, so the batch dimension is - always 1, and different requests are concatenated along sequence - dimension. - - We use paged cache blocks (kv_cache) to store KV cache. - - IO tensor dtypes: - - This kernel assumes all IO tensors have the same dtype except for - block_tables (int32) and mask (int32) - - If mixed_precision is True, then all Tensor Engine operation will be - performed in bfloat16 and accumulation will be performed in float32. - Otherwise the intermediates will be in the same type as the inputs. - - Compile-time Constants: - - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - - mixed_precision: flag to set non-matmul ops in fp32 precision, default - is set to `true`, if false, we use same precision as input types - - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention - computation reduction - - GQA support Notes: - the spmd kernel for launching kernel should be on kv_heads instead of - nheads - - Example usage: - MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] - usage: `flash_fwd[b, h](q, k, v, ...)` - GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] - usage: `flash_fwd[b, kv_h](q, k, v, ...)` - """ - B_F_SIZE = 512 - B_P_SIZE = 128 - b, h, d, seqlen_q = query.shape - B_D_SIZE = d - n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine - _, num_blocks, k_h, block_size, _ = kv_cache.shape - q_h_per_k_h = h // k_h - assert b == 1, f"invalid batch size {b=}" - assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" - cache_shape = (2, num_blocks, k_h, block_size, d) - assert (tuple(kv_cache.shape) == cache_shape - ), f"{kv_cache.shape=} mismatch, expect {cache_shape}" - assert key is None or tuple(key.shape) == ( - 1, - k_h, - d, - seqlen_q, - ), f"key shape {key.shape} mismatch!" - assert value is None or tuple(value.shape) == ( - 1, - k_h, - seqlen_q, - d, - ), f"value shape {value.shape} mismatch!" - - assert ( - nl.program_ndim() == 2 - ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" - batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) - - (num_active_blocks, ) = block_tables.shape - context_kv_len = num_active_blocks * block_size - assert ( - LARGE_TILE_SZ % B_F_SIZE == 0 - ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p" - assert (context_kv_len % LARGE_TILE_SZ == 0 - ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" - - num_blocks_per_large_tile = LARGE_TILE_SZ // block_size - assert is_power_of_2( - num_blocks_per_large_tile - ), f"{num_blocks_per_large_tile=} is expected of be power of 2" - if seqlen_q > B_F_SIZE: - MAX_REDUCTION_TILE = 2048 - if seqlen_q // 2 > MAX_REDUCTION_TILE: - assert ( - seqlen_q % MAX_REDUCTION_TILE == 0 - ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}" - else: - assert (seqlen_q % B_F_SIZE == 0 - ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})" - - kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype - acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype - softmax_scale = softmax_scale or (1.0 / (d**0.5)) - num_large_k_tile = context_kv_len // LARGE_TILE_SZ - - o = nl.ndarray((b, h, seqlen_q, d), - dtype=query.dtype, - buffer=nl.shared_hbm) - hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = ( - None, - None, - None, - None, - ) - if return_debug_tensors: - hbm_l_buffer = nl.ndarray((b, h, seqlen_q), - dtype=acc_type, - buffer=nl.shared_hbm) - hbm_m_buffer = nl.ndarray((b, h, seqlen_q), - dtype=acc_type, - buffer=nl.shared_hbm) - hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q), - dtype=acc_type, - buffer=nl.shared_hbm) - qk_res_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - block_tables_sbuf = load_block_tables( - block_tables_hbm=block_tables, - num_tiles=num_large_k_tile, - num_blocks_per_tile=num_blocks_per_large_tile, - ) - - # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient - if num_blocks_per_large_tile < B_P_SIZE: - # we checked num_blocks_per_tile is a power of 2 - assert B_P_SIZE % num_blocks_per_large_tile == 0 - block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile - # We assume block_size >= block_size_tiling_factor - assert block_size % block_size_tiling_factor == 0 - else: - block_size_tiling_factor = 1 - tiled_block_size = block_size // block_size_tiling_factor - - # Indirect DMA load must be placed along Partition Dimension - block_tables_sbuf = transform_block_tables_for_indirect_load( - block_tables_sbuf, - block_size_tiling_factor=block_size_tiling_factor, - num_head=k_h, - head_id=head_id, - ) - - # Flatten KV cache to be 3D for loading into SBUF - new_cache_shape = ( - 2, - num_blocks * k_h * block_size_tiling_factor, - tiled_block_size * d, - ) - kv_cache = kv_cache.reshape(new_cache_shape) - - # Global Flash Attention accumulators - o_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - l_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - m_buffer = nl.zeros( - (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), - dtype=acc_type, - buffer=nl.sbuf, - lazy_initialization=True, - ) - - for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): - num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) - cur_k_tile = nl.ndarray( - (par_dim(B_D_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype, - ) - cur_v_tile = nl.ndarray( - (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE), - dtype=kernel_dtype, - ) - load_kv_tile_from_cache( - cur_k_tile=cur_k_tile, - cur_v_tile=cur_v_tile, - kv_cache=kv_cache, - block_tables=block_tables_sbuf, - large_k_tile_idx=large_k_tile_idx, - num_blocks_per_large_tile=num_blocks_per_large_tile, - tiled_block_size=tiled_block_size, - B_P_SIZE=B_P_SIZE, - B_D_SIZE=B_D_SIZE, - ) - - for i in nl.affine_range(n_tile_q): - cur_mask = nl.load(mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), - ]) - for i_q_h in nl.affine_range(q_h_per_k_h): - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) - q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load(q_hbm_tile[:, - nl.ds(i * - B_P_SIZE, B_P_SIZE)]) - if q_sbuf_tile.dtype != kernel_dtype: - q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) - q_tile[:, :] = q_sbuf_tile * softmax_scale - - _flash_attention_core( - q_local_tile=q_tile, - k=cur_k_tile, - v=cur_v_tile, - o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[i, i_q_h], - m_buffer=m_buffer[i, i_q_h], - kernel_dtype=kernel_dtype, - acc_type=acc_type, - tile_mask=cur_mask, - use_causal_mask=False, - q_tile_idx=i, - initialize=large_k_tile_idx == 0, - LARGE_TILE_SZ=LARGE_TILE_SZ, - B_P_SIZE=B_P_SIZE, - B_F_SIZE=B_F_SIZE, - B_D_SIZE=B_D_SIZE, - ) - - # compute attention between input query, key and value - if key is not None and value is not None: - B_F_SIZE = min(seqlen_q, B_F_SIZE) - LARGE_TILE_SZ = seqlen_q - - cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), - dtype=kernel_dtype) - cur_v_tile = nl.ndarray( - (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE), - dtype=kernel_dtype, - ) - - loaded = nl.load(key[batch_id, head_id, :, :]) - if loaded.dtype != kernel_dtype: - loaded = nl.copy(loaded, dtype=kernel_dtype) - cur_k_tile[:, :] = loaded - - v_hbm_tile = value[batch_id, head_id] - for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): - load_v_tile( - v_hbm_tile=v_hbm_tile, - cur_v_tile=cur_v_tile, - large_tile_idx=0, - v_i=v_i, - LARGE_TILE_SZ=LARGE_TILE_SZ, - ) - - for i in nl.affine_range(n_tile_q): - cur_mask = nl.load(mask[ - nl.ds(i * B_P_SIZE, B_P_SIZE), - nl.ds(context_kv_len, LARGE_TILE_SZ), - ]) - for i_q_h in nl.affine_range(q_h_per_k_h): - - q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) - q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] - q_sbuf_tile = nl.load(q_hbm_tile[:, - nl.ds(i * - B_P_SIZE, B_P_SIZE)]) - if q_sbuf_tile.dtype != kernel_dtype: - q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) - q_tile[:, :] = q_sbuf_tile * softmax_scale - _flash_attention_core( - q_local_tile=q_tile, - k=cur_k_tile, - v=cur_v_tile, - o_buffer=o_buffer[i, i_q_h], - l_buffer=l_buffer[i, i_q_h], - m_buffer=m_buffer[i, i_q_h], - kernel_dtype=kernel_dtype, - acc_type=acc_type, - tile_mask=cur_mask, - use_causal_mask=True, - q_tile_idx=i, - initialize=False, - LARGE_TILE_SZ=LARGE_TILE_SZ, - B_P_SIZE=B_P_SIZE, - B_F_SIZE=B_F_SIZE, - B_D_SIZE=B_D_SIZE, - qk_res_buffer=(qk_res_buffer[i, i_q_h] - if qk_res_buffer is not None else None), - ) - - # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # - for i_q_h in nl.affine_range(q_h_per_k_h): - for i in nl.affine_range(n_tile_q): - out = nl.multiply( - o_buffer[i, i_q_h], - nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]), - dtype=kernel_dtype, - ) - - nl.store( - o[ - batch_id, - head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), - :, - ], - out, - ) - # maximum and summation statistics - if return_debug_tensors: - nl.store( - hbm_m_buffer[ - batch_id, - head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), - ], - m_buffer[i, i_q_h, :, :], - ) - nl.store( - hbm_l_buffer[ - batch_id, - head_id * q_h_per_k_h + i_q_h, - nl.ds(i * B_P_SIZE, B_P_SIZE), - ], - l_buffer[i, i_q_h], - ) - nl.store( - hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], - qk_res_buffer[batch_id, i_q_h, :, :], - ) - - if return_debug_tensors: - return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res - return o - - -def reorder_context_mask(mask, LARGE_TILE_SZ, block_size): - """ - Reorder the mask to make it compatible with the flash attention kernel. - - We vectorize KV cache read to improve DMA utilization. However, the layout - that maximizes DMA bandwidth changes the order tokens are consumed. - - The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE, - tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And - each step the engine consumes a column (rather than a row) of B_P_SIZE - tokens. Therefore, the tokens are visited in a strided way. - - To make sure mask matches the order tokens are consumed, we need to properly - transpose mask. - """ - total_query_len, total_seq_len = mask.shape - context_kv_len = total_seq_len - total_query_len - - B_P_SIZE = 128 - assert (LARGE_TILE_SZ - >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}" - num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size) - tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks - if tiled_block_size > 1: - # Mask reordering is needed when tiled_block_size > 1 - device = mask.device - mask = mask.cpu() - context_mask = mask[:, :context_kv_len] - context_mask = context_mask.view( - total_query_len, - context_kv_len // LARGE_TILE_SZ, - num_tiled_blocks // B_P_SIZE, - B_P_SIZE, - tiled_block_size, - ) - context_mask = context_mask.transpose(3, 4).reshape( - total_query_len, context_kv_len) - new_mask = mask[:, context_kv_len:] - return torch.concat([context_mask, new_mask], dim=1).to(device) - else: - return mask - - -def flash_attn_varlen_nkifunc( - query, - key, - value, - kv_cache, - block_table, - attn_mask, - n_kv_head=None, - head_size=None, - LARGE_TILE_SZ=2048, - mixed_precision=True, -): - """ - Compute flash paged attention for variable length sequences. - - This function is a wrapper around the flash attention NKI kernel. It takes - in the following arguments: - - query: (1, n_heads, d, seq_q) - - key: (1, n_kv_heads, d, seq_k) - - value: (1, n_kv_heads, seq_v, d) - - kv_cache: (2, n_blocks, n_kv_heads, block_size, d) - - block_tables: (n_active_blocks, ) - - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) - - Notes: - - attn_mask must be reordered outside using `reorder_context_mask` - - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) - for better DMA throughput - """ - if n_kv_head is None: - n_kv_head = kv_cache.shape[2] - assert kv_cache.shape[0] == 2 - assert kv_cache.shape[2] == n_kv_head - if head_size is None: - head_size = kv_cache.shape[-1] - - kwargs = dict( - query=query, - key=key, - value=value, - kv_cache=kv_cache, - block_tables=block_table, - mask=attn_mask, - softmax_scale=1.0 / (head_size**0.5), - mixed_precision=mixed_precision, - LARGE_TILE_SZ=LARGE_TILE_SZ, - ) - - o = flash_paged_attention[1, n_kv_head](**kwargs) - return o - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, -) -> None: - """ - Writes key-value pairs to the KV cache at specified positions. - - Args: - key (torch.Tensor): Key tensor with shape - (num_tokens, n_kv_head, d_head) - value (torch.Tensor): Value tensor with shape - (num_tokens, n_kv_head, d_head) - kv_cache (torch.Tensor): Key/value cache tensor with shape - (2, num_blocks, n_kv_head, block_size, d_head) - slot_mapping (torch.Tensor): Mapping tensor indicating cache positions - with shape (num_tokens) - - Returns: - None: Updates the kv_cache tensor in-place - """ - block_size = kv_cache.size(3) - n_kv_head = key.size(1) - - # Calculate indices with explicit floor division - block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - block_offsets = slot_mapping % block_size - - # Create the head indices tensor - head_indices = torch.arange(n_kv_head, device=key.device) - - # Update caches using index_put_ - kv_cache.index_put_( - (torch.tensor([0], device=key.device), block_indices[:, None], - head_indices[None, :], block_offsets[:, None]), key) - - kv_cache.index_put_( - (torch.tensor([1], device=key.device), block_indices[:, None], - head_indices[None, :], block_offsets[:, None]), value) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index c6d1501e27578..4d870a45e5800 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -6,9 +6,14 @@ from typing import List, Optional, Tuple import torch -from vllm import _custom_ops as ops +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + if HAS_TRITON: from vllm.attention.ops.prefix_prefill import context_attention_fwd diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index e1d41930f6231..7e5c2b6c62e9b 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -15,6 +15,7 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8 # To check compatibility IS_TURING = current_platform.get_device_capability() == (7, 5) +float8_info = torch.finfo(current_platform.fp8_dtype()) # Here's an example autotuner config for this kernel. This config does provide @@ -43,6 +44,7 @@ def _fwd_kernel(Q, sm_scale, k_scale, v_scale, + out_scale_inv, B_Start_Loc, B_Seqlen, x: tl.constexpr, @@ -82,8 +84,11 @@ def _fwd_kernel(Q, num_unroll_request: tl.constexpr, SKIP_DECODE: tl.constexpr, USE_SINKS: tl.constexpr, + USE_FP8: tl.constexpr, MAX_Q_LEN: tl.constexpr = 0, - MAX_CTX_LEN: tl.constexpr = 0): + MAX_CTX_LEN: tl.constexpr = 0, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -146,7 +151,7 @@ def _fwd_kernel(Q, start_n = tl.multiple_of(start_n, BLOCK_SIZE) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - (start_n // BLOCK_SIZE) * stride_b_loc_s) + (start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64) # [D,BLOCK_SIZE] off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + @@ -284,6 +289,9 @@ def _fwd_kernel(Q, off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od) out_ptrs = Out + off_o + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) tl.store(out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) @@ -367,7 +375,7 @@ def _fwd_kernel_flash_attn_v2( bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0).to(tl.int64) off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + @@ -575,7 +583,7 @@ def _fwd_kernel_alibi( bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0).to(tl.int64) off_k = ( bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + @@ -743,6 +751,7 @@ def context_attention_fwd(q, sliding_window=None, sm_scale=None, skip_decode=False, + fp8_out_scale=None, sinks=None): q_dtype_is_f32 = q.dtype is torch.float32 @@ -793,6 +802,7 @@ def context_attention_fwd(q, if alibi_slopes is not None: assert sinks is None, "Sinks arg is not supported with alibi" + assert fp8_out_scale is None, "FP8 output not supported with alibi" # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: @@ -870,6 +880,7 @@ def context_attention_fwd(q, sm_scale, k_scale, v_scale, + 1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0, b_start_loc, b_seq_len, k_cache.shape[4], @@ -905,6 +916,7 @@ def context_attention_fwd(q, BLOCK_DMODEL_PADDED=Lk_padded, SLIDING_WINDOW=sliding_window, SKIP_DECODE=skip_decode, + USE_FP8=fp8_out_scale is not None, BLOCK_M=128, BLOCK_N=64, num_unroll_cache=4, diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py index ad97152e208b8..2a0336de8cf72 100644 --- a/vllm/attention/ops/rocm_aiter_paged_attn.py +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -81,8 +81,8 @@ class AITERPagedAttention(PagedAttention): blocksparse_head_sliding_step=blocksparse_head_sliding_step) if "fp8" in kv_cache_dtype: - key_cache = key_cache.view(torch.float8_e4m3fnuz) - value_cache = value_cache.view(torch.float8_e4m3fnuz) + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: # use blocksparse paged attention diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 56ebed0f52448..591b68bfa6468 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -10,9 +10,11 @@ import torch from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton logger = init_logger(__name__) +float8_info = torch.finfo(current_platform.fp8_dtype()) @triton.jit @@ -48,47 +50,52 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, @triton.jit def kernel_unified_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - qq_bias_stride_0: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_QQ_BIAS: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - USE_SINKS: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -112,6 +119,7 @@ def kernel_unified_attention_2d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos @@ -171,31 +179,32 @@ def kernel_unified_attention_2d( # actual sequence length max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) - # calculate the number of tiles (blocks) that need to be processed to - # cover the longest sequence prefix (due to causal masking, blocks beyond + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond # this prefix can be skipped) - num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) # iterate through tiles - for j in range(0, num_blocks): + for j in range(0, num_tiles): + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) - offs_n = tl.arange(0, BLOCK_SIZE) - - v_offset = (physical_block_idx * stride_v_cache_0 + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx * stride_k_cache_0 + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0) if K_load.dtype.is_fp8(): @@ -206,9 +215,9 @@ def kernel_unified_attention_2d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0) if V_load.dtype.is_fp8(): @@ -219,12 +228,10 @@ def kernel_unified_attention_2d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) @@ -256,11 +263,12 @@ def kernel_unified_attention_2d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE) + # P : (BLOCK_M, TILE_SIZE) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -281,6 +289,9 @@ def kernel_unified_attention_2d( # epilogue acc = acc / L[:, None] + if USE_FP8: + acc = acc * tl.load(out_scale) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) output_offset = (query_offset_0[:, None] * output_stride_0 + query_offset_1[:, None] * output_stride_1 + @@ -318,6 +329,7 @@ def kernel_unified_attention_3d( query_stride_1: tl.int64, # int, should be equal to head_size qq_bias_stride_0: tl.int64, # int BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool @@ -365,20 +377,19 @@ def kernel_unified_attention_3d( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) - if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: return offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - + offs_t = tl.arange(0, TILE_SIZE) query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + \ offs_m % num_queries_per_kv - query_offset = (query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) @@ -424,30 +435,44 @@ def kernel_unified_attention_3d( qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 ) # shape: [BLOCK_M] - num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) # iterate through tiles within current segment for j in range( - segm_idx * blocks_per_segment, - min((segm_idx + 1) * blocks_per_segment, num_blocks), + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), ): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len - offs_n = tl.arange(0, BLOCK_SIZE) + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + + seq_offset // BLOCK_SIZE).to(tl.int64) - v_offset = (physical_block_idx * stride_v_cache_0 + + v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 + offs_d[None, :] * stride_v_cache_3 + - offs_n[:, None] * stride_v_cache_1) + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) - k_offset = (physical_block_idx * stride_k_cache_0 + + k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 + offs_d[:, None] * stride_k_cache_3 + - offs_n[None, :] * stride_k_cache_1) + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) - # K : (HEAD_SIZE, BLOCK_SIZE) + # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], + mask=dim_mask[:, None] & tile_mask[None, :], other=0.0) if K_load.dtype.is_fp8(): @@ -458,9 +483,9 @@ def kernel_unified_attention_3d( else: K = K_load - # V : (BLOCK_SIZE, HEAD_SIZE) + # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], + mask=dim_mask[None, :] & tile_mask[:, None], other=0.0) if V_load.dtype.is_fp8(): @@ -471,13 +496,10 @@ def kernel_unified_attention_3d( else: V = V_load - seq_offset = j * BLOCK_SIZE + offs_n - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 - # S : (BLOCK_M, BLOCK_SIZE) - S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) - + # S : (BLOCK_M, TILE_SIZE) + S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S += scale * tl.dot(Q, K) if USE_SOFTCAP: @@ -508,11 +530,12 @@ def kernel_unified_attention_3d( # compute running maximum # m_j : (BLOCK_M,) m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of # the entire row. In this case we need to set m_j 0 to avoid NaN m_j = tl.where(m_j > float("-inf"), m_j, 0.0) - # P : (BLOCK_M, BLOCK_SIZE,) + # P : (BLOCK_M, TILE_SIZE,) P = tl.exp(S - m_j[:, None]) # l_j : (BLOCK_M,) @@ -552,23 +575,27 @@ def kernel_unified_attention_3d( @triton.jit def reduce_segments( - output_ptr, # [num_tokens, num_query_heads, head_size] - segm_output_ptr, - #[num_tokens, num_query_heads, max_num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] - seq_lens_ptr, # [num_seqs] - num_seqs, # int - num_query_heads: tl.constexpr, # int - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - block_table_stride: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int, must be power of 2 - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + #[num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + out_scale_inv, # float32 + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + TILE_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) @@ -581,10 +608,10 @@ def reduce_segments( # number of segments for this particular sequence num_segments = NUM_SEGMENTS_PER_SEQ - blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) # create masks for subsequent loads - act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, @@ -624,6 +651,10 @@ def reduce_segments( # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + # write result output_offset = (query_token_idx * output_stride_0 + query_head_idx * output_stride_1 + @@ -649,17 +680,15 @@ def unified_attention( k_descale, v_descale, alibi_slopes=None, + output_scale=None, qq_bias=None, # Optional tensor for sinks sinks=None, ): + assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" - block_size = v.shape[1] - assert q.element_size() >= 2 or block_size >= 32, \ - "Block size must be at least 32 for fp8" - if sinks is not None: assert sinks.shape[0] == q.shape[1], \ "Sinks must be num_query_heads size" @@ -674,7 +703,8 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - BLOCK_M = 16 + BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2( + num_queries_per_kv) BLOCK_Q = BLOCK_M // num_queries_per_kv # Ideally we would launch with kernel with: @@ -688,6 +718,12 @@ def unified_attention( # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + # Assigning default tile sizes for prefill and decode. + # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) + # and at least 16 for all other data types. + TILE_SIZE_PREFILL = 32 + TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 + # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: kernel_unified_attention_2d[( @@ -706,6 +742,7 @@ def unified_attention( scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, + out_scale=1 / output_scale if output_scale is not None else 1.0, softcap=softcap, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, @@ -716,6 +753,7 @@ def unified_attention( output_stride_1=out.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_PREFILL, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -735,6 +773,7 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, + USE_FP8=output_scale is not None, ) else: # for initial version, NUM_SEGMENTS = 16 is chosen as a default @@ -788,6 +827,7 @@ def unified_attention( query_stride_1=q.stride(1), qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, @@ -809,7 +849,6 @@ def unified_attention( BLOCK_M=BLOCK_M, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, ) - reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, @@ -818,13 +857,16 @@ def unified_attention( seq_lens_ptr=seqused_k, num_seqs=num_seqs, num_query_heads=num_query_heads, + out_scale_inv=1 / + output_scale if output_scale is not None else 1.0, output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), - BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + USE_FP8=output_scale is not None, ) diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index f8b00565f0517..dc0af7e28e3e2 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -68,5 +68,18 @@ def flash_attn_supports_fp8() -> bool: current_platform.get_device_capability().major == 9 +def flash_attn_supports_mla(): + from vllm.platforms import current_platform + if current_platform.is_cuda(): + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + is_fa_version_supported) + return is_fa_version_supported(3) \ + and current_platform.get_device_capability()[0] == 9 + except (ImportError, AssertionError): + pass + return False + + def is_flash_attn_varlen_func_available() -> bool: return current_platform.is_cuda() or current_platform.is_xpu() diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 93519b5ba1523..68a937d5750ec 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -11,6 +11,7 @@ generation. Supported dataset types include: - HuggingFace - VisionArena """ +import argparse import ast import base64 import io @@ -36,7 +37,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict from vllm.multimodal.image import convert_image_mode -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import PlaceholderModule try: @@ -99,13 +100,13 @@ class BenchmarkDataset(ABC): ) -> None: """ Initialize the BenchmarkDataset with an optional dataset path and random - seed. - + seed. + Args: dataset_path (Optional[str]): Path to the dataset. If None, it - indicates that a default or random dataset might be used. + indicates that a default or random dataset might be used. random_seed (int): Seed value for reproducible shuffling or - sampling. Defaults to DEFAULT_SEED. + sampling. Defaults to DEFAULT_SEED. """ self.dataset_path = dataset_path # Set the random seed, ensuring that a None value is replaced with the @@ -132,10 +133,10 @@ class BenchmarkDataset(ABC): elif isinstance(mm_content, dict): content.append(mm_content) else: - raise TypeError( + raise TypeError( "Could not process multimodal content of type: " + - f"{type(mm_content)}" - ) + f"{type(mm_content)}" + ) return [{"role": "user", "content": content}] def load_data(self) -> None: @@ -154,34 +155,27 @@ class BenchmarkDataset(ABC): def get_random_lora_request( self, - tokenizer: PreTrainedTokenizerBase, max_loras: Optional[int] = None, lora_path: Optional[str] = None, - ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + ) -> Optional[LoRARequest]: """ - Optionally select a random LoRA request and return its associated - tokenizer. + Optionally select a random LoRA request. This method is used when LoRA parameters are provided. It randomly - selects a LoRA based on max_loras and retrieves a cached tokenizer for - that LoRA if available. Otherwise, it returns the base tokenizer. + selects a LoRA based on max_loras. Args: - tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no - LoRA is selected. max_loras (Optional[int]): The maximum number of LoRAs available. If `None`, LoRA is not used. lora_path (Optional[str]): Path to the LoRA parameters on disk. If `None`, LoRA is not used. Returns: - A tuple with the following elements: - - A new [LoRARequest][] (or `None` if not applicable). - - The tokenizer associated with the LoRA request - (or the base tokenizer). + A new [`LoRARequest`][vllm.lora.request.LoRARequest] + (or `None` if not applicable). """ if max_loras is None or lora_path is None: - return None, tokenizer + return None # Generate a random LoRA ID in the range [1, max_loras]. lora_id = random.randint(1, max_loras) @@ -190,16 +184,13 @@ class BenchmarkDataset(ABC): lora_int_id=lora_id, lora_path=lora_path_on_disk(lora_path), ) - if lora_id not in lora_tokenizer_cache: - lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) - # Return lora_request and the cached tokenizer if available; otherwise, - # return the base tokenizer - return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + return lora_request @abstractmethod def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int, - request_id_prefix: str = "") -> list[SampleRequest]: + num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -210,8 +201,7 @@ class BenchmarkDataset(ABC): tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. - request_id_prefix (str) The prefix of request_id. - + request_id_prefix (str): The prefix of request_id. Returns: list[SampleRequest]: A list of sample requests generated from the @@ -224,6 +214,7 @@ class BenchmarkDataset(ABC): requests: list[SampleRequest], num_requests: int, request_id_prefix: str = "", + no_oversample: bool = False, ) -> None: """ Oversamples the list of requests if its size is less than the desired @@ -233,9 +224,15 @@ class BenchmarkDataset(ABC): requests (List[SampleRequest]): The current list of sampled requests. num_requests (int): The target number of requests. - request_id_prefix (str) The prefix of the request ids. + request_id_prefix (str): The prefix applied to generated request + identifiers. """ + if no_oversample: + logger.info("Skipping oversampling. " \ + "Total samples: %d.", len(requests)) + return + if len(requests) < num_requests: random.seed(self.random_seed) additional = deepcopy( @@ -327,7 +324,7 @@ def process_image(image: Any) -> Mapping[str, Any]: if isinstance(image, str): image_url = (image if image.startswith( - ("http://", "file://")) else f"file://{image}") + ("http://", "https://", "file://")) else f"file://{image}") return {"type": "image_url", "image_url": {"url": image_url}} raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" @@ -362,7 +359,7 @@ def process_video(video: Any) -> Mapping[str, Any]: if isinstance(video, str): video_url = (video if video.startswith( - ("http://", "file://")) else f"file://{video}") + ("http://", "https://", "file://")) else f"file://{video}") return {"type": "video_url", "video_url": {"url": video_url}} raise ValueError( @@ -405,6 +402,7 @@ class RandomDataset(BenchmarkDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, request_id_prefix: str = "", + no_oversample: bool = False, prefix_len: int = DEFAULT_PREFIX_LEN, range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, @@ -518,7 +516,7 @@ class RandomDataset(BenchmarkDataset): size=num_requests) output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests) - offsets = self._rng.integers(0, tokenizer.vocab_size, + offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests) return input_lens, output_lens, offsets @@ -543,10 +541,10 @@ class RandomDataset(BenchmarkDataset): [6880, 6881] -> ['Ġcalls', 'here'] -> [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] To avoid uncontrolled change of the prompt length, - the encoded sequence is truncated before being decode again. + the encoded sequence is truncated before being decoded again. """ # Build the inner sequence by sampling sequentially from the vocab - inner_seq = ((offset + index + np.arange(input_len)) + inner_seq = ((offset + index + np.arange(input_len)) % vocab_size).tolist() token_sequence = prefix_token_ids + inner_seq @@ -581,9 +579,9 @@ class RandomMultiModalDataset(RandomDataset): `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0. The maximum is further clamped to the sum of per-modality limits. 2) Each item’s modality and shape is sampled from `bucket_config`, a dict - mapping (height, width, num_frames) → probability. We treat - `num_frames`=1 as image and and `num_frames` > 1 as video. - Entries with zero probability are removed and the rest are renormalized + mapping (height, width, num_frames) → probability. We treat + `num_frames`=1 as image and and `num_frames` > 1 as video. + Entries with zero probability are removed and the rest are renormalized to sum to 1. 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`. When a modality reaches its cap, all of its buckets are excluded and the @@ -591,8 +589,8 @@ class RandomMultiModalDataset(RandomDataset): Example bucket configuration: {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1} - - Two image buckets (`num_frames`=1) and one video bucket - (`num_frames`=16). + - Two image buckets (`num_frames`=1) and one video bucket + (`num_frames`=16). OBS.: Only image sampling is supported for now. """ @@ -615,9 +613,9 @@ class RandomMultiModalDataset(RandomDataset): def generate_synthetic_image(self, width: int, height: int) -> Image.Image: """Generate synthetic PIL image with random RGB values. - - NOTE: iid pixel sampling results in worst-case compression - (good for stressing I/O), but very unlike real photos. + + NOTE: iid pixel sampling results in worst-case compression + (good for stressing I/O), but very unlike real photos. We could consider a “low-freq” mode (e.g., noise blur) to emulate network realism instead of max stress. """ @@ -629,11 +627,11 @@ class RandomMultiModalDataset(RandomDataset): ) return Image.fromarray(random_pixels) - def generate_synthetic_video(self, width: int, - height: int, + def generate_synthetic_video(self, width: int, + height: int, num_frames: int) -> Any: """Generate synthetic video with random values. - + TODO: Finish this method. """ raise NotImplementedError("Video sampling is WIP.") @@ -647,7 +645,7 @@ class RandomMultiModalDataset(RandomDataset): else: raise ValueError(f"Invalid multimodal item configuration: {config}") - def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], + def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], float]) -> dict[tuple[int, int, int], float]: """ Remove zero probability entries @@ -667,24 +665,24 @@ class RandomMultiModalDataset(RandomDataset): return {k: v / total for k, v in bucket_config.items()} - def generate_mm_item(self, + def generate_mm_item(self, mm_item_config: tuple[int, int, int], ) -> Mapping[str, Any]: """ - Create synthetic images and videos and + Create synthetic images and videos and apply process_image/process_video respectively. This follows the OpenAI API chat completions https://github.com/openai/openai-python """ - + if self.map_config_to_modality(mm_item_config) == "image": return process_image(self.generate_synthetic_image( mm_item_config[1], mm_item_config[0])) elif self.map_config_to_modality(mm_item_config) == "video": return process_video(self.generate_synthetic_video( - mm_item_config[1], - mm_item_config[0], + mm_item_config[1], + mm_item_config[0], mm_item_config[2])) else: raise ValueError(f"Invalid multimodal item configuration: " @@ -714,17 +712,17 @@ class RandomMultiModalDataset(RandomDataset): f"limit_mm_per_prompt: " f"{limit_mm_per_prompt.keys()}") - # Remove zero probability entries + # Remove zero probability entries # and normalize bucket config to sum to 1 bucket_config = self.normalize_bucket_config(bucket_config) logger.info( "Normalized bucket config: %s", bucket_config, ) # Only consider limit per prompt for modalities in bucket config - allowed_modalities = {self.map_config_to_modality(cfg) + allowed_modalities = {self.map_config_to_modality(cfg) for cfg in bucket_config} limit_mm_per_prompt = { - k: v for k, v in limit_mm_per_prompt.items() + k: v for k, v in limit_mm_per_prompt.items() if k in allowed_modalities} if not limit_mm_per_prompt: raise ValueError("No valid limits for modalities present in " @@ -737,19 +735,19 @@ class RandomMultiModalDataset(RandomDataset): # Get max and min num mm items and ensure # it is at most the sum of limit_mm_per_prompt for all modalities max_num_mm_items = min( - sum(limit_mm_per_prompt.values()), + sum(limit_mm_per_prompt.values()), math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)) ) # Ensure min num mm items is at least 0 min_num_mm_items = max( - 0, + 0, math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) ) # Raise error if min num mm items is greater than max num mm items if min_num_mm_items > max_num_mm_items: raise ValueError(f"Min num mm items is greater than max mm items: " f"{min_num_mm_items} > {max_num_mm_items}") - + logger.info( "Sampling number of multimodal items from [%s, %s]", min_num_mm_items, max_num_mm_items, @@ -774,8 +772,8 @@ class RandomMultiModalDataset(RandomDataset): whose size is between min_num_mm_items and max_num_mm_items. Loop over the bucket config and sample a multimodal item. - Loop until the number of multimodal items sampled is equal to - request_num_mm_items or limit of multimodal items per prompt + Loop until the number of multimodal items sampled is equal to + request_num_mm_items or limit of multimodal items per prompt for all modalities is reached. Note: @@ -787,19 +785,19 @@ class RandomMultiModalDataset(RandomDataset): # Get the number of multimodal items to sample request_num_mm_items = int( self._rng.integers(min_num_mm_items, max_num_mm_items + 1) - ) + ) # If request_num_mm_items is 0, yield an empty iterator if request_num_mm_items == 0: return # Initialize modality counters - modality_counter = {self.map_config_to_modality(k): 0 + modality_counter = {self.map_config_to_modality(k): 0 for k in bucket_config} # Copy the bucket config to avoid modifying the original bucket_config_copy = bucket_config.copy() # Loop over the number of multimodal items to sample while sum(modality_counter.values()) < request_num_mm_items: # Sample a multimodal item config - mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), + mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), p=list(bucket_config_copy.values())) modality = self.map_config_to_modality(mm_item_config) # Check that modality count is less than limit per prompt @@ -832,6 +830,7 @@ class RandomMultiModalDataset(RandomDataset): tokenizer: PreTrainedTokenizerBase, num_requests: int, request_id_prefix: str = "", + no_oversample: bool = False, prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, input_len: int = RandomDataset.DEFAULT_INPUT_LEN, @@ -839,7 +838,7 @@ class RandomMultiModalDataset(RandomDataset): limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, - bucket_config: dict[tuple[int, int, int], float] = + bucket_config: dict[tuple[int, int, int], float] = DEFAULT_MM_ITEM_BUCKET_CONFIG, enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, **kwargs, @@ -847,7 +846,7 @@ class RandomMultiModalDataset(RandomDataset): # NOTE: Video sampling is WIP. Raise error if video is in bucket config # and probability is non-zero. - if any(self.map_config_to_modality(cfg) == "video" and p > 0 + if any(self.map_config_to_modality(cfg) == "video" and p > 0 for cfg, p in bucket_config.items()): raise NotImplementedError("Video sampling not implemented; " "set its probability to 0.") @@ -898,7 +897,7 @@ class RandomMultiModalDataset(RandomDataset): ]) if enable_multimodal_chat: - # NOTE: For now this option is only provided for completeness + # NOTE: For now this option is only provided for completeness # given that the serve.py benchmark currently does not use it. mm_chat_prompt: Any = prompt mm_chat_prompt = self.apply_multimodal_chat_transformation( @@ -959,6 +958,7 @@ class ShareGPTDataset(BenchmarkDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: samples: list = [] @@ -971,8 +971,8 @@ class ShareGPTDataset(BenchmarkDataset): entry["conversations"][1]["value"], ) - lora_request, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + lora_request = self.get_random_lora_request( + max_loras=max_loras, lora_path=lora_path) prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids prompt_len = len(prompt_ids) @@ -983,11 +983,11 @@ class ShareGPTDataset(BenchmarkDataset): skip_min_output_len_check=output_len is not None): continue - if image_path := entry.get("image"): - mm_content = process_image(image_path) - elif video_path := entry.get("video"): + if image_path := entry.get("image"): + mm_content = process_image(image_path) + elif video_path := entry.get("video"): mm_content = process_video(video_path) - else: + else: mm_content = None if enable_multimodal_chat: prompt = self.apply_multimodal_chat_transformation( @@ -1002,10 +1002,32 @@ class ShareGPTDataset(BenchmarkDataset): request_id=request_id_prefix + str(ind), )) ind += 1 - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) + self.maybe_oversample_requests(samples, + num_requests, + request_id_prefix, + no_oversample) return samples +class _ValidateDatasetArgs(argparse.Action): + """Argparse action to validate dataset name and path compatibility.""" + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + # Get current values of both dataset_name and dataset_path + dataset_name = getattr(namespace, 'dataset_name', 'random') + dataset_path = getattr(namespace, 'dataset_path', None) + + # Validate the combination + if dataset_name == "random" and dataset_path is not None: + parser.error( + "Cannot use 'random' dataset with --dataset-path. " + "Please specify the appropriate --dataset-name (e.g., " + "'sharegpt', 'custom', 'sonnet') for your dataset file: " + f"{dataset_path}" + ) + + def add_dataset_parser(parser: FlexibleArgumentParser): parser.add_argument("--seed", type=int, default=0) parser.add_argument( @@ -1018,9 +1040,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--dataset-name", type=str, default="random", + action=_ValidateDatasetArgs, choices=[ - "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", - "custom", "prefix_repetition" + "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", + "custom", "prefix_repetition", "spec_bench" ], help="Name of the dataset to benchmark on.", ) @@ -1033,9 +1056,16 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--dataset-path", type=str, default=None, + action=_ValidateDatasetArgs, help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.", ) + parser.add_argument( + "--no-oversample", + action="store_true", + help="Do not oversample if the dataset has " \ + "fewer samples than num-prompts.", + ) # group for dataset specific arguments custom_group = parser.add_argument_group("custom dataset options") @@ -1053,6 +1083,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "Skip applying chat template to prompt, used only for custom dataset.", ) + spec_bench_group = parser.add_argument_group("spec bench dataset options") + spec_bench_group.add_argument( + "--spec-bench-output-len", + type=int, + default=256, + help= + "Num of output tokens per request, used only for spec bench dataset.", + ) + spec_bench_group.add_argument( + "--spec-bench-category", + type=str, + default=None, + help= + "Category for spec bench dataset. If None, use all categories.", + ) + sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group.add_argument( "--sonnet-input-len", @@ -1085,6 +1131,22 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "from the ShareGPT dataset.", ) + blazedit_group = parser.add_argument_group("blazedit dataset options") + blazedit_group.add_argument( + "--blazedit-min-distance", + type=float, + default=0.0, + help= + "Minimum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + blazedit_group.add_argument( + "--blazedit-max-distance", + type=float, + default=1.0, + help= + "Maximum distance for blazedit dataset. Min: 0, Max: 1.0", + ) + random_group = parser.add_argument_group("random dataset options") random_group.add_argument( "--random-input-len", @@ -1227,6 +1289,16 @@ def add_dataset_parser(parser: FlexibleArgumentParser): type=str, default=None, help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-name", + type=str, + default=None, + help=( + "Name of the dataset on HuggingFace " + "(e.g., 'lmarena-ai/VisionArena-Chat'). " + "Specify this if your dataset-path is a local path." + ), + ) hf_group.add_argument( "--hf-output-len", type=int, @@ -1268,6 +1340,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser): def get_samples(args, tokenizer) -> list[SampleRequest]: + + if not hasattr(args, "request_id_prefix"): + args.request_id_prefix = "" + if args.dataset_name == "custom": dataset = CustomDataset(dataset_path=args.dataset_path) input_requests = dataset.sample( @@ -1276,12 +1352,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: output_len=args.custom_output_len, skip_chat_template=args.custom_skip_chat_template, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ) elif args.dataset_name == "sonnet": dataset = SonnetDataset(dataset_path=args.dataset_path) # For the "sonnet" dataset, formatting depends on the backend. - if args.endpoint_type == "openai-chat": + if args.backend == "openai-chat": input_requests = dataset.sample( num_requests=args.num_prompts, input_len=args.sonnet_input_len, @@ -1290,6 +1367,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, return_prompt_formatted=False, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( @@ -1302,33 +1380,74 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, return_prompt_formatted=True, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ) elif args.dataset_name == "hf": # all following datasets are implemented from the # HuggingFaceDataset base class - if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + hf_kwargs = {} + if ( + args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = VisionArenaDataset args.hf_split = "train" args.hf_subset = None - elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in MMVUDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MMVUDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MMVUDataset + args.hf_split = "validation" + args.hf_subset = None + elif ( + args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in InstructCoderDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = InstructCoderDataset args.hf_split = "train" - elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MTBenchDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = MTBenchDataset args.hf_split = "train" - elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in ConversationDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = ConversationDataset - elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS + or args.hf_name in AIMODataset.SUPPORTED_DATASET_PATHS + ): dataset_class = AIMODataset args.hf_split = "train" - elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + elif ( + args.dataset_path + in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS # noqa: E501 + or args.hf_name in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = NextEditPredictionDataset args.hf_split = "train" - elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + elif ( + args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in ASRDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = ASRDataset args.hf_split = "train" - elif args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS: + elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS: + dataset_class = BlazeditDataset + args.hf_split = "train" + hf_kwargs = { + "min_distance": args.blazedit_min_distance, + "max_distance": args.blazedit_max_distance, + } + elif ( + args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS + ): dataset_class = MLPerfDataset args.hf_split = "train" else: @@ -1343,7 +1462,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: "Please consider contributing if you would " "like to add support for additional dataset formats.") - if dataset_class.IS_MULTIMODAL and args.endpoint_type not in [ + if dataset_class.IS_MULTIMODAL and args.backend not in [ "openai-chat", "openai-audio", ]: @@ -1351,23 +1470,35 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: # endpoint-type. raise ValueError( "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' endpoint-type.") + "'openai-audio' backends.") input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, dataset_split=args.hf_split, random_seed=args.seed, no_stream=args.no_stream, + hf_name=args.hf_name, ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.hf_output_len, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + **hf_kwargs ) else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { + "spec_bench": + lambda: SpecBench(dataset_path=args.dataset_path, + category=args.spec_bench_category).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.spec_bench_output_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ), "sharegpt": lambda: ShareGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path ).sample( @@ -1375,6 +1506,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: num_requests=args.num_prompts, output_len=args.sharegpt_output_len, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ), "burstgpt": lambda: BurstGPTDataset( random_seed=args.seed, dataset_path=args.dataset_path @@ -1382,6 +1514,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: tokenizer=tokenizer, num_requests=args.num_prompts, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ), "random": lambda: RandomDataset( random_seed=args.seed, dataset_path=args.dataset_path @@ -1394,6 +1527,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: range_ratio=args.random_range_ratio, request_id_prefix=args.request_id_prefix, batchsize=args.random_batch_size, + no_oversample=args.no_oversample, ), "random-mm": lambda: RandomMultiModalDataset( @@ -1410,6 +1544,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, bucket_config=args.random_mm_bucket_config, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ), "prefix_repetition": lambda: PrefixRepetitionRandomDataset( @@ -1422,12 +1557,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: num_prefixes=args.prefix_repetition_num_prefixes, output_len=args.prefix_repetition_output_len, request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, ), } try: # Enforce endpoint compatibility for multimodal datasets. - if args.dataset_name == "random-mm" and args.endpoint_type not in [ + if args.dataset_name == "random-mm" and args.backend not in [ "openai-chat"]: raise ValueError( "Multi-modal content (images) is only supported on " @@ -1503,8 +1639,17 @@ class CustomDataset(BenchmarkDataset): enable_multimodal_chat: bool = False, skip_chat_template: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: + # load all data if needed + self.num_available_samples = len(self.data) + if num_requests <= 0: + num_requests = self.num_available_samples + logger.info("num_requests is set to 0 or negative, " + "so using all available samples: %d", + num_requests) + sampled_requests = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: @@ -1530,12 +1675,58 @@ class CustomDataset(BenchmarkDataset): expected_output_len=output_len, request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) return sampled_requests +# ----------------------------------------------------------------------------- +# Spec Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SpecBench(CustomDataset): + """ + Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench + Download the dataset using: + wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl + """ # noqa: E501 + + def __init__(self, **kwargs) -> None: + self.category = kwargs.pop("category", None) + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + self.data = [] + + # Load the JSONL file + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'turns' column + if "turns" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'turns' column.") + + for _, row in jsonl_data.iterrows(): + # sample only from a specific category if specified + if (not self.category) or (self.category == row['category']): + prompt = row["turns"][0] + self.data.append({"prompt": prompt}) + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample(self, **kwargs) -> list: + # leverage CustomDataset sample + kwargs["skip_chat_template"] = False + return super().sample(**kwargs) + + # ----------------------------------------------------------------------------- # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- @@ -1576,6 +1767,7 @@ class SonnetDataset(BenchmarkDataset): output_len: int = DEFAULT_OUTPUT_LEN, return_prompt_formatted: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: # Calculate average token length for a poem line. @@ -1671,6 +1863,7 @@ class BurstGPTDataset(BenchmarkDataset): max_loras: Optional[int] = None, lora_path: Optional[str] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: samples = [] @@ -1678,8 +1871,8 @@ class BurstGPTDataset(BenchmarkDataset): for i in range(num_requests): input_len = int(data[i][2]) output_len = int(data[i][3]) - lora_req, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + lora_req = self.get_random_lora_request( + max_loras=max_loras, lora_path=lora_path) vocab_size = tokenizer.vocab_size # Generate a synthetic prompt: a list of token IDs computed as (i + # j) modulo vocab_size. @@ -1710,6 +1903,7 @@ class HuggingFaceDataset(BenchmarkDataset): dataset_split: str, no_stream: bool = False, dataset_subset: Optional[str] = None, + hf_name: Optional[str] = None, **kwargs, ) -> None: super().__init__(dataset_path=dataset_path, **kwargs) @@ -1717,6 +1911,7 @@ class HuggingFaceDataset(BenchmarkDataset): self.dataset_split = dataset_split self.dataset_subset = dataset_subset self.load_stream = not no_stream + self.hf_name = hf_name or dataset_path self.load_data() def load_data(self) -> None: @@ -1748,6 +1943,7 @@ class ConversationDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs) -> list: # Filter examples with at least 2 conversations filtered_data = self.data.filter( @@ -1788,8 +1984,8 @@ class ConversationDataset(HuggingFaceDataset): request_id=request_id_prefix + str(ind), )) ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) return sampled_requests @@ -1819,6 +2015,7 @@ class VisionArenaDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: output_len = (output_len @@ -1827,10 +2024,9 @@ class VisionArenaDataset(HuggingFaceDataset): for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break - parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.hf_name) if parser_fn is None: - raise ValueError( - f"Unsupported dataset path: {self.dataset_path}") + raise ValueError(f"Unsupported dataset path: {self.hf_name}") prompt = parser_fn(item) mm_content = process_image(item["images"][0]) prompt_len = len(tokenizer(prompt).input_ids) @@ -1848,8 +2044,63 @@ class VisionArenaDataset(HuggingFaceDataset): multi_modal_data=mm_content, request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) + return sampled_requests + + +class MMVUDataset(HuggingFaceDataset): + """ + MMVU Dataset. + https://huggingface.co/datasets/yale-nlp/MMVU + """ + + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = { + "yale-nlp/MMVU": + lambda x: x["question"] + " " + ( + " ".join(f"{k}.{v}" for k, v in x["choices"].items()) + ), + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for i, item in enumerate(self.data): + if len(sampled_requests) >= num_requests: + break + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.hf_name) + if parser_fn is None: + raise ValueError(f"Unsupported dataset path: {self.hf_name}") + prompt = parser_fn(item) + mm_content = process_video(item["video"]) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), + )) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) return sampled_requests @@ -1879,6 +2130,7 @@ class InstructCoderDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs) -> list: output_len = (output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN) @@ -1909,8 +2161,8 @@ class InstructCoderDataset(HuggingFaceDataset): expected_output_len=output_len, request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) return sampled_requests @@ -1941,6 +2193,7 @@ class MTBenchDataset(HuggingFaceDataset): output_len: Optional[int] = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: output_len = (output_len @@ -1970,8 +2223,97 @@ class MTBenchDataset(HuggingFaceDataset): expected_output_len=output_len, request_id=request_id_prefix + str(i), )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Blazedit Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BlazeditDataset(HuggingFaceDataset): + """ + Blazedit Dataset. + https://github.com/ise-uiuc/blazedit + + 5k char version: vdaita/edit_5k_char + 10k char version: vdaita/edit_10k_char + """ # noqa: E501 + + # 5k char version will have output as ~5k chars + # 10k char version will have output as ~10k chars + # Assuming 3 char per token, 10k chars will be 3333 tokens + # We set default to 4000 to be safe + DEFAULT_OUTPUT_LEN = 4000 + SUPPORTED_DATASET_PATHS = { + "vdaita/edit_5k_char", + "vdaita/edit_10k_char", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + request_id_prefix: str = "", + no_oversample: bool = False, + min_distance: float = 0.0, + max_distance: float = 1.0, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for i, item in enumerate(self.data): + if len(sampled_requests) >= num_requests: + break + code = item["code"] + change_request = item["change_request"] + norm_distance = item["norm_distance"] + + # compare the levenshtein distance normalized by code length + if norm_distance < min_distance or norm_distance > max_distance: + continue + + # template copied from + # https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501 + instruction = f"""Given a code file, please apply the change requests and generate the new file. + +Original file: +```python +{code} +``` + +Change request: +{change_request} + +Please generate the new code file in the "New file" section below.""" # noqa: E501 + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": instruction + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + request_id=request_id_prefix + str(i), + )) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) + return sampled_requests @@ -1994,6 +2336,7 @@ class AIMODataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs) -> list: sampled_requests = [] ind = 0 @@ -2022,11 +2365,10 @@ class AIMODataset(HuggingFaceDataset): expected_output_len=output_len, multi_modal_data=None, request_id=request_id_prefix + str(ind), - )) ind += 1 self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + request_id_prefix, no_oversample) return sampled_requests @@ -2098,11 +2440,11 @@ class NextEditPredictionDataset(HuggingFaceDataset): def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs): - formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( - self.dataset_path) + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.hf_name) if formatting_prompt_func is None: - raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + raise ValueError(f"Unsupported dataset path: {self.hf_name}") samples = [] for i, sample in enumerate(self.data): sample = formatting_prompt_func(sample) @@ -2116,7 +2458,10 @@ class NextEditPredictionDataset(HuggingFaceDataset): )) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, num_requests, request_id_prefix) + self.maybe_oversample_requests(samples, + num_requests, + request_id_prefix, + no_oversample) return samples @@ -2167,6 +2512,7 @@ class ASRDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list: output_len = (output_len @@ -2204,8 +2550,8 @@ class ASRDataset(HuggingFaceDataset): " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) return sampled_requests @@ -2243,6 +2589,7 @@ class MLPerfDataset(HuggingFaceDataset): num_requests: int, output_len: Optional[int] = None, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: # Force dynamic output length based on reference completion. @@ -2288,8 +2635,8 @@ class MLPerfDataset(HuggingFaceDataset): ) ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix) + self.maybe_oversample_requests(sampled_requests, num_requests, + request_id_prefix, no_oversample) return sampled_requests @@ -2299,7 +2646,7 @@ class MLPerfDataset(HuggingFaceDataset): class PrefixRepetitionRandomDataset(BenchmarkDataset): - # Default values copied from benchmark_serving.py for the repeated prefix + # Default values copied from benchmark_serving.py for the repeated prefix # dataset. DEFAULT_PREFIX_LEN = 256 DEFAULT_SUFFIX_LEN = 256 @@ -2323,6 +2670,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): num_prefixes: int = DEFAULT_NUM_PREFIXES, output_len: int = DEFAULT_OUTPUT_LEN, request_id_prefix: str = "", + no_oversample: bool = False, **kwargs, ) -> list[SampleRequest]: vocab_size = tokenizer.vocab_size diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 6bb2a497119e9..066b8fe834380 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -17,6 +17,47 @@ from tqdm.asyncio import tqdm AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) +class StreamedResponseHandler: + """Handles streaming HTTP responses by accumulating chunks until complete + messages are available.""" + + def __init__(self): + self.buffer = "" + + def add_chunk(self, chunk_bytes: bytes) -> list[str]: + """Add a chunk of bytes to the buffer and return any complete + messages.""" + chunk_str = chunk_bytes.decode("utf-8") + self.buffer += chunk_str + + messages = [] + + # Split by double newlines (SSE message separator) + while "\n\n" in self.buffer: + message, self.buffer = self.buffer.split("\n\n", 1) + message = message.strip() + if message: + messages.append(message) + + # if self.buffer is not empty, check if it is a complete message + # by removing data: prefix and check if it is a valid JSON + if self.buffer.startswith("data: "): + message_content = self.buffer.removeprefix("data: ").strip() + if message_content == "[DONE]": + messages.append(self.buffer.strip()) + self.buffer = "" + elif message_content: + try: + json.loads(message_content) + messages.append(self.buffer.strip()) + self.buffer = "" + except json.JSONDecodeError: + # Incomplete JSON, wait for more chunks. + pass + + return messages + + @dataclass class RequestFuncInput: """The input for the request function.""" @@ -27,6 +68,7 @@ class RequestFuncInput: model: str model_name: Optional[str] = None logprobs: Optional[int] = None + extra_headers: Optional[dict] = None extra_body: Optional[dict] = None multi_modal_content: Optional[Union[dict, list[dict]]] = None ignore_eos: bool = False @@ -47,6 +89,7 @@ class RequestFuncOutput: tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" + start_time: float = 0.0 async def async_request_openai_completions( @@ -88,6 +131,8 @@ async def async_request_openai_completions( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers if request_func_input.request_id: headers["x-request-id"] = request_func_input.request_id @@ -96,52 +141,57 @@ async def async_request_openai_completions( generated_text = "" st = time.perf_counter() + output.start_time = st most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: first_chunk_received = False - async for chunk_bytes in response.content: + handler = StreamedResponseHandler() + + async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk_bytes = chunk_bytes.decode("utf-8") - # NOTE: SSE comments (often used as pings) start with - # a colon. These are not JSON data payload and should - # be skipped. - if chunk_bytes.startswith(":"): - continue - chunk = chunk_bytes.removeprefix("data: ") + messages = handler.add_chunk(chunk_bytes) + for message in messages: + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if message.startswith(":"): + continue - if chunk != "[DONE]": - data = json.loads(chunk) + chunk = message.removeprefix("data: ") - # NOTE: Some completion API might have a last - # usage summary response without a token so we - # want to check a token was generated - if choices := data.get("choices"): - # Note that text could be empty here - # e.g. for special tokens - text = choices[0].get("text") - timestamp = time.perf_counter() - # First token - if not first_chunk_received: - first_chunk_received = True - ttft = time.perf_counter() - st - output.ttft = ttft + if chunk != "[DONE]": + data = json.loads(chunk) - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft - most_recent_timestamp = timestamp - generated_text += text or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") if first_chunk_received: output.success = True else: @@ -213,6 +263,8 @@ async def async_request_openai_chat_completions( "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers if request_func_input.request_id: headers["x-request-id"] = request_func_input.request_id @@ -222,46 +274,50 @@ async def async_request_openai_chat_completions( generated_text = "" ttft = 0.0 st = time.perf_counter() + output.start_time = st most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk_bytes in response.content: + handler = StreamedResponseHandler() + async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk_bytes = chunk_bytes.decode("utf-8") - # NOTE: SSE comments (often used as pings) start with - # a colon. These are not JSON data payload and should - # be skipped. - if chunk_bytes.startswith(":"): - continue - chunk = chunk_bytes.removeprefix("data: ") + messages = handler.add_chunk(chunk_bytes) + for message in messages: + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if message.startswith(":"): + continue - if chunk != "[DONE]": - timestamp = time.perf_counter() - data = json.loads(chunk) + chunk = message.removeprefix("data: ") - if choices := data.get("choices"): - content = choices[0]["delta"].get("content") - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) - # Decoding phase - else: - output.itl.append(timestamp - - most_recent_timestamp) + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft - generated_text += content or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) - most_recent_timestamp = timestamp + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True @@ -316,6 +372,8 @@ async def async_request_openai_audio( headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers if request_func_input.request_id: headers["x-request-id"] = request_func_input.request_id @@ -341,42 +399,47 @@ async def async_request_openai_audio( generated_text = "" ttft = 0.0 st = time.perf_counter() + output.start_time = st most_recent_timestamp = st try: async with session.post(url=api_url, data=form, headers=headers) as response: if response.status == 200: - async for chunk_bytes in response.content: + handler = StreamedResponseHandler() + + async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") - if chunk != "[DONE]": - timestamp = time.perf_counter() - data = json.loads(chunk) + messages = handler.add_chunk(chunk_bytes) + for message in messages: + chunk = message.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) - if choices := data.get("choices"): - content = choices[0]["delta"].get( - "content") - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft + if choices := data.get("choices"): + content = choices[0]["delta"].get( + "content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft - # Decoding phase - else: - output.itl.append( - timestamp - most_recent_timestamp) + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp) - generated_text += content or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") - most_recent_timestamp = timestamp + most_recent_timestamp = timestamp output.generated_text = generated_text output.success = True @@ -416,6 +479,7 @@ async def async_request_openai_embeddings( output = RequestFuncOutput() st = time.perf_counter() + output.start_time = st try: async with session.post( url=api_url, diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index abb838316cd31..7382782f11655 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -8,8 +8,8 @@ to launch the vLLM OpenAI API server: On the client side, run: vllm bench serve \ - --endpoint-type <endpoint_type. Default 'openai'> \ - --label <benchmark result label. Default using endpoint_type> \ + --backend <backend or endpoint type. Default 'openai'> \ + --label <benchmark result label. Default using backend> \ --model <your_model> \ --dataset-name <dataset_name. Default 'random'> \ --request-rate <request_rate. Default inf> \ @@ -18,9 +18,11 @@ On the client side, run: import argparse import asyncio import gc +import importlib.util import json import os import random +import shutil import time import warnings from collections.abc import AsyncGenerator, Iterable @@ -46,6 +48,24 @@ from vllm.transformers_utils.tokenizer import get_tokenizer MILLISECONDS_TO_SECONDS_CONVERSION = 1000 +TERM_PLOTLIB_AVAILABLE = ((importlib.util.find_spec("termplotlib") is not None) + and (shutil.which("gnuplot") is not None)) + + +# TODO: Remove this in v0.11.0 +class DeprecatedEndpointTypeAction(argparse.Action): + """Argparse action for the deprecated --endpoint-type flag. + """ + + def __call__(self, _, namespace, values, option_string=None): + warnings.warn( + "'--endpoint-type' is deprecated and will be removed in v0.11.0. " + "Please use '--backend' instead or remove this argument if you " + "have already set it.", + stacklevel=1, + ) + setattr(namespace, self.dest, values) + class TaskType(Enum): GENERATION = "generation" @@ -80,18 +100,23 @@ class BenchmarkMetrics: median_e2el_ms: float std_e2el_ms: float percentiles_e2el_ms: list[tuple[float, float]] + # Max output tokens per second and concurrent requests at that peak + max_output_tokens_per_s: float + max_concurrent_requests: int + @dataclass class EmbedBenchmarkMetrics: completed: int total_input: int request_throughput: float - total_token_throughput :float + total_token_throughput: float mean_e2el_ms: float std_e2el_ms: float median_e2el_ms: float percentiles_e2el_ms: float + def _get_current_request_rate( ramp_up_strategy: Optional[Literal["linear", "exponential"]], ramp_up_start_rps: Optional[int], @@ -139,7 +164,7 @@ async def get_request( A lower burstiness value (0 < burstiness < 1) results in more bursty requests, while a higher burstiness value (burstiness > 1) results in a more uniform arrival of requests. - ramp_up_strategy (optional): + ramp_up_strategy (optional): The ramp-up strategy. Can be "linear" or "exponential". If None, uses constant request rate (specified by request_rate). ramp_up_start_rps (optional): @@ -150,8 +175,8 @@ async def get_request( assert burstiness > 0, ( f"A positive burstiness factor is expected, but given {burstiness}.") # Convert to list to get length for ramp-up calculations - if isinstance(input_requests, Iterable) and not isinstance( - input_requests, list): + if isinstance(input_requests, + Iterable) and not isinstance(input_requests, list): input_requests = list(input_requests) total_requests = len(input_requests) @@ -161,12 +186,9 @@ async def get_request( request_rates = [] delay_ts = [] for request_index, request in enumerate(input_requests): - current_request_rate = _get_current_request_rate(ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - request_index, - total_requests, - request_rate) + current_request_rate = _get_current_request_rate( + ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps, + request_index, total_requests, request_rate) request_rates.append(current_request_rate) if current_request_rate == float("inf"): delay_ts.append(0) @@ -189,7 +211,7 @@ async def get_request( # NOTE: If we simply accumulate the random delta values # from the gamma distribution, their sum would have 1-2% gap # from target_total_delay_s. The purpose of the following logic is to - # close the gap for stablizing the throughput data + # close the gap for stabilizing the throughput data # from different random seeds. target_total_delay_s = total_requests / request_rate normalize_factor = target_total_delay_s / delay_ts[-1] @@ -206,10 +228,8 @@ async def get_request( def calculate_metrics_for_embeddings( - outputs: list[RequestFuncOutput], - dur_s: float, - selected_percentiles: list[float] -) -> EmbedBenchmarkMetrics: + outputs: list[RequestFuncOutput], dur_s: float, + selected_percentiles: list[float]) -> EmbedBenchmarkMetrics: """Calculate the metrics for the embedding requests. Args: @@ -242,10 +262,8 @@ def calculate_metrics_for_embeddings( mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[ - (p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles - ], + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], ) return metrics @@ -336,6 +354,67 @@ def calculate_metrics( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", stacklevel=2) + + # Calculate max output tokens per second metric + max_output_tokens_per_s = 0.0 + max_concurrent_requests = 0 + + # Find the time range across all successful requests + successful_outputs = [output for output in outputs if output.success] + if successful_outputs: + min_start_time = min(output.start_time + for output in successful_outputs) + max_end_time = max(output.start_time + output.latency + for output in successful_outputs) + + # Create second buckets (ceiling to ensure we capture all time) + duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1 + tokens_per_second = np.zeros(duration_seconds) + concurrent_requests_per_second = np.zeros(duration_seconds) + + for i, output in enumerate(successful_outputs): + # Calculate token generation timestamp using + # start_time, ttft, and itl + token_times = [output.start_time + output.ttft] + current_time = token_times[0] + for itl_value in output.itl: + current_time += itl_value + token_times.append(current_time) + + # Add tokens to second buckets + for token_time in token_times: + second_bucket = int(token_time - min_start_time) + if 0 <= second_bucket < duration_seconds: + tokens_per_second[second_bucket] += 1 + + # Track concurrent requests for each second this request was active + request_start_second = int(output.start_time - min_start_time) + request_end_second = int((output.start_time + output.latency) - + min_start_time) + + for second in range(request_start_second, request_end_second + 1): + concurrent_requests_per_second[second] += 1 + + # Find the maximum tokens per second and corresponding + # concurrent requests + if len(tokens_per_second) > 0: + max_output_tokens_per_s = float(np.max(tokens_per_second)) + max_concurrent_requests = int( + np.max(concurrent_requests_per_second)) + + if TERM_PLOTLIB_AVAILABLE: + import termplotlib as tpl + fig = tpl.figure() + fig.plot(np.arange(len(tokens_per_second)), + tokens_per_second, + title="Output tokens per second") + fig.plot(np.arange(len(concurrent_requests_per_second)), + concurrent_requests_per_second, + title="Concurrent requests per second") + fig.show() + else: + print("tip: install termplotlib and gnuplot to plot the metrics") + metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -365,6 +444,8 @@ def calculate_metrics( median_e2el_ms=np.median(e2els or 0) * 1000, percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles], + max_output_tokens_per_s=max_output_tokens_per_s, + max_concurrent_requests=max_concurrent_requests, ) return metrics, actual_output_lens @@ -389,24 +470,22 @@ async def benchmark( goodput_config_dict: dict[str, float], max_concurrency: Optional[int], lora_modules: Optional[Iterable[str]], + extra_headers: Optional[dict], extra_body: Optional[dict], ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, ramp_up_start_rps: Optional[int] = None, ramp_up_end_rps: Optional[int] = None, ready_check_timeout_sec: int = 600, ): - task_type = ( - TaskType.EMBEDDING - if api_url.endswith("/v1/embeddings") - else TaskType.GENERATION - ) + task_type = (TaskType.EMBEDDING if api_url.endswith("/v1/embeddings") else + TaskType.GENERATION) if endpoint_type in ASYNC_REQUEST_FUNCS: if task_type == TaskType.EMBEDDING: request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"] else: request_func = ASYNC_REQUEST_FUNCS[endpoint_type] else: - raise ValueError(f"Unknown endpoint_type: {endpoint_type}") + raise ValueError(f"Unknown backend: {endpoint_type}") # Reuses connections across requests to reduce TLS handshake overhead. connector = aiohttp.TCPConnector( @@ -434,14 +513,10 @@ async def benchmark( input_requests[0].multi_modal_data, ) - assert ( - test_mm_content is None - or isinstance(test_mm_content, dict) - or ( - isinstance(test_mm_content, list) - and all(isinstance(item, dict) for item in test_mm_content) - ) - ), "multi_modal_data must be a dict or list[dict]" + assert (test_mm_content is None or isinstance(test_mm_content, dict) + or (isinstance(test_mm_content, list) + and all(isinstance(item, dict) for item in test_mm_content)) + ), "multi_modal_data must be a dict or list[dict]" test_input = RequestFuncInput( model=model_id, model_name=model_name, @@ -452,6 +527,7 @@ async def benchmark( logprobs=logprobs, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, + extra_headers=extra_headers, extra_body=extra_body, ) @@ -484,14 +560,15 @@ async def benchmark( logprobs=logprobs, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, + extra_headers=extra_headers, extra_body=extra_body) - profile_output = await request_func( - request_func_input=profile_input, session=session) + profile_output = await request_func(request_func_input=profile_input, + session=session) if profile_output.success: print("Profiler started") - distribution = ("Poisson process" if burstiness == 1.0 - else "Gamma distribution") + distribution = ("Poisson process" + if burstiness == 1.0 else "Gamma distribution") if ramp_up_strategy is not None: print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") @@ -559,17 +636,20 @@ async def benchmark( req_lora_module = next(lora_modules) req_model_id, req_model_name = req_lora_module, req_lora_module - request_func_input = RequestFuncInput(model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body, - request_id=request_id,) + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + request_id=request_id, + ) tasks.append( asyncio.create_task( limited_request_func(request_func_input=request_func_input, @@ -611,19 +691,21 @@ async def benchmark( benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) if isinstance(metrics, BenchmarkMetrics): - print("{:<40} {:<10}".format( - "Total generated tokens:", metrics.total_output)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) if goodput_config_dict: print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput)) if isinstance(metrics, BenchmarkMetrics): - print( - "{:<40} {:<10.2f}".format( - "Output token throughput (tok/s):", metrics.output_throughput - ) - ) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format( + "Peak output token throughput (tok/s):", + metrics.max_output_tokens_per_s)) + print("{:<40} {:<10.2f}".format("Peak concurrent requests:", + metrics.max_concurrent_requests)) print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput)) @@ -644,6 +726,8 @@ async def benchmark( "itls": [output.itl for output in outputs], "generated_texts": [output.generated_text for output in outputs], "errors": [output.error for output in outputs], + "max_output_tokens_per_s": metrics.max_output_tokens_per_s, + "max_concurrent_requests": metrics.max_concurrent_requests, } else: result = { @@ -693,8 +777,8 @@ async def benchmark( if task_type == TaskType.GENERATION: process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric( - "tpot", "TPOT", "Time per Output Token (excl. 1st token)") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") @@ -710,8 +794,8 @@ async def benchmark( output_len=test_output_len, logprobs=logprobs, ) - profile_output = await request_func( - request_func_input=profile_input, session=session) + profile_output = await request_func(request_func_input=profile_input, + session=session) if profile_output.success: print("Profiler stopped") @@ -781,24 +865,28 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, def add_cli_args(parser: argparse.ArgumentParser): add_dataset_parser(parser) - parser.add_argument( - "--endpoint-type", - type=str, - default="openai", - choices=list(ASYNC_REQUEST_FUNCS.keys()), - ) parser.add_argument( "--label", type=str, default=None, help="The label (prefix) of the benchmark results. If not specified, " - "the endpoint type will be used as the label.", + "the value of '--backend' will be used as the label.", ) parser.add_argument( "--backend", type=str, - default="vllm", + default="openai", choices=list(ASYNC_REQUEST_FUNCS.keys()), + help="The type of backend or endpoint to use for the benchmark." + ) + parser.add_argument( + "--endpoint-type", + type=str, + default=None, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + action=DeprecatedEndpointTypeAction, + help="'--endpoint-type' is deprecated and will be removed in v0.11.0. " + "Please use '--backend' instead.", ) parser.add_argument( "--base-url", @@ -815,6 +903,15 @@ def add_cli_args(parser: argparse.ArgumentParser): default="/v1/completions", help="API endpoint.", ) + parser.add_argument( + "--header", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --header x-additional-info=0.3.3) " + "for headers to be passed with each request. These headers override " \ + "per backend constants and values set via environment variable, and " \ + "will be overriden by other arguments (such as request ids)." + ) parser.add_argument( "--max-concurrency", type=int, @@ -838,7 +935,8 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer", type=str, - help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( @@ -969,7 +1067,6 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Specify the prefix of request id.", ) - sampling_group = parser.add_argument_group("sampling parameters") sampling_group.add_argument( "--top-p", @@ -1034,8 +1131,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The ramp-up strategy. This would be used to " "ramp up the request rate from initial RPS to final " "RPS rate (specified by --ramp-up-start-rps and " - "--ramp-up-end-rps.) over the duration of the benchmark." - ) + "--ramp-up-end-rps.) over the duration of the benchmark.") parser.add_argument( "--ramp-up-start-rps", type=int, @@ -1074,13 +1170,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: raise ValueError( "When using ramp-up, do not specify --request-rate. " "The request rate will be controlled by ramp-up parameters. " - "Please remove the --request-rate argument." - ) + "Please remove the --request-rate argument.") if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: raise ValueError( "When using --ramp-up-strategy, both --ramp-up-start-rps and " - "--ramp-up-end-rps must be specified" - ) + "--ramp-up-end-rps must be specified") if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: raise ValueError("Ramp-up start and end RPS must be non-negative") if args.ramp_up_start_rps > args.ramp_up_end_rps: @@ -1090,7 +1184,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: raise ValueError( "For exponential ramp-up, the start RPS cannot be 0.") - endpoint_type = args.endpoint_type label = args.label model_id = args.model model_name = args.served_model_name @@ -1104,6 +1197,18 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: api_url = f"http://{args.host}:{args.port}{args.endpoint}" base_url = f"http://{args.host}:{args.port}" + # Headers + headers = None + if args.header: + headers = {} + for item in args.header: + if "=" in item: + kvstring = item.split("=", 1) + headers[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid header format. Please use KEY=VALUE format.") + tokenizer = get_tokenizer(tokenizer_id, tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code) @@ -1141,7 +1246,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: gc.freeze() benchmark_result = await benchmark( - endpoint_type=args.endpoint_type, + endpoint_type=args.backend, api_url=api_url, base_url=base_url, model_id=model_id, @@ -1161,6 +1266,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, + extra_headers=headers, extra_body=sampling_params, ramp_up_strategy=args.ramp_up_strategy, ramp_up_start_rps=args.ramp_up_start_rps, @@ -1174,7 +1280,8 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: # Setup current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") result_json["date"] = current_dt - result_json["endpoint_type"] = args.endpoint_type + result_json["endpoint_type"] = args.backend # for backward compatibility + result_json["backend"] = args.backend result_json["label"] = label result_json["model_id"] = model_id result_json["tokenizer_id"] = tokenizer_id @@ -1184,12 +1291,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.metadata: for item in args.metadata: if "=" in item: - kvstring = item.split("=") + kvstring = item.split("=", 1) result_json[kvstring[0].strip()] = kvstring[1].strip() else: raise ValueError( - "Invalid metadata format. Please use KEY=VALUE format." - ) + "Invalid metadata format. Please use KEY=VALUE format.") # Traffic result_json["request_rate"] = (args.request_rate if args.request_rate @@ -1225,7 +1331,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: base_model_id = model_id.split("/")[-1] max_concurrency_str = (f"-concurrency{args.max_concurrency}" if args.max_concurrency is not None else "") - label = label or endpoint_type + label = label or args.backend if args.ramp_up_strategy is not None: file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa else: diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index f022a55e625f5..96e39fd92eba0 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -37,6 +37,7 @@ def run_vllm( requests: list[SampleRequest], n: int, engine_args: EngineArgs, + do_profile: bool, disable_detokenize: bool = False, ) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams @@ -75,10 +76,14 @@ def run_vllm( outputs = None if not use_beam_search: start = time.perf_counter() + if do_profile: + llm.start_profile() outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests, use_tqdm=True) + if do_profile: + llm.stop_profile() end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" @@ -88,6 +93,8 @@ def run_vllm( for request in requests: assert request.expected_output_len == output_len start = time.perf_counter() + if do_profile: + llm.start_profile() llm.beam_search( prompts, BeamSearchParams( @@ -95,6 +102,8 @@ def run_vllm( max_tokens=output_len, ignore_eos=True, )) + if do_profile: + llm.stop_profile() end = time.perf_counter() return end - start, outputs @@ -103,6 +112,7 @@ def run_vllm_chat( requests: list[SampleRequest], n: int, engine_args: EngineArgs, + do_profile: bool, disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: """ Run vLLM chat benchmark. This function is recommended ONLY for benchmarking @@ -133,7 +143,11 @@ def run_vllm_chat( detokenize=not disable_detokenize, )) start = time.perf_counter() + if do_profile: + llm.start_profile() outputs = llm.chat(prompts, sampling_params, use_tqdm=True) + if do_profile: + llm.stop_profile() end = time.perf_counter() return end - start, outputs @@ -142,6 +156,7 @@ async def run_vllm_async( requests: list[SampleRequest], n: int, engine_args: AsyncEngineArgs, + do_profile: bool, disable_frontend_multiprocessing: bool = False, disable_detokenize: bool = False, ) -> float: @@ -185,6 +200,8 @@ async def run_vllm_async( generators = [] start = time.perf_counter() + if do_profile: + await llm.start_profile() for i, (prompt, sp, lr) in enumerate(zip(prompts, sampling_params, lora_requests)): generator = llm.generate(prompt, @@ -195,6 +212,8 @@ async def run_vllm_async( all_gens = merge_async_iterators(*generators) async for i, res in all_gens: pass + if do_profile: + await llm.stop_profile() end = time.perf_counter() return end - start @@ -543,6 +562,12 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="Split of the HF dataset.") + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Use Torch Profiler. The env variable " + "VLLM_TORCH_PROFILER_DIR must be set to enable profiler.") # prefix repetition dataset prefix_repetition_group = parser.add_argument_group( @@ -600,22 +625,27 @@ def main(args: argparse.Namespace): requests, args.n, AsyncEngineArgs.from_cli_args(args), - args.disable_frontend_multiprocessing, - args.disable_detokenize, + disable_frontend_multiprocessing=args.disable_frontend_multiprocessing, + disable_detokenize=args.disable_detokenize, + do_profile=args.profile, )) else: elapsed_time, request_outputs = run_vllm( requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + disable_detokenize=args.disable_detokenize, + do_profile=args.profile) elif args.backend == "hf": assert args.tensor_parallel_size == 1 + if args.profile: + raise NotImplementedError( + "Profiling not implemented yet for backend='hf'.") elapsed_time = run_hf(requests, args.model, tokenizer, args.n, args.hf_max_batch_size, args.trust_remote_code, args.disable_detokenize) elif args.backend == "vllm-chat": elapsed_time, request_outputs = run_vllm_chat( requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + disable_detokenize=args.disable_detokenize, do_profile=args.profile) else: raise ValueError(f"Unknown backend: {args.backend}") diff --git a/vllm/collect_env.py b/vllm/collect_env.py index ee43ad12e8a5e..fb9d3657790cf 100644 --- a/vllm/collect_env.py +++ b/vllm/collect_env.py @@ -54,7 +54,6 @@ SystemEnv = namedtuple( 'is_xnnpack_available', 'cpu_info', 'rocm_version', # vllm specific field - 'neuron_sdk_version', # vllm specific field 'vllm_version', # vllm specific field 'vllm_build_flags', # vllm specific field 'gpu_topo', # vllm specific field @@ -75,6 +74,7 @@ DEFAULT_CONDA_PATTERNS = { "zmq", "nvidia", "pynvml", + "flashinfer-python", } DEFAULT_PIP_PATTERNS = { @@ -90,6 +90,7 @@ DEFAULT_PIP_PATTERNS = { "zmq", "nvidia", "pynvml", + "flashinfer-python", } @@ -275,15 +276,6 @@ def get_rocm_version(run_lambda): r'HIP version: (\S+)') -def get_neuron_sdk_version(run_lambda): - # Adapted from your install script - try: - result = run_lambda(["neuron-ls"]) - return result if result[0] == 0 else 'N/A' - except Exception: - return 'N/A' - - def get_vllm_version(): from vllm import __version__, __version_tuple__ @@ -306,10 +298,9 @@ def get_vllm_version(): def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. - return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( + return 'CUDA Archs: {}; ROCm: {}'.format( os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'), 'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled', - 'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled', ) @@ -498,6 +489,16 @@ def get_libc_version(): return '-'.join(platform.libc_ver()) +def is_uv_venv(): + if os.environ.get("UV"): + return True + pyvenv_cfg_path = os.path.join(sys.prefix, 'pyvenv.cfg') + if os.path.exists(pyvenv_cfg_path): + with open(pyvenv_cfg_path, 'r') as f: + return any(line.startswith('uv = ') for line in f) + return False + + def get_pip_packages(run_lambda, patterns=None): """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" if patterns is None: @@ -513,7 +514,7 @@ def get_pip_packages(run_lambda, patterns=None): if pip_available: cmd = [sys.executable, '-mpip', 'list', '--format=freeze'] - elif os.environ.get("UV") is not None: + elif is_uv_venv(): print("uv is set") cmd = ["uv", "pip", "list", "--format=freeze"] else: @@ -601,7 +602,6 @@ def get_env_info(): conda_packages = get_conda_packages(run_lambda) rocm_version = get_rocm_version(run_lambda) - neuron_sdk_version = get_neuron_sdk_version(run_lambda) vllm_version = get_vllm_version() vllm_build_flags = summarize_vllm_build_flags() gpu_topo = get_gpu_topo(run_lambda) @@ -635,7 +635,6 @@ def get_env_info(): is_xnnpack_available=is_xnnpack_available(), cpu_info=get_cpu_info(run_lambda), rocm_version=rocm_version, - neuron_sdk_version=neuron_sdk_version, vllm_version=vllm_version, vllm_build_flags=vllm_build_flags, gpu_topo=gpu_topo, @@ -702,7 +701,6 @@ env_info_fmt += """ vLLM Info ============================== ROCM Version : {rocm_version} -Neuron SDK Version : {neuron_sdk_version} vLLM Version : {vllm_version} vLLM Build Flags: {vllm_build_flags} diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3361b65a9b885..3cc0fc3106f5a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -454,11 +454,12 @@ class VllmBackend: inductor_config = config.inductor_compile_config PASS_KEY = "post_grad_custom_post_pass" if PASS_KEY in inductor_config: - # Config should automatically wrap all inductor passes if isinstance(inductor_config[PASS_KEY], PostGradPassManager): + # PassManager already added to config, make sure it's correct assert (inductor_config[PASS_KEY].uuid() == self.post_grad_pass_manager.uuid()) else: + # Config should automatically wrap all inductor passes assert isinstance(inductor_config[PASS_KEY], InductorPass) self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) inductor_config[PASS_KEY] = self.post_grad_pass_manager diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py index 161d066ce9fb8..6ee82e74963d9 100644 --- a/vllm/compilation/base_static_graph.py +++ b/vllm/compilation/base_static_graph.py @@ -12,8 +12,13 @@ class AbstractStaticGraphWrapper(Protocol): to be captured as a static graph. """ - def __init__(self, runnable: Callable, vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, **kwargs): + def __init__( + self, + runnable: Callable[..., Any], + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + **kwargs: Any, + ) -> None: """ Initializes the StaticGraphWrapper class with graph capturing and execution-related configurations. @@ -31,7 +36,7 @@ class AbstractStaticGraphWrapper(Protocol): """ raise NotImplementedError - def __call__(self, *args, **kwargs) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> Any: """ Executes the wrapped callable. diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7a99aaff707dc..71274420c3426 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -513,7 +513,7 @@ if flashinfer_comm is not None: torch.ops._C.static_scaled_fp8_quant( quant_out, norm_out, scale_factor) if scale_factor is None or norm_out is not None: - # we need to return allreduce outpput + # we need to return allreduce output # in cases of non quant fused AR + RMS norm # and fused AR + RMS norm + quant without fused add allreduce_in.copy_(allreduce_out) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 3095f17110fde..e3677b3dd62d8 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -39,6 +39,7 @@ class AttentionQuantPattern(ABC): self, layer: Attention, quant_key: QuantKey, + dtype: torch.dtype, ): self.layer = layer self.layer_name = layer.layer_name @@ -46,11 +47,16 @@ class AttentionQuantPattern(ABC): self.head_size = layer.head_size self.quant_key = quant_key self.quant_dtype = quant_key.dtype + self.dtype = dtype assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] + def empty(self, *args, **kwargs): + kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs} + return torch.empty(*args, **kwargs) + def empty_quant(self, *args, **kwargs): kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) @@ -91,12 +97,13 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): def __init__( self, layer: Attention, + dtype: torch.dtype, symmetric: bool = True, ): quant_key = QuantKey(dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric) - super().__init__(layer, quant_key) + super().__init__(layer, quant_key, dtype) def _register(self, pm_pass: PatternMatcherPass): @@ -139,10 +146,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads, self.head_size), # attn_output + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # q + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # k + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # v + self.empty(5, self.num_heads, self.head_size, + dtype=self.dtype), # attn_output self.empty_quant(5, self.num_heads * self.head_size), # quant_output empty_fp32(1, 1) # scale @@ -165,8 +176,8 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): will be passed into Attention op as the `output_scale` argument. """ - def __init__(self, layer: Attention): - super().__init__(layer, kNvfp4Quant) + def __init__(self, layer: Attention, dtype: torch.dtype): + super().__init__(layer, kNvfp4Quant, dtype) def _register(self, pm_pass: PatternMatcherPass): @@ -255,11 +266,15 @@ class AttnFusionPass(VllmInductorPass): attn_layers = get_layers_from_vllm_config(config, Attention) for layer_name, layer in attn_layers.items(): - pattern_fp8 = AttentionFp8StaticQuantPattern(layer) + pattern_fp8 = AttentionFp8StaticQuantPattern( + layer, config.model_config.dtype) pattern_fp8.register_if_supported(self.patterns) - pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) - pattern_nvfp4.register_if_supported(self.patterns) + if current_platform.is_cuda() and hasattr(torch.ops._C, + "scaled_fp4_quant"): + pattern_nvfp4 = AttentionNvfp4QuantPattern( + layer, config.model_config.dtype) + pattern_nvfp4.register_if_supported(self.patterns) if len(attn_layers) == 0: logger.warning( diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 4888d4d1298e3..17e85e70218da 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -62,9 +62,6 @@ class NoOpEliminationPass(VllmInductorPass): scaled_mm: "f16[s0, 4096]" = ... at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) out: "f16[s0, 4096]" = at[1] - - TODO(luka): This is currently tested in test_fusion, - but separate tests could be good. """ def __call__(self, graph: torch.fx.Graph): @@ -96,17 +93,19 @@ class NoOpEliminationPass(VllmInductorPass): # Invalid reshape args, skip continue - if self.all_dims_equivalent(shape, input_shape): + if self.reshape_all_dims_equivalent(shape, input_shape): node.replace_all_uses_with(input) graph.erase_node(node) count += 1 elif is_func(node, torch.ops.aten.slice.Tensor): + # python slicing semantics are different from reshape + # Don't treat -1 as inferred dimension input, dim_index, start, end = node.args[:4] input_shape = input.meta["val"].shape - i_dim = input_shape[dim_index] + output_shape = node.meta["val"].shape - if start == 0 and self.dims_equivalent(end, i_dim): + if output_shape == input_shape: node.replace_all_uses_with(input) graph.erase_node(node) count += 1 @@ -116,14 +115,7 @@ class NoOpEliminationPass(VllmInductorPass): base_shape = base.meta["val"].shape view_shape = view.meta["val"].shape - view_dim = view_shape[dim_index] - - # Check that view fully covers base and the full view is used - # (if the view fully covered the base after slicing but was not - # fully used, we could replace slice_scatter with a simple slice - # but that's a niche case). - if (base_shape == view_shape and start == 0 - and self.dims_equivalent(end, view_dim)): + if base_shape == view_shape: node.replace_all_uses_with(view) graph.erase_node(node) count += 1 @@ -132,13 +124,9 @@ class NoOpEliminationPass(VllmInductorPass): self.dump_graph(graph, "after_noop_elimination") self.end_and_log() - def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]], - i_dims: Iterable[Union[int, SymInt]]): - return all( - self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) - - def dims_equivalent(self, dim: Union[int, torch.fx.Node], - i_dim: Union[int, SymInt]) -> bool: + # ---------------------- Reshape helpers ---------------------- + def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: """ This function checks if two dimensions are equivalent. :param dim: The dimension arg to reshape/slice @@ -156,10 +144,18 @@ class NoOpEliminationPass(VllmInductorPass): In case 3, the reshape dimension is a torch.fx.Node, and its value is a SymInt. That value is equal to the input dimension. - """ # Case 1 and 2 if dim == i_dim or dim == -1: return True # Case 3 return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim + + def reshape_all_dims_equivalent( + self, + dims: Iterable[Union[int, torch.fx.Node]], + i_dims: Iterable[Union[int, SymInt]], + ) -> bool: + return all( + self.reshape_dims_equivalent(s, i_s) + for s, i_s in zip(dims, i_dims)) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index f53e8b0308853..25daca00c02d9 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -8,16 +8,15 @@ import enum import hashlib import inspect import json +import os import textwrap -import uuid import warnings -from collections.abc import Mapping from contextlib import contextmanager -from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from dataclasses import InitVar, field, fields, is_dataclass, replace from functools import cached_property, lru_cache from importlib.util import find_spec -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, - Protocol, TypeVar, Union, cast, get_args) +from typing import (TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, + TypeVar, Union, cast, get_args) import regex as re import torch @@ -25,7 +24,7 @@ from pydantic import (ConfigDict, SkipValidation, field_validator, model_validator) from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE -from typing_extensions import Self, assert_never, runtime_checkable +from typing_extensions import assert_never, runtime_checkable import vllm.envs as envs from vllm import version @@ -33,12 +32,21 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) +from vllm.config.kv_events import KVEventsConfig +from vllm.config.kv_transfer import KVTransferConfig +from vllm.config.load import LoadConfig +from vllm.config.lora import LoRAConfig +from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, + MultiModalConfig) from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, ParallelConfig) from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy +from vllm.config.speculative import SpeculativeConfig +from vllm.config.structured_outputs import StructuredOutputsConfig from vllm.config.utils import ConfigType, config from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -47,9 +55,11 @@ from vllm.transformers_utils.config import ( is_interleaved, maybe_override_with_speculators_target_model, try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) -from vllm.transformers_utils.s3_utils import S3Model -from vllm.transformers_utils.utils import is_s3, maybe_model_redirect -from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, LayerBlockType, +from vllm.transformers_utils.runai_utils import (ObjectStorageModel, + is_runai_obj_uri) +from vllm.transformers_utils.utils import maybe_model_redirect +from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, + STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType, LazyLoader, common_broadcastable_dtype, random_uuid) if TYPE_CHECKING: @@ -61,8 +71,6 @@ if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) - from vllm.model_executor.model_loader import LoadFormats - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.sample.logits_processor import LogitsProcessor HfOverrides = Union[dict, Callable[[type], type]] @@ -72,8 +80,6 @@ else: QuantizationConfig = Any QuantizationMethods = Any BaseModelLoader = Any - LoadFormats = Any - TensorizerConfig = Any LogitsProcessor = Any HfOverrides = Union[dict[str, Any], Callable[[type], type]] @@ -170,6 +176,7 @@ class ModelImpl(str, enum.Enum): AUTO = "auto" VLLM = "vllm" TRANSFORMERS = "transformers" + TERRATORCH = "terratorch" def get_attr_docs(cls: type[Any]) -> dict[str, str]: @@ -234,30 +241,12 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: return out -def get_field(cls: ConfigType, name: str) -> Field: - """Get the default factory field of a dataclass by name. Used for getting - default factory fields in `EngineArgs`.""" - if not is_dataclass(cls): - raise TypeError("The given class is not a dataclass.") - cls_fields = {f.name: f for f in fields(cls)} - if name not in cls_fields: - raise ValueError(f"Field '{name}' not found in {cls.__name__}.") - named_field: Field = cls_fields[name] - if (default_factory := named_field.default_factory) is not MISSING: - return field(default_factory=default_factory) - if (default := named_field.default) is not MISSING: - return field(default=default) - raise ValueError( - f"{cls.__name__}.{name} must have a default value or default factory.") - - def is_init_field(cls: ConfigType, name: str) -> bool: return next(f for f in fields(cls) if f.name == name).init TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] -MMEncoderTPMode = Literal["weights", "data"] class LogprobsMode(enum.Enum): @@ -402,23 +391,9 @@ class ModelConfig: that this name(s) will also be used in `model_name` tag content of prometheus metrics, if multiple names provided, metrics tag will take the first one.""" - limit_mm_per_prompt: dict[str, int] = field(default_factory=dict) - """Maximum number of data items per modality per prompt. Only applicable - for multimodal models.""" - interleave_mm_strings: bool = False - """Enable fully interleaved support for multimodal prompts, while using - --chat-template-content-format=string. Defaults to False.""" - skip_mm_profiling: bool = False - """When enabled, skips multimodal memory profiling and only profiles with - language backbone model during engine initialization. - """ - media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) - """Additional args passed to process media inputs, keyed by modalities. - For example, to set num_frames for video, set - `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ use_async_output_proc: bool = True """Whether to use async output processor.""" - config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value + config_format: Union[str, ConfigFormat] = "auto" """The format of the model config to load:\n - "auto" will try to load the config in hf format if available else it will try to load in mistral format.\n @@ -431,39 +406,6 @@ class ModelConfig: hf_overrides: HfOverrides = field(default_factory=dict) """If a dictionary, contains arguments to be forwarded to the Hugging Face config. If a callable, it is called to update the HuggingFace config.""" - mm_processor_kwargs: Optional[dict[str, Any]] = None - """Arguments to be forwarded to the model's processor for multi-modal data, - e.g., image processor. Overrides for the multi-modal processor obtained - from `AutoProcessor.from_pretrained`. The available overrides depend on the - model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. - """ - mm_processor_cache_gb: float = 4 - """The size (in GiB) of the multi-modal processor cache, which is used to - avoid re-processing past multi-modal inputs. - - This cache is duplicated for each API process and engine core process, - resulting in a total memory usage of - `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. - - Set to `0` to disable this cache completely (not recommended).""" - mm_encoder_tp_mode: MMEncoderTPMode = "weights" - """Indicates how to optimize multi-modal encoder inference using - tensor parallelism (TP). - - - `"weights"`: Within the same vLLM engine, split the weights of - each layer across TP ranks. (default TP behavior) - - `"data"`: Within the same vLLM engine, split the batched input data - across TP ranks to process the data in parallel, while hosting - the full weights on each TP rank. - This batch-level DP is not to be confused with API request-level - DP (which is controlled by `--data-parallel-size`). - This is only supported on a per-model basis and falls back to - `"weights"` if the encoder does not support DP.""" - override_neuron_config: dict[str, Any] = field(default_factory=dict) - """Initialize non-default neuron config or override default neuron config - that are specific to Neuron devices, this argument will be used to - configure the neuron config that can not be gathered from the vllm - arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`.""" pooler_config: Optional["PoolerConfig"] = field(init=False) """Pooler config which controls the behaviour of output pooling in pooling models.""" @@ -495,7 +437,9 @@ class ModelConfig: back to the Transformers implementation if no vLLM implementation is available.\n - "vllm" will use the vLLM model implementation.\n - - "transformers" will use the Transformers model implementation.""" + - "transformers" will use the Transformers model implementation.\n + - "terratorch" will use the TerraTorch model implementation. + """ override_attention_dtype: Optional[str] = None """Override dtype for attention""" logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None @@ -504,6 +448,20 @@ class ModelConfig: io_processor_plugin: Optional[str] = None """IOProcessor plugin name to load at model startup""" + # Multimodal config and init vars + multimodal_config: Optional[MultiModalConfig] = None + """Configuration for multimodal model. If `None`, this will be inferred + from the architecture of `self.model`.""" + limit_mm_per_prompt: InitVar[Optional[dict[str, int]]] = None + media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None + mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None + mm_processor_cache_gb: InitVar[Optional[float]] = None + mm_processor_cache_type: InitVar[Optional[MMCacheType]] = None + mm_shm_cache_max_object_size_mb: InitVar[Optional[int]] = None + mm_encoder_tp_mode: InitVar[Optional[MMEncoderTPMode]] = None + interleave_mm_strings: InitVar[Optional[bool]] = None + skip_mm_profiling: InitVar[Optional[bool]] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -537,7 +495,18 @@ class ModelConfig: assert_hashable(str_factors) return hashlib.sha256(str(factors).encode()).hexdigest() - def __post_init__(self) -> None: + def __post_init__( + self, + # Multimodal config init vars + limit_mm_per_prompt: Optional[dict[str, int]], + media_io_kwargs: Optional[dict[str, dict[str, Any]]], + mm_processor_kwargs: Optional[dict[str, Any]], + mm_processor_cache_gb: Optional[float], + mm_processor_cache_type: Optional[MMCacheType], + mm_shm_cache_max_object_size_mb: Optional[int], + mm_encoder_tp_mode: Optional[MMEncoderTPMode], + interleave_mm_strings: Optional[bool], + skip_mm_profiling: Optional[bool]) -> None: # Set the default seed to 0 in V1. # NOTE(woosuk): In V0, we set the default seed to None because the # driver worker shares the same process as the user process, and thus @@ -556,15 +525,6 @@ class ModelConfig: "affect the random state of the Python process that " "launched vLLM.", self.seed) - if self.runner != "draft": - # If we're not running the draft model, check for speculators config - # If speculators config, set model / tokenizer to be target model - self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code) - # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name(self.model, self.served_model_name) @@ -603,7 +563,16 @@ class ModelConfig: f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) - self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer) + self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) + + if self.runner != "draft": + # If we're not running the draft model, check for speculators config + # If speculators config, set model / tokenizer to be target model + self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code) if (backend := envs.VLLM_ATTENTION_BACKEND ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: @@ -626,9 +595,6 @@ class ModelConfig: raise ValueError( "Sleep mode is not supported on current platform.") - if isinstance(self.config_format, str): - self.config_format = ConfigFormat(self.config_format) - hf_config = get_config(self.hf_config_path or self.model, self.trust_remote_code, self.revision, @@ -745,7 +711,7 @@ class ModelConfig: self.pooler_config = self._init_pooler_config() - self.dtype = _get_and_verify_dtype( + self.dtype: torch.dtype = _get_and_verify_dtype( self.model, self.hf_config, self.dtype, @@ -771,7 +737,33 @@ class ModelConfig: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) - self.multimodal_config = self._init_multimodal_config() + # Init multimodal config if needed + if self._model_info.supports_multimodal: + if (mm_encoder_tp_mode == "data" and + not self._model_info.supports_multimodal_encoder_tp_data): + logger.warning_once( + "This model does not support `--mm-encoder-tp-mode data`. " + "Falling back to `--mm-encoder-tp-mode weights`.") + mm_encoder_tp_mode = "weights" + + mm_config_kwargs = dict( + limit_per_prompt=limit_mm_per_prompt, + media_io_kwargs=media_io_kwargs, + mm_processor_kwargs=mm_processor_kwargs, + mm_processor_cache_gb=mm_processor_cache_gb, + mm_processor_cache_type=mm_processor_cache_type, + mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb, + mm_encoder_tp_mode=mm_encoder_tp_mode, + interleave_mm_strings=interleave_mm_strings, + skip_mm_profiling=skip_mm_profiling, + ) + + mm_config_kwargs = { + k: v + for k, v in mm_config_kwargs.items() if v is not None + } + + self.multimodal_config = MultiModalConfig(**mm_config_kwargs) if self.disable_sliding_window: # Set after get_and_verify_max_len to ensure that max_model_len @@ -781,10 +773,6 @@ class ModelConfig: if not self.skip_tokenizer_init: self._verify_tokenizer_mode() - if (not current_platform.is_neuron() and self.override_neuron_config): - raise ValueError( - "`override_neuron_config` is only supported on Neuron.") - # Avoid running try_verify_and_update_config multiple times self.config_updated = False @@ -836,62 +824,46 @@ class ModelConfig: """The architecture vllm actually used.""" return self._architecture - def maybe_pull_model_tokenizer_for_s3(self, model: str, - tokenizer: str) -> None: - """Pull model/tokenizer from S3 to temporary directory when needed. + def maybe_pull_model_tokenizer_for_runai(self, model: str, + tokenizer: str) -> None: + """Pull model/tokenizer from Object Storage to temporary + directory when needed. Args: model: Model name or path tokenizer: Tokenizer name or path """ - if not (is_s3(model) or is_s3(tokenizer)): + if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)): return - if is_s3(model): - s3_model = S3Model() - s3_model.pull_files(model, - allow_pattern=["*.model", "*.py", "*.json"]) + if is_runai_obj_uri(model): + object_storage_model = ObjectStorageModel() + object_storage_model.pull_files( + model, allow_pattern=["*.model", "*.py", "*.json"]) self.model_weights = model - self.model = s3_model.dir + self.model = object_storage_model.dir # If tokenizer is same as model, download to same directory if model == tokenizer: - s3_model.pull_files(model, - ignore_pattern=[ - "*.pt", "*.safetensors", "*.bin", - "*.tensors" - ]) - self.tokenizer = s3_model.dir + object_storage_model.pull_files(model, + ignore_pattern=[ + "*.pt", "*.safetensors", + "*.bin", "*.tensors", + "*.pth" + ]) + self.tokenizer = object_storage_model.dir return # Only download tokenizer if needed and not already handled - if is_s3(tokenizer): - s3_tokenizer = S3Model() - s3_tokenizer.pull_files( - model, - ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"]) - self.tokenizer = s3_tokenizer.dir - - def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: - if self._model_info.supports_multimodal: - if (self.mm_encoder_tp_mode == "data" and - not self._model_info.supports_multimodal_encoder_tp_data): - logger.warning_once( - "This model does not support `--mm-encoder-tp-mode data`. " - "Falling back to `--mm-encoder-tp-mode weights`.") - self.mm_encoder_tp_mode = "weights" - - return MultiModalConfig( - limit_per_prompt=self.limit_mm_per_prompt, - media_io_kwargs=self.media_io_kwargs, - mm_processor_kwargs=self.mm_processor_kwargs, - mm_processor_cache_gb=self.mm_processor_cache_gb, - mm_encoder_tp_mode=self.mm_encoder_tp_mode, - interleave_mm_strings=self.interleave_mm_strings, - skip_mm_profiling=self.skip_mm_profiling, - ) - - return None + if is_runai_obj_uri(tokenizer): + object_storage_tokenizer = ObjectStorageModel() + object_storage_tokenizer.pull_files(model, + ignore_pattern=[ + "*.pt", "*.safetensors", + "*.bin", "*.tensors", + "*.pth" + ]) + self.tokenizer = object_storage_tokenizer.dir def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( @@ -1097,11 +1069,11 @@ class ModelConfig: assert_never(runner_type) - def _parse_quant_hf_config(self): - quant_cfg = getattr(self.hf_config, "quantization_config", None) + def _parse_quant_hf_config(self, hf_config: PretrainedConfig): + quant_cfg = getattr(hf_config, "quantization_config", None) if quant_cfg is None: # compressed-tensors uses a "compression_config" key - quant_cfg = getattr(self.hf_config, "compression_config", None) + quant_cfg = getattr(hf_config, "compression_config", None) else: # Set quant_method for ModelOpt models. @@ -1121,28 +1093,16 @@ class ModelConfig: def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS - optimized_quantization_methods = [ - "fp8", - "modelopt", - "gptq_marlin_24", - "gptq_marlin", - "awq_marlin", - "fbgemm_fp8", - "compressed-tensors", - "experts_int8", - "quark", - "modelopt_fp4", - "bitblas", - "gptq_bitblas", - "inc", - "petit_nvfp4", - ] if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, self.quantization) # Parse quantization method from the HF model config, if available. - quant_cfg = self._parse_quant_hf_config() + quant_cfg = self._parse_quant_hf_config(self.hf_config) + if quant_cfg is None and (text_config := getattr( + self.hf_config, "text_config", None)): + # Check the text config as well for multi-modal models. + quant_cfg = self._parse_quant_hf_config(text_config) if quant_cfg is not None: # Use the community standard 'quant_method' @@ -1174,7 +1134,7 @@ class ModelConfig: ] # Any custom overrides will be in quantization_methods so we place # them at the start of the list so custom overrides have preference - # over the built in ones. + # over the built-in ones. quantization_methods = quantization_methods + overrides # Detect which checkpoint is it @@ -1214,11 +1174,6 @@ class ModelConfig: f"be one of {supported_quantization}.") from vllm.platforms import current_platform current_platform.verify_quantization(self.quantization) - if self.quantization not in optimized_quantization_methods: - logger.warning( - "%s quantization is not fully " - "optimized yet. The speed can be slower than " - "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: # The `max_seq_len_to_capture` was incorrectly @@ -1233,11 +1188,8 @@ class ModelConfig: getattr(self.hf_config, "max_source_positions", 0)) self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, effective_max_seq_len) - # CUDAGraph capture not supported for enc-dec models and mllama on ROCm - ROCM_UNSUPPORTED_MODELS = ['mllama'] - unsupported_rocm = (self.hf_config.model_type - in ROCM_UNSUPPORTED_MODELS - or self.is_encoder_decoder) + # CUDAGraph capture not supported for encoder-decoder models on ROCm + unsupported_rocm = self.is_encoder_decoder if (unsupported_rocm and not self.enforce_eager and current_platform.is_rocm()): @@ -1304,6 +1256,10 @@ class ModelConfig: self.hf_config.dual_chunk_attention_config[ "sparse_attention_enabled"] = True + if envs.VLLM_ATTENTION_BACKEND != STR_DUAL_CHUNK_FLASH_ATTN_VAL: + raise ValueError("please set VLLM_ATTENTION_BACKEND to " + f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}") + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -1418,6 +1374,11 @@ class ModelConfig: if getattr(self.hf_text_config, "head_dim", None) is not None: return self.hf_text_config.head_dim + # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` + if getattr(self.hf_text_config, "hidden_size_per_head", + None) is not None: + return self.hf_text_config.hidden_size_per_head + # FIXME(woosuk): This may not be true for all models. return (self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads) @@ -1501,7 +1462,8 @@ class ModelConfig: if (self.hf_text_config.model_type == "deepseek_mtp" or self.hf_config.model_type == "mimo_mtp" or self.hf_config.model_type == "glm4_moe_mtp" - or self.hf_config.model_type == "ernie_mtp"): + or self.hf_config.model_type == "ernie_mtp" + or self.hf_config.model_type == "qwen3_next_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -1545,7 +1507,7 @@ class ModelConfig: for bc in block_configs[start:end]) else: # Hybrid model Jamba - layers_block_type_value = getattr(self.hf_config, + layers_block_type_value = getattr(self.hf_text_config, "layers_block_type", None) if layers_block_type_value is not None: if hasattr(self.hf_text_config, @@ -1564,15 +1526,28 @@ class ModelConfig: if attn_type_list: return sum(t == 1 for t in attn_type_list[start:end]) - if layers_block_type_value is None and attn_type_list is None: + # Hybrid model Qwen3Next + layer_types_value = getattr(self.hf_config, "layer_types", None) + if layer_types_value is not None: + if getattr(block_type, "value", block_type) == "attention": + return sum(t == "full_attention" + for t in layer_types_value[start:end]) + elif getattr(block_type, "value", + block_type) == "linear_attention": + return sum(t == "linear_attention" + for t in layer_types_value[start:end]) + else: + return sum(t == getattr(block_type, "value", block_type) + for t in layer_types_value[start:end]) + + if (layers_block_type_value is None and attn_type_list is None + and layer_types_value is None): raise ValueError( "The model is an hybrid without a" - "layers_block_type or an attn_type_list in the hf_config," - "cannot determine the num of " + "layers_block_type or an attn_type_list, or a layer_types " + "in the hf_config, cannot determine the num of " f"{block_type.value} layers") - return sum(t == 1 for t in attn_type_list[start:end]) - def get_mamba_chunk_size(self) -> Optional[int]: """ Returns the mamba chunk size if it exists @@ -1680,16 +1655,6 @@ class ModelConfig: @property def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" - """ - For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to - True to enable cross-attention - Neuron needs all multimodal data to be in the decoder and does not - need to explicitly enable cross-attention - """ - if (current_platform.is_neuron() - and self.hf_config.model_type == "mllama"): - return False - return is_encoder_decoder(self.hf_config) @property @@ -1752,6 +1717,37 @@ class ModelConfig: # `llm as reranker` models defaults to not using pad_token. return getattr(self.hf_config, "use_pad_token", True) + @property + def head_dtype(self) -> torch.dtype: + """ + "head" refers to the last Linear layer(s) of an LLM, + such as the lm_head in a generation model, + or the score or classifier in a classification model. + + `head_dtype` currently only supports pooling models.\n + - The pooling model defaults to using fp32 head, + you can use --hf-overrides '{"head_dtype": "model"}' to disable it. + """ + + head_dtype = _get_head_dtype(config=self.hf_config, + dtype=self.dtype, + runner_type=self.runner_type) + + if self.runner_type != "pooling" and head_dtype != self.dtype: + logger.warning_once( + "`head_dtype` currently only supports pooling models." + "fallback to model dtype [%s].", self.dtype) + return self.dtype + + if head_dtype not in current_platform.supported_dtypes: + logger.warning_once( + "The current platform does not support [%s] head dtype, " + "fallback to model dtype [%s].", head_dtype, self.dtype) + return self.dtype + + logger.debug_once("head dtype: %s", head_dtype) + return head_dtype + def get_and_verify_max_len(self, max_model_len: int): # Consider max_model_len in tokenizer_config only when # pooling models use absolute position_embedding. @@ -1774,91 +1770,7 @@ class ModelConfig: return max_model_len -@config -@dataclass -class LoadConfig: - """Configuration for loading the model weights.""" - - load_format: Union[str, LoadFormats] = "auto" - """The format of the model weights to load:\n - - "auto" will try to load the weights in the safetensors format and fall - back to the pytorch bin format if safetensors format is not available.\n - - "pt" will load the weights in the pytorch bin format.\n - - "safetensors" will load the weights in the safetensors format.\n - - "npcache" will load the weights in pytorch format and store a numpy cache - to speed up the loading.\n - - "dummy" will initialize the weights with random values, which is mainly - for profiling.\n - - "tensorizer" will use CoreWeave's tensorizer library for fast weight - loading. See the Tensorize vLLM Model script in the Examples section for - more information.\n - - "runai_streamer" will load the Safetensors weights using Run:ai Model - Streamer.\n - - "bitsandbytes" will load the weights using bitsandbytes quantization.\n - - "sharded_state" will load weights from pre-sharded checkpoint files, - supporting efficient loading of tensor-parallel models.\n - - "gguf" will load weights from GGUF format files (details specified in - https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n - - "mistral" will load weights from consolidated safetensors files used by - Mistral models. - - Other custom values can be supported via plugins.""" - download_dir: Optional[str] = None - """Directory to download and load the weights, default to the default - cache directory of Hugging Face.""" - model_loader_extra_config: Union[dict, TensorizerConfig] = field( - default_factory=dict) - """Extra config for model loader. This will be passed to the model loader - corresponding to the chosen load_format.""" - device: Optional[str] = None - """Device to which model weights will be loaded, default to - device_config.device""" - ignore_patterns: Optional[Union[list[str], str]] = None - """The list of patterns to ignore when loading the model. Default to - "original/**/*" to avoid repeated loading of llama's checkpoints.""" - use_tqdm_on_load: bool = True - """Whether to enable tqdm for showing progress bar when loading model - weights.""" - pt_load_map_location: Union[str, dict[str, str]] = "cpu" - """ - pt_load_map_location: the map location for loading pytorch checkpoint, to - support loading checkpoints can only be loaded on certain devices like - "cuda", this is equivalent to {"": "cuda"}. Another supported format is - mapping from different devices like from GPU 1 to GPU 0: - {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings - in dictionary needs to be double quoted for json parsing. For more details, - see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - self.load_format = self.load_format.lower() - if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: - logger.info( - "Ignoring the following patterns when downloading weights: %s", - self.ignore_patterns) - else: - self.ignore_patterns = ["original/**/*"] - - -Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"] +Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] @config @@ -1914,729 +1826,13 @@ class DeviceConfig: self.device_type = self.device.type # Some device types require processing inputs on CPU - if self.device_type in ["neuron"]: - self.device = torch.device("cpu") - elif self.device_type in ["tpu"]: + if self.device_type in ["tpu"]: self.device = None else: # Set device with device type self.device = torch.device(self.device_type) -SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp", - "ernie_mtp"] - - -@config -@dataclass -class SpeculativeConfig: - """Configuration for speculative decoding.""" - - # General speculative decoding control - num_speculative_tokens: SkipValidation[int] = None # type: ignore - """The number of speculative tokens, if provided. It will default to the - number in the draft model config if present, otherwise, it is required.""" - model: Optional[str] = None - """The name of the draft model, eagle head, or additional weights, if - provided.""" - method: Optional[SpeculativeMethod] = None - """The name of the speculative method to use. If users provide and set the - `model` param, the speculative method type will be detected automatically - if possible, if `model` param is not provided, the method name must be - provided. - - If using `ngram` method, the related configuration `prompt_lookup_max` and - `prompt_lookup_min` should be considered.""" - draft_tensor_parallel_size: Optional[int] = None - """The degree of the tensor parallelism for the draft model. Can only be 1 - or the same as the target model's tensor parallel size.""" - disable_logprobs: bool = True - """If set to True, token log probabilities are not returned during - speculative decoding. If set to False, token log probabilities are returned - according to the log probability settings in SamplingParams.""" - - # Draft model configuration - quantization: Optional[me_quant.QuantizationMethods] = None - """Quantization method that was used to quantize the draft model weights. - If `None`, we assume the model weights are not quantized. Note that it only - takes effect when using the draft model-based speculative method.""" - max_model_len: Optional[int] = None - """The maximum model length of the draft model. Used when testing the - ability to skip speculation for some sequences.""" - revision: Optional[str] = None - """The specific model version to use for the draft model. It can be a - branch name, a tag name, or a commit id. If unspecified, will use the - default version.""" - code_revision: Optional[str] = None - """The specific revision to use for the draft model code on Hugging Face - Hub. It can be a branch name, a tag name, or a commit id. If unspecified, - will use the default version.""" - - # Advanced control - disable_by_batch_size: Optional[int] = None - """Disable speculative decoding for new incoming requests when the number - of enqueued requests is larger than this value, if provided.""" - - # Ngram proposer configuration - prompt_lookup_max: Optional[int] = None - """Maximum size of ngram token window when using Ngram proposer, required - when method is set to ngram.""" - prompt_lookup_min: Optional[int] = None - """Minimum size of ngram token window when using Ngram proposer, if - provided. Defaults to 1.""" - - speculative_token_tree: Optional[str] = None - """Specifies the tree structure for speculative token generation. - """ - # required configuration params passed from engine - target_model_config: SkipValidation[ModelConfig] = None # type: ignore - """The configuration of the target model.""" - target_parallel_config: SkipValidation[ - ParallelConfig] = None # type: ignore - """The parallel configuration for the target model.""" - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore - """Whether vLLM is configured to use chunked prefill or not. Used for - raising an error since it's not yet compatible with speculative decode.""" - disable_log_stats: SkipValidation[bool] = None # type: ignore - """Whether to disable the periodic printing of stage times in speculative - decoding.""" - - # params generated in the post-init stage - draft_model_config: SkipValidation[ModelConfig] = None # type: ignore - """The configuration of the draft model initialized internal.""" - draft_parallel_config: SkipValidation[ - ParallelConfig] = None # type: ignore - """The parallel configuration for the draft model initialized internal.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - # Eagle3 affects the computation graph because it returns intermediate - # hidden states in addition to the final hidden state. - factors.append(self.method == "eagle3") - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - @staticmethod - def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: - if hf_config.model_type == "deepseek_v3": - hf_config.model_type = "deepseek_mtp" - if hf_config.model_type == "deepseek_mtp": - n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["DeepSeekMTPModel"] - }) - - if hf_config.architectures[0] == "MiMoForCausalLM": - hf_config.model_type = "mimo_mtp" - n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "num_hidden_layers": 0, - "n_predict": n_predict, - "architectures": ["MiMoMTPModel"] - }) - - if hf_config.architectures[0] == "Glm4MoeForCausalLM": - hf_config.model_type = "glm4_moe_mtp" - n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "num_hidden_layers": 0, - "n_predict": n_predict, - "architectures": ["Glm4MoeMTPModel"] - }) - - if hf_config.model_type == "ernie4_5_moe": - hf_config.model_type = "ernie_mtp" - if hf_config.model_type == "ernie_mtp": - n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["ErnieMTPModel"] - }) - return hf_config - - return hf_config - - def __post_init__(self): - - # Note: "method" is a new parameter that helps to extend the - # configuration of non-model-based proposers, and the "model" parameter - # will be used to set the draft model, eagle head, or additional weight - # when needed. If users do not specify "method", the speculative method - # will be detected automatically if possible. If the speculative method - # can not be detected, it will be considered as the "draft_model" by - # default. - - if self.model is None and self.num_speculative_tokens is not None: - # TODO(Shangming): Refactor mtp configuration logic when supporting - # mtp acceleration for more models besides deepseek_v3 - if self.target_model_config and \ - (self.target_model_config.hf_text_config.model_type \ - == "deepseek_v3" or - self.target_model_config.hf_text_config.model_type in - ("mimo","ernie4_5_moe")): - # use the draft model from the same model: - self.model = self.target_model_config.model - elif self.method in ("ngram", "[ngram]"): - self.model = "ngram" - else: - raise ValueError("num_speculative_tokens was provided without " - "speculative model.") - - # Automatically configure the method for ngram when "model" is used - # instead of "method" - if self.method is None and (self.model is not None - and self.model in ("ngram", "[ngram]")): - self.method = "ngram" - - if self.method in ("ngram", "[ngram]"): - # Unified to "ngram" internally - self.method = "ngram" - # Set default values if not provided - if (self.prompt_lookup_min is None - and self.prompt_lookup_max is None): - # TODO(woosuk): Tune these values. They are arbitrarily chosen. - self.prompt_lookup_min = 5 - self.prompt_lookup_max = 5 - elif self.prompt_lookup_min is None: - assert self.prompt_lookup_max is not None - self.prompt_lookup_min = self.prompt_lookup_max - elif self.prompt_lookup_max is None: - assert self.prompt_lookup_min is not None - self.prompt_lookup_max = self.prompt_lookup_min - - # Validate values - if self.prompt_lookup_min < 1: - raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") - if self.prompt_lookup_max < 1: - raise ValueError( - f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") - if self.prompt_lookup_min > self.prompt_lookup_max: - raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must " - f"be <= prompt_lookup_max={self.prompt_lookup_max}") - - # TODO: current we still need extract vocab_size from target model - # config, in future, we may try refactor it out, and set - # draft related config as None here. - self.draft_model_config = self.target_model_config - self.draft_parallel_config = self.target_parallel_config - else: - self.prompt_lookup_max = 0 - self.prompt_lookup_min = 0 - - if self.model is not None: - self.draft_model_config = ModelConfig( - model=self.model, - runner="draft", - tokenizer=self.target_model_config.tokenizer, - tokenizer_mode=self.target_model_config.tokenizer_mode, - trust_remote_code=self.target_model_config. - trust_remote_code, - allowed_local_media_path=self.target_model_config. - allowed_local_media_path, - dtype=self.target_model_config.dtype, - seed=self.target_model_config.seed, - revision=self.revision, - code_revision=self.code_revision, - tokenizer_revision=self.target_model_config. - tokenizer_revision, - spec_target_max_model_len=self.target_model_config. - max_model_len, - quantization=self.quantization, - enforce_eager=self.target_model_config.enforce_eager, - max_seq_len_to_capture=self.target_model_config. - max_seq_len_to_capture, - max_logprobs=self.target_model_config.max_logprobs, - hf_overrides=SpeculativeConfig.hf_config_override, - ) - - # Automatically detect the method - if self.method in ('eagle', 'eagle3'): - pass - elif "eagle-" in self.draft_model_config.model.lower() or \ - "eagle3-" in self.draft_model_config.model.lower(): - self.method = "eagle" - elif self.draft_model_config.hf_config.model_type == "medusa": - self.method = "medusa" - elif (self.draft_model_config.hf_config.model_type == - "mlp_speculator"): - self.method = "mlp_speculator" - elif (self.draft_model_config.hf_config.model_type - in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): - self.method = "deepseek_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "All Deepseek MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - elif (self.draft_model_config.hf_config.model_type == - "ernie_mtp"): - self.method = "ernie_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "All Ernie MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - else: - self.method = "draft_model" - raise NotImplementedError( - "Speculative decoding with draft model is not " - "supported yet. Please consider using other " - "speculative decoding methods such as ngram, medusa, " - "eagle, or deepseek_mtp.") - - # Replace hf_config for EAGLE draft_model - if self.method in ("eagle", "eagle3"): - if self.enable_chunked_prefill and not envs.VLLM_USE_V1: - raise ValueError( - "Chunked prefill and EAGLE are not compatible " - "when using V0.") - - from vllm.transformers_utils.configs import ( - SpeculatorsConfig) - from vllm.transformers_utils.configs.eagle import ( - EAGLEConfig) - - if isinstance(self.draft_model_config.hf_config, - (EAGLEConfig, SpeculatorsConfig)): - pass - else: - eagle_config = EAGLEConfig( - self.draft_model_config.hf_config, - method=self.method, - model_type="eagle") - self.draft_model_config.hf_config = eagle_config - - if (self.num_speculative_tokens is not None - and hasattr(self.draft_model_config.hf_config, - "num_lookahead_tokens")): - self.draft_model_config.hf_config.num_lookahead_tokens = \ - self.num_speculative_tokens - - n_predict = getattr(self.draft_model_config.hf_config, - "n_predict", None) - if n_predict is not None: - if self.num_speculative_tokens is None: - # Default to max value defined in draft model config. - self.num_speculative_tokens = n_predict - elif self.num_speculative_tokens > n_predict and \ - self.num_speculative_tokens % n_predict != 0: - # Ensure divisibility for MTP module reuse. - raise ValueError( - f"num_speculative_tokens:{self.num_speculative_tokens}" - f" must be divisible by {n_predict=}") - - if self.speculative_token_tree is None: - # Generate chain of tokens. - self.speculative_token_tree = str([ - (i + 1) * (0, ) - for i in range(self.num_speculative_tokens) - ]) - else: - # Sort the token tree breadth-first. - tree_choices = ast.literal_eval( - self.speculative_token_tree) - self.speculative_token_tree = str( - sorted(tree_choices, key=lambda t: (len(t), t))) - - self.draft_tensor_parallel_size = \ - SpeculativeConfig._verify_and_get_draft_tp( - self.target_parallel_config, - self.draft_tensor_parallel_size, - self.draft_model_config.hf_config - ) - - self.draft_model_config.max_model_len = ( - SpeculativeConfig._maybe_override_draft_max_model_len( - self.max_model_len, - self.draft_model_config.max_model_len, - self.target_model_config.max_model_len, - )) - - self.draft_parallel_config = ( - SpeculativeConfig.create_draft_parallel_config( - self.target_parallel_config, - self.draft_tensor_parallel_size)) - - @staticmethod - def _maybe_override_draft_max_model_len( - speculative_max_model_len: Optional[int], - draft_max_model_len: int, - target_max_model_len: int, - ) -> int: - """Determine the max sequence len for the draft model. This is usually - the draft_max_model_len, but may be the target_max_model_len if it is - less than the draft_max_model_len, or may be speculative_max_model_len - if it is specified. - - This is necessary so that sequences do not exceed the capacity of the - draft model or the target model. - - speculative_max_model_len is mainly used for testing that sequences can - skip speculation. - """ - - if speculative_max_model_len is not None: - - if speculative_max_model_len > draft_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {draft_max_model_len=}") - - if speculative_max_model_len > target_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {target_max_model_len=}") - - return speculative_max_model_len - - return min( - draft_max_model_len, - target_max_model_len, - ) - - @staticmethod - def _verify_and_get_draft_tp( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], - draft_hf_config: PretrainedConfig) -> int: - """ - Verifies and adjusts the tensor parallel size for a draft model - specified using speculative_draft_tensor_parallel_size. - """ - # If speculative_draft_tensor_parallel_size is unset then set it - # appropriately else verify that it is set correctly. - if speculative_draft_tensor_parallel_size is None: - if draft_hf_config.model_type == "mlp_speculator": - speculative_draft_tensor_parallel_size = 1 - if target_parallel_config.tensor_parallel_size > 1: - logger.warning( - "%s cannot currently be run with tp>1; " - "setting speculative_draft_tensor_parallel_size=1", - draft_hf_config.model_type) - else: - speculative_draft_tensor_parallel_size = \ - target_parallel_config.tensor_parallel_size - elif speculative_draft_tensor_parallel_size not in ( - 1, target_parallel_config.tensor_parallel_size): - raise ValueError( - f"{speculative_draft_tensor_parallel_size=} cannot be " - f"other value than 1 or target model tensor_parallel_size") - return speculative_draft_tensor_parallel_size - - @staticmethod - def create_draft_parallel_config( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: int, - ) -> ParallelConfig: - """Create a parallel config for use by the draft worker. - - This is mostly a copy of the target parallel config, except the tp_size. - """ - draft_parallel_config = ParallelConfig( - pipeline_parallel_size=target_parallel_config. - pipeline_parallel_size, - tensor_parallel_size=speculative_draft_tensor_parallel_size, - distributed_executor_backend=target_parallel_config. - distributed_executor_backend, - max_parallel_loading_workers=target_parallel_config. - max_parallel_loading_workers, - disable_custom_all_reduce=target_parallel_config. - disable_custom_all_reduce, - ray_workers_use_nsight=target_parallel_config. - ray_workers_use_nsight, - placement_group=target_parallel_config.placement_group, - ) - - return draft_parallel_config - - @model_validator(mode='after') - def _verify_args(self) -> Self: - if self.num_speculative_tokens is None: - raise ValueError( - "num_speculative_tokens must be provided with " - "speculative model unless the draft model config contains an " - "n_predict parameter.") - - if self.num_speculative_tokens <= 0: - raise ValueError("Expected num_speculative_tokens to be greater " - f"than zero ({self.num_speculative_tokens}).") - - if self.draft_model_config: - self.draft_model_config.verify_with_parallel_config( - self.draft_parallel_config) - - if (self.disable_by_batch_size is not None - and self.disable_by_batch_size < 2): - raise ValueError("Expect the batch size threshold of disabling " - "speculative decoding is > 1, but got " - f"{self.disable_by_batch_size=}") - - eagle3_target_supported = ["llama", "qwen"] - if self.method == "eagle3" and self.target_model_config and not any( - supported_model in - self.target_model_config.hf_text_config.model_type - for supported_model in eagle3_target_supported): - raise ValueError( - f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 - f"Got {self.target_model_config.hf_text_config.model_type=}") - - return self - - @property - def num_lookahead_slots(self) -> int: - """The number of additional slots the scheduler should allocate per - step, in addition to the slots allocated for each known token. - - This is equal to the number of speculative tokens, as each speculative - token must be scored. - """ - return self.num_speculative_tokens - - def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") - - def __repr__(self) -> str: - method = self.method - model = None if method == "ngram" else self.draft_model_config.model - num_spec_tokens = self.num_speculative_tokens - return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" - - -LoRADType = Literal["auto", "float16", "bfloat16"] - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class LoRAConfig: - """Configuration for LoRA.""" - - max_lora_rank: int = 16 - """Max LoRA rank.""" - max_loras: int = 1 - """Max number of LoRAs in a single batch.""" - fully_sharded_loras: bool = False - """By default, only half of the LoRA computation is sharded with tensor - parallelism. Enabling this will use the fully sharded layers. At high - sequence length, max rank or tensor parallel size, this is likely faster. - """ - max_cpu_loras: Optional[int] = None - """Maximum number of LoRAs to store in CPU memory. Must be >= than - `max_loras`.""" - lora_dtype: Union[torch.dtype, LoRADType] = "auto" - """Data type for LoRA. If auto, will default to base model dtype.""" - lora_extra_vocab_size: int = 256 - """(Deprecated) Maximum size of extra vocabulary that can be present in a - LoRA adapter. Will be removed in v0.12.0.""" - lora_vocab_padding_size: ClassVar[int] = current_platform\ - .get_lora_vocab_padding_size() - - default_mm_loras: Optional[dict[str, str]] = None - """Dictionary mapping specific modalities to LoRA model paths; this field - is only applicable to multimodal models and should be leveraged when a - model always expects a LoRA to be active when a given modality is present. - Note that currently, if a request provides multiple additional - modalities, each of which have their own LoRA, we do NOT apply - default_mm_loras because we currently only support one lora adapter - per prompt. When run in offline mode, the lora IDs for n modalities - will be automatically assigned to 1-n with the names of the modalities - in alphabetic order.""" - bias_enabled: bool = False - """Enable bias for LoRA adapters.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - factors: list[Any] = [] - factors.append(self.max_lora_rank) - factors.append(self.max_loras) - factors.append(self.fully_sharded_loras) - factors.append(self.lora_dtype) - factors.append(self.lora_extra_vocab_size) - factors.append(self.lora_vocab_padding_size) - factors.append(self.bias_enabled) - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - # Deprecation warning for lora_extra_vocab_size - logger.warning( - "`lora_extra_vocab_size` is deprecated and will be removed " - "in v0.12.0. Additional vocabulary support for " - "LoRA adapters is being phased out.") - - # Setting the maximum rank to 512 should be able to satisfy the vast - # majority of applications. - possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) - possible_lora_extra_vocab_size = (256, 512) - if self.max_lora_rank not in possible_max_ranks: - raise ValueError( - f"max_lora_rank ({self.max_lora_rank}) must be one of " - f"{possible_max_ranks}.") - if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: - raise ValueError( - f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " - f"must be one of {possible_lora_extra_vocab_size}.") - if self.max_loras < 1: - raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") - if self.max_cpu_loras is None: - self.max_cpu_loras = self.max_loras - elif self.max_cpu_loras < self.max_loras: - raise ValueError( - f"max_cpu_loras ({self.max_cpu_loras}) must be >= " - f"max_loras ({self.max_loras})") - - def verify_with_cache_config(self, cache_config: CacheConfig): - if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: - raise ValueError( - "V0 LoRA does not support CPU offload, please use V1.") - - def verify_with_model_config(self, model_config: ModelConfig): - if self.lora_dtype in (None, "auto"): - self.lora_dtype = model_config.dtype - elif isinstance(self.lora_dtype, str): - self.lora_dtype = getattr(torch, self.lora_dtype) - - -@config -@dataclass -class MultiModalConfig: - """Controls the behavior of multimodal models.""" - - limit_per_prompt: dict[str, int] = \ - cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt")) - """ - The maximum number of input items allowed per prompt for each modality. - Defaults to 1 (V0) or 999 (V1) for each modality. - - For example, to allow up to 16 images and 2 videos per prompt: - `{"image": 16, "video": 2}` - """ - - media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) - """Additional args passed to process media inputs, keyed by modalities. - For example, to set num_frames for video, set - `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ - - mm_processor_kwargs: Optional[dict[str, object]] = None - """ - Overrides for the multi-modal processor obtained from - `transformers.AutoProcessor.from_pretrained`. - - The available overrides depend on the model that is being run. - - For example, for Phi-3-Vision: - `{"num_crops": 4}`. - """ - - mm_processor_cache_gb: float = 4 - """ - The size (in GiB) of the multi-modal processor cache, which is used to - - This cache is duplicated for each API process and engine core process, - resulting in a total memory usage of - `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. - - Set to `0` to disable this cache completely (not recommended). - """ - - mm_encoder_tp_mode: MMEncoderTPMode = "weights" - """ - Indicates how to optimize multi-modal encoder inference using - tensor parallelism (TP). - - - `"weights"`: Within the same vLLM engine, split the weights of - each layer across TP ranks. (default TP behavior) - - `"data"`: Within the same vLLM engine, split the batched input data - across TP ranks to process the data in parallel, while hosting - the full weights on each TP rank. - This batch-level DP is not to be confused with API request-level - DP (which is controlled by `--data-parallel-size`). - This is only supported on a per-model basis and falls back to - `"weights"` if the encoder does not support DP. - """ - - interleave_mm_strings: bool = False - """ - Enable fully interleaved support for multimodal prompts. - """ - - skip_mm_profiling: bool = False - """ - When enabled, skips multimodal memory profiling and only profiles with - language backbone model during engine initialization. - - This reduces engine startup time but shifts the responsibility to users for - estimating the peak memory usage of the activation of multimodal encoder and - embedding cache. - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def get_limit_per_prompt(self, modality: str) -> int: - """ - Get the maximum number of input items allowed per prompt - for the given modality. - """ - return self.limit_per_prompt.get( - modality, - 999 if envs.VLLM_USE_V1 else 1, - ) - - def merge_mm_processor_kwargs( - self, - inference_kwargs: Mapping[str, object], - ) -> dict[str, object]: - """ - Get the keyword arguments to pass to the multi-modal processor - according to the extra arguments passed during inference. - """ - kwargs = self.mm_processor_kwargs or {} - return kwargs | dict(inference_kwargs) - - @config @dataclass class PoolerConfig: @@ -2651,24 +1847,46 @@ class PoolerConfig: ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the embeddings outputs. + Whether to normalize the embeddings outputs. Defaults to True. """ dimensions: Optional[int] = None """ Reduce the dimensions of embeddings if model - support matryoshka representation. + support matryoshka representation. Defaults to None. + """ + enable_chunked_processing: Optional[bool] = None + """ + Whether to enable chunked processing for long inputs that exceed the model's + maximum position embeddings. When enabled, long inputs will be split into + chunks, processed separately, and then aggregated using weighted averaging. + This allows embedding models to handle arbitrarily long text without CUDA + errors. Defaults to False. + """ + max_embed_len: Optional[int] = None + """ + Maximum input length allowed for embedding generation. When set, allows + inputs longer than max_embed_len to be accepted for embedding models. + When an input exceeds max_embed_len, it will be handled according to + the original max_model_len validation logic. + Defaults to None (i.e. set to max_model_len). """ ## for classification models activation: Optional[bool] = None """ Whether to apply activation function to the classification outputs. + Defaults to True. + """ + logit_bias: Optional[float] = None + """ + If provided, apply classification logit biases. Defaults to None. """ ## for reward models softmax: Optional[bool] = None """ Whether to apply softmax to the reward outputs. + Defaults to True. """ step_tag_id: Optional[int] = None """ @@ -2683,25 +1901,6 @@ class PoolerConfig: ``math-shepherd-mistral-7b-prm`` model. """ - enable_chunked_processing: Optional[bool] = None - """ - Whether to enable chunked processing for long inputs that exceed the model's - maximum position embeddings. When enabled, long inputs will be split into - chunks, processed separately, and then aggregated using weighted averaging. - This allows embedding models to handle arbitrarily long text without CUDA - errors. Defaults to False. - """ - - max_embed_len: Optional[int] = None - """ - Maximum input length allowed for embedding generation. When set, allows - inputs longer than max_embed_len to be accepted for embedding models. - This parameter enables accepting long inputs without requiring - VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds - max_embed_len, it will be handled according to the original max_model_len - validation logic. Defaults to None (i.e. set to max_model_len). - """ - def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2734,6 +1933,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = { _FLOAT16_NOT_SUPPORTED_MODELS = { "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3_text": + "Numerical instability. Please use bfloat16 or float32 instead.", "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", } @@ -2886,6 +2087,31 @@ def _get_and_verify_dtype( return torch_dtype +def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype, + runner_type: str) -> torch.dtype: + head_dtype: Optional[Union[str, + torch.dtype]] = getattr(config, "head_dtype", + None) + + if head_dtype == "model": + return dtype + elif isinstance(head_dtype, str): + head_dtype = head_dtype.lower() + if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {head_dtype!r}") + return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype] + elif isinstance(head_dtype, torch.dtype): + return head_dtype + elif head_dtype is None: + if torch.float32 not in current_platform.supported_dtypes: + return dtype + if runner_type == "pooling": + return torch.float32 + return dtype + else: + raise ValueError(f"Unknown dtype: {head_dtype}") + + def _get_and_verify_max_len( hf_config: PretrainedConfig, tokenizer_config: Optional[dict], @@ -3054,66 +2280,6 @@ def get_served_model_name(model: str, return served_model_name -GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines", - "lm-format-enforcer"] - - -@config -@dataclass -class DecodingConfig: - """Dataclass which contains the decoding strategy of the engine.""" - - backend: GuidedDecodingBackend = "auto" - """Which engine will be used for guided decoding (JSON schema / regex etc) - by default. With "auto", we will make opinionated choices based on request - contents and what the backend libraries currently support, so the behavior - is subject to change in each release.""" - - disable_fallback: bool = False - """If `True`, vLLM will not fallback to a different backend on error.""" - - disable_any_whitespace: bool = False - """If `True`, the model will not generate any whitespace during guided - decoding. This is only supported for xgrammar and guidance backends.""" - - disable_additional_properties: bool = False - """If `True`, the `guidance` backend will not use `additionalProperties` - in the JSON schema. This is only supported for the `guidance` backend and - is used to better align its behaviour with `outlines` and `xgrammar`.""" - - reasoning_backend: str = "" - """Select the reasoning parser depending on the model that you're using. - This is used to parse the reasoning content into OpenAI API format.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - if (self.disable_any_whitespace - and self.backend not in ("xgrammar", "guidance")): - raise ValueError("disable_any_whitespace is only supported for " - "xgrammar and guidance backends.") - if (self.disable_additional_properties and self.backend != "guidance"): - raise ValueError("disable_additional_properties is only supported " - "for the guidance backend.") - - DetailedTraceModules = Literal["model", "worker", "all"] @@ -3203,149 +2369,6 @@ class ObservabilityConfig: self.collect_detailed_traces[0].split(",")) -KVProducer = Literal["kv_producer", "kv_both"] -KVConsumer = Literal["kv_consumer", "kv_both"] -KVRole = Literal[KVProducer, KVConsumer] - - -@config -@dataclass -class KVTransferConfig: - """Configuration for distributed KV cache transfer.""" - - kv_connector: Optional[str] = None - """The KV connector for vLLM to transmit KV caches between vLLM instances. - """ - - engine_id: Optional[str] = None - """The engine id for KV transfers.""" - - kv_buffer_device: Optional[str] = "cuda" - """The device used by kv connector to buffer the KV cache. - Currently only support 'cuda'.""" - - kv_buffer_size: float = 1e9 - """The buffer size for TorchDistributedConnector. Measured in number of - bytes. Recommended value: 1e9 (about 1GB).""" - - kv_role: Optional[KVRole] = None - """Whether this vLLM instance produces, consumes KV cache, or both. Choices - are 'kv_producer', 'kv_consumer', and 'kv_both'.""" - - kv_rank: Optional[int] = None - """The rank of this vLLM instance in the KV cache transfer. Typical value: - 0 for prefill instance, 1 for decode instance. - Currently only 1P1D is supported.""" - - kv_parallel_size: int = 1 - """The number of parallel instances for KV cache transfer. For - PyNcclConnector, this should be 2.""" - - kv_ip: str = "127.0.0.1" - """The KV connector ip, used to build distributed connection.""" - - kv_port: int = 14579 - """The KV connector port, used to build distributed connection.""" - - kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) - """any extra config that the connector may need.""" - - kv_connector_module_path: Optional[str] = None - """The Python module path to dynamically load the KV connector from. - Only supported in V1.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self) -> None: - if self.engine_id is None: - self.engine_id = str(uuid.uuid4()) - - if self.kv_role is not None and self.kv_role not in get_args(KVRole): - raise ValueError(f"Unsupported kv_role: {self.kv_role}. " - f"Supported roles are {get_args(KVRole)}") - - if self.kv_connector is not None and self.kv_role is None: - raise ValueError("Please specify kv_disagg_role when kv_connector " - f"is set, supported roles are {get_args(KVRole)}") - - @property - def is_kv_transfer_instance(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVRole) - - @property - def is_kv_producer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVProducer) - - @property - def is_kv_consumer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVConsumer) - - def get_from_extra_config(self, key, default) -> Any: - return self.kv_connector_extra_config.get(key, default) - - -@config -@dataclass -class KVEventsConfig: - """Configuration for KV event publishing.""" - - enable_kv_cache_events: bool = False - """If True, enable KV cache events for tracking block storage and removal. - Events can be published externally by zmq using the event publisher config. - """ - - publisher: str = "null" - """The publisher to use for publishing kv events. Can be "null", "zmq". - """ - - endpoint: str = "tcp://*:5557" - """The zmq endpoint to use for publishing kv events. - """ - - replay_endpoint: Optional[str] = None - """The zmq endpoint to use for replaying kv events. - """ - - buffer_steps: int = 10_000 - """The number of steps to cache for replay endpoint. Will only save - events from the last N steps for the replay endpoint. - """ - - hwm: int = 100_000 - """The zmq high water mark for the event publisher. After queueing N events, - events will start dropping if the consumer is not keeping up. - """ - - max_queue_size: int = 100_000 - """The maximum number of events to queue while waiting for publishing. - """ - - topic: str = "" - """The topic to use for the event publisher. Consumers can subscribe to - this topic to receive events. - """ - - @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: @@ -3371,8 +2394,9 @@ class VllmConfig: """LoRA configuration.""" speculative_config: Optional[SpeculativeConfig] = None """Speculative decoding configuration.""" - decoding_config: DecodingConfig = field(default_factory=DecodingConfig) - """Decoding configuration.""" + structured_outputs_config: StructuredOutputsConfig = field( + default_factory=StructuredOutputsConfig) + """Structured outputs configuration.""" observability_config: Optional[ObservabilityConfig] = None """Observability configuration.""" quant_config: Optional[QuantizationConfig] = None @@ -3463,8 +2487,8 @@ class VllmConfig: vllm_factors.append(self.speculative_config.compute_hash()) else: vllm_factors.append("None") - if self.decoding_config: - vllm_factors.append(self.decoding_config.compute_hash()) + if self.structured_outputs_config: + vllm_factors.append(self.structured_outputs_config.compute_hash()) else: vllm_factors.append("None") if self.observability_config: @@ -3649,6 +2673,24 @@ class VllmConfig: " Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.cache_config.kv_sharing_fast_prefill: + if not envs.VLLM_USE_V1: + raise NotImplementedError( + "Fast prefill optimization for KV sharing is not supported " + "in V0 currently.") + + if self.speculative_config is not None and \ + self.speculative_config.use_eagle(): + raise NotImplementedError( + "Fast prefill optimization for KV sharing is not " + "compatible with EAGLE as EAGLE requires correct logits " + "for all tokens while fast prefill gives incorrect logits " + "for prompt tokens.") + + logger.warning_once( + "--kv-sharing-fast-prefill requires changes on model side for " + "correctness and to realize prefill savings. ") + if ((not envs.VLLM_USE_V1) and self.lora_config is not None and self.compilation_config.level != CompilationLevel.NO_COMPILATION): @@ -3659,16 +2701,37 @@ class VllmConfig: disable_chunked_prefill_reasons: list[str] = [] - if self.model_config and self.model_config.pooler_config: - pooling_type = self.model_config.pooler_config.pooling_type - if pooling_type is None or pooling_type.lower() != "last": + if self.model_config: + if self.model_config.pooler_config: + pooling_type = self.model_config.pooler_config.pooling_type + if pooling_type is None or pooling_type.lower() != "last": + disable_chunked_prefill_reasons.append( + "Only \"last\" pooling supports chunked " + "prefill and prefix caching; disabling both.") + if not getattr(self.model_config.hf_config, "is_causal", True): + disable_chunked_prefill_reasons.append( + "Only models using causal attention supports chunked " + "prefill and prefix caching; disabling both.") + elif self.model_config.is_encoder_decoder: + self.scheduler_config.max_num_encoder_input_tokens = \ + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) + logger.debug( + "Encoder-decoder model detected: setting " + "`max_num_encoder_input_tokens` to encoder length (%s)", + self.scheduler_config.max_num_encoder_input_tokens) + self.scheduler_config.disable_chunked_mm_input = True disable_chunked_prefill_reasons.append( - "Only \"last\" pooling supports chunked " - "prefill and prefix caching; disabling both.") - elif not getattr(self.model_config.hf_config, "is_causal", True): - disable_chunked_prefill_reasons.append( - "Only models using causal attention supports chunked " - "prefill and prefix caching; disabling both.") + "Encoder-decoder models do not support chunked prefill nor" + " prefix caching; disabling both.") + if (self.model_config.architecture + == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") + != "spawn"): + logger.warning( + "Whisper is known to have issues with " + "forked workers. If startup is hanging, " + "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " + "to 'spawn'.") if disable_chunked_prefill_reasons: for reason in disable_chunked_prefill_reasons: @@ -3712,6 +2775,14 @@ class VllmConfig: "when cudagraph_mode piecewise cudagraphs is used, "\ f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + if self.parallel_config.enable_dbo: + a2a_backend = envs.VLLM_ALL2ALL_BACKEND + assert a2a_backend == "deepep_low_latency", \ + "Microbatching currently only supports the deepep_low_latency "\ + f"all2all backend. {a2a_backend} is not supported. To fix set "\ + "the VLLM_ALL2ALL_BACKEND environment variable to "\ + "deepep_low_latency and install the DeepEP kerenls." + if not self.instance_id: self.instance_id = random_uuid()[:5] @@ -3725,7 +2796,7 @@ class VllmConfig: # logger should only print warning message for hybrid models. As we # can't know whether the model is hybrid or not now, so we don't log # warning message here and will log it later. - if not (current_platform.is_cuda() or current_platform.is_rocm()): + if not current_platform.support_hybrid_kv_cache(): # Hybrid KV cache manager is not supported on non-GPU platforms. self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_transfer_config is not None: @@ -3775,30 +2846,40 @@ class VllmConfig: def _set_cudagraph_sizes(self): """ - cudagraph batchsize padding logic: + vLLM defines the default candidate list of batch sizes for CUDA graph + capture as: - `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible - batch sizes that cudagraph will capture. - - Depending on the engine's configuration of `max_num_seqs`, the - candidate batch sizes to capture cudagraph will shrink to the subset - which just cover the range of `[1, max_num_seqs]`. In the common case, - `max_num_seqs` is 256, and the cudagraph batch sizes will be - `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. - - However, if users specify the cudagraph capture sizes through - compilation config, we will use the specified sizes instead. + ```python + max_graph_size = min(max_num_seqs * 2, 512) + # 1, 2, 4, then multiples of 8 up to max_graph_size + cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size] In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` will be the final sizes to capture cudagraph (in descending order). - During runtime, if batchsize is larger than - `vllm_config.compilation_config.cudagraph_capture_sizes`, - no cudagraph will be used. - If the batch size is no larger than - `vllm_config.compilation_config.cudagraph_capture_sizes`, - we can quickly find the padded graph size for a given batch size by - looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. + These sizes are used to capture and reuse CUDA graphs for + performance-critical paths (e.g., decoding). Capturing enables + significantly faster kernel dispatch by avoiding Python overhead. The + list is then filtered based on `max_num_batched_tokens` (e.g., 8192 on + most GPUs), which controls the total allowed number of tokens in a + batch. Since each sequence may have a variable number of tokens, the + maximum usable batch size will depend on actual sequence lengths. + + Example: + With `max_num_batched_tokens = 8192`, and typical sequences + averaging ~32 tokens, most practical batch sizes fall below 256. + However, the system will still allow capture sizes up to 512 if + shape and memory permit. + + Note: + If users explicitly specify cudagraph capture sizes in the + compilation config, those will override this default logic. + At runtime: + + - If batch size <= one of the `cudagraph_capture_sizes`, the closest + padded CUDA graph will be used. + - If batch size > largest `cudagraph_capture_sizes`, cudagraph will + not be used. """ # calculate the default `batch_size_capture_list` @@ -3892,6 +2973,18 @@ class VllmConfig: SequenceClassificationConfig) SequenceClassificationConfig.verify_and_update_config(self) + if hasattr(self.model_config, "model_weights") and is_runai_obj_uri( + self.model_config.model_weights): + if self.load_config.load_format == "auto": + logger.info("Detected Run:ai model config. " + "Overriding `load_format` to 'runai_streamer'") + self.load_config.load_format = "runai_streamer" + elif self.load_config.load_format != "runai_streamer": + raise ValueError(f"To load a model from S3, 'load_format' " + f"must be 'runai_streamer', " + f"but got '{self.load_config.load_format}'. " + f"Model: {self.model_config.model}") + def __str__(self): return ( f"model={self.model_config.model!r}, " @@ -3900,7 +2993,6 @@ class VllmConfig: f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, " f"tokenizer_mode={self.model_config.tokenizer_mode}, " f"revision={self.model_config.revision}, " - f"override_neuron_config={self.model_config.override_neuron_config}, " # noqa f"tokenizer_revision={self.model_config.tokenizer_revision}, " f"trust_remote_code={self.model_config.trust_remote_code}, " f"dtype={self.model_config.dtype}, " @@ -3909,12 +3001,13 @@ class VllmConfig: f"load_format={self.load_config.load_format}, " f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}, " # noqa f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa f"quantization={self.model_config.quantization}, " f"enforce_eager={self.model_config.enforce_eager}, " f"kv_cache_dtype={self.cache_config.cache_dtype}, " f"device_config={self.device_config.device}, " - f"decoding_config={self.decoding_config!r}, " + f"structured_outputs_config={self.structured_outputs_config!r}, " f"observability_config={self.observability_config!r}, " f"seed={self.model_config.seed}, " f"served_model_name={self.model_config.served_model_name}, " @@ -4007,7 +3100,7 @@ def contains_object_print(text): Check if the text looks like a printed Python object, e.g. contains any substring matching the pattern: "at 0xFFFFFFF>" We match against 0x followed by 2-16 hex chars (there's - a max of 16 on a 64 bit system). + a max of 16 on a 64-bit system). Args: text (str): The text to check diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 79761e7844859..4c4e39c37ee50 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -24,7 +24,7 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] -PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] +PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @config @@ -33,9 +33,8 @@ class CacheConfig: """Configuration for the KV cache.""" block_size: SkipValidation[BlockSize] = None # type: ignore - """Size of a contiguous cache block in number of tokens. This is ignored on - neuron devices and set to `--max-model-len`. On CUDA devices, only block - sizes up to 32 are supported. On HPU devices, block size defaults to 128. + """Size of a contiguous cache block in number of tokens. On CUDA devices, + only block sizes up to 32 are supported. This config has no static default. If left unspecified by the user, it will be set in `Platform.check_and_update_config()` based on the current @@ -64,17 +63,12 @@ class CacheConfig: """Sliding window size for the KV cache. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" enable_prefix_caching: Optional[bool] = None - """Whether to enable prefix caching. Disabled by default for V0. Enabled by - default for V1.""" - prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Whether to enable prefix caching. Enabled by default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" """Set the hash algorithm for prefix caching:\n - - "builtin" is Python's built-in hash.\n - - "sha256" is collision resistant but with certain overheads. - This option uses Pickle for object serialization before hashing.\n - - "sha256_cbor_64bit" provides a reproducible, cross-language compatible - hash. It serializes objects using canonical CBOR and hashes them with - SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256 - digest.""" + - "sha256" uses Pickle for object serialization before hashing.\n + - "sha256_cbor" provides a reproducible, cross-language compatible hash. It + serializes objects using canonical CBOR and hashes them with SHA-256.""" cpu_offload_gb: float = 0 """The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to @@ -119,6 +113,15 @@ class CacheConfig: necessary for implementing this optimization in some models (e.g. Gemma3n) """ + kv_cache_memory_bytes: Optional[int] = None + """Size of KV Cache per GPU in bytes. By default, this is set to None + and vllm can automatically infer the kv cache size based on + gpu_memory_utilization. However, users may want to manually specify + the kv cache memory size. kv_cache_memory_bytes allows more fine-grain + control of how much memory gets used when compared with using + gpu_memory_memory_utilization. Note that kv_cache_memory_bytes + (when not-None) ignores gpu_memory_utilization""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -145,19 +148,12 @@ class CacheConfig: self._verify_cache_dtype() self._verify_prefix_caching() - self._verify_kv_sharing_fast_prefill() def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info return {key: str(value) for key, value in self.__dict__.items()} - def _verify_kv_sharing_fast_prefill(self) -> None: - if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1: - raise NotImplementedError( - "Fast prefill optimization for KV sharing is not supported " - "in V0 currently.") - @model_validator(mode='after') def _verify_args(self) -> Self: if self.cpu_offload_gb < 0: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 5c3b220016360..f8ccc20222615 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -234,7 +234,7 @@ class CompilationConfig: - FULL_AND_PIECEWISE. PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph - incompatiable ops (i.e. some attention ops) outside the cudagraph + incompatible ops (i.e. some attention ops) outside the cudagraph for general flexibility. This is the default mode. @@ -340,6 +340,8 @@ class CompilationConfig: "vllm.mamba_mixer", "vllm.short_conv", "vllm.linear_attention", + "vllm.plamo2_mamba_mixer", + "vllm.gdn_attention", ] def compute_hash(self) -> str: @@ -545,7 +547,8 @@ class CompilationConfig: # full cudagraph outside the fx graph. This reduces some cpu # overhead when the runtime batch_size is not cudagraph captured. # see https://github.com/vllm-project/vllm/pull/20059 for details. - self.splitting_ops = self._attention_ops + # make a copy to avoid mutating the class-level list via reference. + self.splitting_ops = list(self._attention_ops) elif len(self.splitting_ops) == 0: logger.warning_once("Using piecewise compilation with empty " "splitting_ops.") @@ -560,6 +563,18 @@ class CompilationConfig: self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] + if envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput": + # exclude MoE dispatch/combine from capture by ensuring + # piecewise splitting includes them, so communication remains + # outside CUDA graphs while compute can still be graphed. + moe_ops = [ + "vllm.moe_forward", + "vllm.moe_forward_shared", + ] + for op in moe_ops: + if op not in self.splitting_ops: + self.splitting_ops.append(op) + def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( op in self.splitting_ops for op in self._attention_ops) diff --git a/vllm/config/kv_events.py b/vllm/config/kv_events.py new file mode 100644 index 0000000000000..1c6bdffa1281d --- /dev/null +++ b/vllm/config/kv_events.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class KVEventsConfig: + """Configuration for KV event publishing.""" + + enable_kv_cache_events: bool = False + """If True, enable KV cache events for tracking block storage and removal. + Events can be published externally by zmq using the event publisher config. + """ + + publisher: str = "null" + """The publisher to use for publishing kv events. Can be "null", "zmq". + """ + + endpoint: str = "tcp://*:5557" + """The zmq endpoint to use for publishing kv events. + """ + + replay_endpoint: Optional[str] = None + """The zmq endpoint to use for replaying kv events. + """ + + buffer_steps: int = 10_000 + """The number of steps to cache for replay endpoint. Will only save + events from the last N steps for the replay endpoint. + """ + + hwm: int = 100_000 + """The zmq high water mark for the event publisher. After queueing N events, + events will start dropping if the consumer is not keeping up. + """ + + max_queue_size: int = 100_000 + """The maximum number of events to queue while waiting for publishing. + """ + + topic: str = "" + """The topic to use for the event publisher. Consumers can subscribe to + this topic to receive events. + """ diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py new file mode 100644 index 0000000000000..9abf4acacfe81 --- /dev/null +++ b/vllm/config/kv_transfer.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import uuid +from dataclasses import field +from typing import Any, Literal, Optional, get_args + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + +KVProducer = Literal["kv_producer", "kv_both"] +KVConsumer = Literal["kv_consumer", "kv_both"] +KVRole = Literal[KVProducer, KVConsumer] + + +@config +@dataclass +class KVTransferConfig: + """Configuration for distributed KV cache transfer.""" + + kv_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ + + engine_id: Optional[str] = None + """The engine id for KV transfers.""" + + kv_buffer_device: Optional[str] = "cuda" + """The device used by kv connector to buffer the KV cache. + Currently only support 'cuda'.""" + + kv_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" + + kv_role: Optional[KVRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'kv_producer', 'kv_consumer', and 'kv_both'.""" + + kv_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for prefill instance, 1 for decode instance. + Currently only 1P1D is supported.""" + + kv_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + P2pNcclConnector, this should be 2.""" + + kv_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" + + kv_port: int = 14579 + """The KV connector port, used to build distributed connection.""" + + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" + + kv_connector_module_path: Optional[str] = None + """The Python module path to dynamically load the KV connector from. + Only supported in V1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.engine_id is None: + self.engine_id = str(uuid.uuid4()) + + if self.kv_role is not None and self.kv_role not in get_args(KVRole): + raise ValueError(f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + f"is set, supported roles are {get_args(KVRole)}") + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVRole) + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVProducer) + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVConsumer) + + def get_from_extra_config(self, key, default) -> Any: + return self.kv_connector_extra_config.get(key, default) diff --git a/vllm/config/load.py b/vllm/config/load.py new file mode 100644 index 0000000000000..26ffec23ad5c6 --- /dev/null +++ b/vllm/config/load.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from dataclasses import field +from typing import TYPE_CHECKING, Any, Optional, Union + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.model_executor.model_loader import LoadFormats + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +else: + LoadFormats = Any + TensorizerConfig = Any + +logger = init_logger(__name__) + + +@config +@dataclass +class LoadConfig: + """Configuration for loading the model weights.""" + + load_format: Union[str, LoadFormats] = "auto" + """The format of the model weights to load:\n + - "auto" will try to load the weights in the safetensors format and fall + back to the pytorch bin format if safetensors format is not available.\n + - "pt" will load the weights in the pytorch bin format.\n + - "safetensors" will load the weights in the safetensors format.\n + - "npcache" will load the weights in pytorch format and store a numpy cache + to speed up the loading.\n + - "dummy" will initialize the weights with random values, which is mainly + for profiling.\n + - "tensorizer" will use CoreWeave's tensorizer library for fast weight + loading. See the Tensorize vLLM Model script in the Examples section for + more information.\n + - "runai_streamer" will load the Safetensors weights using Run:ai Model + Streamer.\n + - "bitsandbytes" will load the weights using bitsandbytes quantization.\n + - "sharded_state" will load weights from pre-sharded checkpoint files, + supporting efficient loading of tensor-parallel models.\n + - "gguf" will load weights from GGUF format files (details specified in + https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n + - "mistral" will load weights from consolidated safetensors files used by + Mistral models. + - Other custom values can be supported via plugins.""" + download_dir: Optional[str] = None + """Directory to download and load the weights, default to the default + cache directory of Hugging Face.""" + safetensors_load_strategy: str = "lazy" + """Specifies the loading strategy for safetensors weights. + - "lazy" (default): Weights are memory-mapped from the file. This enables + on-demand loading and is highly efficient for models on local storage. + - "eager": The entire file is read into CPU memory upfront before loading. + This is recommended for models on network filesystems (e.g., Lustre, NFS) + as it avoids inefficient random reads, significantly speeding up model + initialization. However, it uses more CPU RAM. + """ + model_loader_extra_config: Union[dict, TensorizerConfig] = field( + default_factory=dict) + """Extra config for model loader. This will be passed to the model loader + corresponding to the chosen load_format.""" + device: Optional[str] = None + """Device to which model weights will be loaded, default to + device_config.device""" + ignore_patterns: Optional[Union[list[str], str]] = None + """The list of patterns to ignore when loading the model. Default to + "original/**/*" to avoid repeated loading of llama's checkpoints.""" + use_tqdm_on_load: bool = True + """Whether to enable tqdm for showing progress bar when loading model + weights.""" + pt_load_map_location: Union[str, dict[str, str]] = "cpu" + """ + pt_load_map_location: the map location for loading pytorch checkpoint, to + support loading checkpoints can only be loaded on certain devices like + "cuda", this is equivalent to {"": "cuda"}. Another supported format is + mapping from different devices like from GPU 1 to GPU 0: + {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings + in dictionary needs to be double quoted for json parsing. For more details, + see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + self.load_format = self.load_format.lower() + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] diff --git a/vllm/config/lora.py b/vllm/config/lora.py new file mode 100644 index 0000000000000..3fe28f5dad4fa --- /dev/null +++ b/vllm/config/lora.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union + +import torch +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass + +import vllm.envs as envs +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.config.cache import CacheConfig +else: + ModelConfig = Any + CacheConfig = Any + +logger = init_logger(__name__) + +LoRADType = Literal["auto", "float16", "bfloat16"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class LoRAConfig: + """Configuration for LoRA.""" + + max_lora_rank: int = 16 + """Max LoRA rank.""" + max_loras: int = 1 + """Max number of LoRAs in a single batch.""" + fully_sharded_loras: bool = False + """By default, only half of the LoRA computation is sharded with tensor + parallelism. Enabling this will use the fully sharded layers. At high + sequence length, max rank or tensor parallel size, this is likely faster. + """ + max_cpu_loras: Optional[int] = None + """Maximum number of LoRAs to store in CPU memory. Must be >= than + `max_loras`.""" + lora_dtype: Union[torch.dtype, LoRADType] = "auto" + """Data type for LoRA. If auto, will default to base model dtype.""" + lora_extra_vocab_size: int = 256 + """(Deprecated) Maximum size of extra vocabulary that can be present in a + LoRA adapter. Will be removed in v0.12.0.""" + lora_vocab_padding_size: ClassVar[int] = current_platform\ + .get_lora_vocab_padding_size() + default_mm_loras: Optional[dict[str, str]] = None + """Dictionary mapping specific modalities to LoRA model paths; this field + is only applicable to multimodal models and should be leveraged when a + model always expects a LoRA to be active when a given modality is present. + Note that currently, if a request provides multiple additional + modalities, each of which have their own LoRA, we do NOT apply + default_mm_loras because we currently only support one lora adapter + per prompt. When run in offline mode, the lora IDs for n modalities + will be automatically assigned to 1-n with the names of the modalities + in alphabetic order.""" + bias_enabled: bool = False + """[DEPRECATED] Enable bias for LoRA adapters. This option will be + removed in v0.12.0.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.max_lora_rank) + factors.append(self.max_loras) + factors.append(self.fully_sharded_loras) + factors.append(self.lora_dtype) + factors.append(self.lora_extra_vocab_size) + factors.append(self.lora_vocab_padding_size) + factors.append(self.bias_enabled) + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + # Deprecation warning for lora_extra_vocab_size + logger.warning( + "`lora_extra_vocab_size` is deprecated and will be removed " + "in v0.12.0. Additional vocabulary support for " + "LoRA adapters is being phased out.") + + # Deprecation warning for enable_lora_bias + if self.bias_enabled: + logger.warning("`enable_lora_bias` is deprecated " + "and will be removed in v0.12.0.") + + # Setting the maximum rank to 512 should be able to satisfy the vast + # majority of applications. + possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) + possible_lora_extra_vocab_size = (256, 512) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_loras ({self.max_loras})") + + def verify_with_cache_config(self, cache_config: CacheConfig): + if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: + raise ValueError( + "V0 LoRA does not support CPU offload, please use V1.") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py new file mode 100644 index 0000000000000..1b93b520f33f9 --- /dev/null +++ b/vllm/config/multimodal.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from collections.abc import Mapping +from dataclasses import field +from typing import Any, Literal, Optional + +from pydantic.dataclasses import dataclass + +import vllm.envs as envs +from vllm.config.utils import config + +MMEncoderTPMode = Literal["weights", "data"] +MMCacheType = Literal["shm", "lru"] + + +@config +@dataclass +class MultiModalConfig: + """Controls the behavior of multimodal models.""" + + limit_per_prompt: dict[str, int] = field(default_factory=dict) + """The maximum number of input items allowed per prompt for each modality. + Defaults to 1 (V0) or 999 (V1) for each modality. + + For example, to allow up to 16 images and 2 videos per prompt: + `{"image": 16, "video": 2}`""" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + """Additional args passed to process media inputs, keyed by modalities. + For example, to set num_frames for video, set + `--media-io-kwargs '{"video": {"num_frames": 40} }'`""" + mm_processor_kwargs: Optional[dict[str, object]] = None + """Arguments to be forwarded to the model's processor for multi-modal data, + e.g., image processor. Overrides for the multi-modal processor obtained + from `transformers.AutoProcessor.from_pretrained`. + + The available overrides depend on the model that is being run. + + For example, for Phi-3-Vision: + `{"num_crops": 4}`.""" + mm_processor_cache_gb: float = 4 + """The size (in GiB) of the multi-modal processor cache, which is used to + avoid re-processing past multi-modal inputs. + + This cache is duplicated for each API process and engine core process, + resulting in a total memory usage of + `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. + + Set to `0` to disable this cache completely (not recommended).""" + mm_processor_cache_type: MMCacheType = "lru" + """Type of cache to use for the multi-modal preprocessor/mapper. If `shm`, + use shared memory FIFO cache. If `lru`, use mirrored LRU cache.""" + mm_shm_cache_max_object_size_mb: int = 128 + """Size limit (in MiB) for each object stored in the multi-modal processor + shared memory cache. Only effective when `mm_processor_cache_type` is + `"shm"`.""" + mm_encoder_tp_mode: MMEncoderTPMode = "weights" + """Indicates how to optimize multi-modal encoder inference using tensor + parallelism (TP). + + - `"weights"`: Within the same vLLM engine, split the weights of + each layer across TP ranks. (default TP behavior)\n + - `"data"`: Within the same vLLM engine, split the batched input data + across TP ranks to process the data in parallel, while hosting + the full weights on each TP rank. + This batch-level DP is not to be confused with API request-level + DP (which is controlled by `--data-parallel-size`). + This is only supported on a per-model basis and falls back to + `"weights"` if the encoder does not support DP.""" + interleave_mm_strings: bool = False + """Enable fully interleaved support for multimodal prompts, while using + --chat-template-content-format=string.""" + skip_mm_profiling: bool = False + """When enabled, skips multimodal memory profiling and only profiles with + language backbone model during engine initialization. + + This reduces engine startup time but shifts the responsibility to users for + estimating the peak memory usage of the activation of multimodal encoder and + embedding cache.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def get_limit_per_prompt(self, modality: str) -> int: + """ + Get the maximum number of input items allowed per prompt + for the given modality. + """ + return self.limit_per_prompt.get( + modality, + 999 if envs.VLLM_USE_V1 else 1, + ) + + def merge_mm_processor_kwargs( + self, + inference_kwargs: Mapping[str, object], + ) -> dict[str, object]: + """ + Get the keyword arguments to pass to the multi-modal processor + according to the extra arguments passed during inference. + """ + kwargs = self.mm_processor_kwargs or {} + return kwargs | dict(inference_kwargs) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 9ea883d4a03cd..8e92e54a96780 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -29,6 +29,7 @@ else: logger = init_logger(__name__) +ExpertPlacementStrategy = Literal["linear", "round_robin"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] @@ -87,7 +88,7 @@ class ParallelConfig: data_parallel_external_lb: bool = False """Whether to use "external" DP LB mode. Applies only to online serving and when data_parallel_size > 0. This is useful for a "one-pod-per-rank" - wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank + wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank is provided explicitly to vllm serve.""" data_parallel_hybrid_lb: bool = False """Whether to use "hybrid" DP LB mode. Applies only to online serving @@ -102,6 +103,15 @@ class ParallelConfig: """Enable expert parallelism load balancing for MoE layers.""" eplb_config: EPLBConfig = field(default_factory=EPLBConfig) """Expert parallelism configuration.""" + expert_placement_strategy: ExpertPlacementStrategy = "linear" + """The expert placement strategy for MoE layers:\n + - "linear": Experts are placed in a contiguous manner. For example, with 4 + experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have + experts [2, 3].\n + - "round_robin": Experts are placed in a round-robin manner. For example, + with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1 + will have experts [1, 3]. This strategy can help improve load balancing + for grouped expert models with no redundant experts.""" num_redundant_experts: Optional[int] = None """`num_redundant_experts` is deprecated and has been replaced with `eplb_config.num_redundant_experts`. This will be removed in v0.12.0. @@ -127,6 +137,14 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + enable_dbo: bool = False + """Enable microbatching for the model executor.""" + + dbo_decode_token_threshold: int = 32 + """The threshold for microbatching. If the number of tokens in the + request is greater than this threshold, microbatching will be used. + Otherwise, the request will be processed in a single batch.""" + ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" @@ -170,6 +188,11 @@ class ParallelConfig: Set to be private as it's not intended to be configured by users. """ + decode_context_parallel_size: int = 1 + """Number of decode context parallel groups, because the world size does + not change by dcp, it simply reuse the GPUs of TP group, and tp_size + needs to be divisible by dcp_size.""" + @property def world_size_across_dp(self) -> int: """world_size_across_dp is TPxPPxDP, it is the size of the world @@ -363,8 +386,10 @@ class ParallelConfig: else: if self.eplb_config.num_redundant_experts != 0: raise ValueError( - "num_redundant_experts should be used with EPLB." - f"{self.eplb_config.num_redundant_experts}.") + "num_redundant_experts is set to " + f"{self.eplb_config.num_redundant_experts} but EPLB is not " + "enabled. Either enable EPLB or unset " + "num_redundant_experts.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. @@ -372,10 +397,7 @@ class ParallelConfig: from vllm.executor import ray_utils backend: DistributedExecutorBackend = "mp" ray_found = ray_utils.ray_is_available() - if current_platform.is_neuron(): - # neuron uses single process to control multiple devices - backend = "uni" - elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: + if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: backend = "uni" elif (current_platform.is_cuda() and cuda_device_count_stateless() < self.world_size): diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py new file mode 100644 index 0000000000000..2c861723c3966 --- /dev/null +++ b/vllm/config/speculative.py @@ -0,0 +1,559 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +import hashlib +from typing import TYPE_CHECKING, Any, Literal, Optional + +from pydantic import SkipValidation, model_validator +from pydantic.dataclasses import dataclass +from typing_extensions import Self + +import vllm.envs as envs +from vllm.config.parallel import ParallelConfig +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.utils import LazyLoader + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + import vllm.model_executor.layers.quantization as me_quant + from vllm.config import ModelConfig +else: + PretrainedConfig = Any + ModelConfig = Any + + me_quant = LazyLoader("model_executor", globals(), + "vllm.model_executor.layers.quantization") + +logger = init_logger(__name__) + +SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", + "mlp_speculator", "draft_model", "deepseek_mtp", + "ernie_mtp", "qwen3_next_mtp", "mimo_mtp"] + + +@config +@dataclass +class SpeculativeConfig: + """Configuration for speculative decoding.""" + + # General speculative decoding control + num_speculative_tokens: SkipValidation[int] = None # type: ignore + """The number of speculative tokens, if provided. It will default to the + number in the draft model config if present, otherwise, it is required.""" + model: Optional[str] = None + """The name of the draft model, eagle head, or additional weights, if + provided.""" + method: Optional[SpeculativeMethod] = None + """The name of the speculative method to use. If users provide and set the + `model` param, the speculative method type will be detected automatically + if possible, if `model` param is not provided, the method name must be + provided. + + If using `ngram` method, the related configuration `prompt_lookup_max` and + `prompt_lookup_min` should be considered.""" + draft_tensor_parallel_size: Optional[int] = None + """The degree of the tensor parallelism for the draft model. Can only be 1 + or the same as the target model's tensor parallel size.""" + disable_logprobs: bool = True + """If set to True, token log probabilities are not returned during + speculative decoding. If set to False, token log probabilities are returned + according to the log probability settings in SamplingParams.""" + + # Draft model configuration + quantization: Optional[me_quant.QuantizationMethods] = None + """Quantization method that was used to quantize the draft model weights. + If `None`, we assume the model weights are not quantized. Note that it only + takes effect when using the draft model-based speculative method.""" + max_model_len: Optional[int] = None + """The maximum model length of the draft model. Used when testing the + ability to skip speculation for some sequences.""" + revision: Optional[str] = None + """The specific model version to use for the draft model. It can be a + branch name, a tag name, or a commit id. If unspecified, will use the + default version.""" + code_revision: Optional[str] = None + """The specific revision to use for the draft model code on Hugging Face + Hub. It can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version.""" + + # Advanced control + disable_by_batch_size: Optional[int] = None + """Disable speculative decoding for new incoming requests when the number + of enqueued requests is larger than this value, if provided.""" + disable_padded_drafter_batch: bool = False + """Disable input padding for speculative decoding. If set to True, + speculative input batches can contain sequences of different lengths, + which may only be supported by certain attention backends. This currently + only affects the EAGLE method of speculation.""" + + # Ngram proposer configuration + prompt_lookup_max: Optional[int] = None + """Maximum size of ngram token window when using Ngram proposer, required + when method is set to ngram.""" + prompt_lookup_min: Optional[int] = None + """Minimum size of ngram token window when using Ngram proposer, if + provided. Defaults to 1.""" + + speculative_token_tree: Optional[str] = None + """Specifies the tree structure for speculative token generation. + """ + # required configuration params passed from engine + target_model_config: SkipValidation[ModelConfig] = None # type: ignore + """The configuration of the target model.""" + target_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore + """The parallel configuration for the target model.""" + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore + """Whether vLLM is configured to use chunked prefill or not. Used for + raising an error since it's not yet compatible with speculative decode.""" + disable_log_stats: SkipValidation[bool] = None # type: ignore + """Whether to disable the periodic printing of stage times in speculative + decoding.""" + + # params generated in the post-init stage + draft_model_config: SkipValidation[ModelConfig] = None # type: ignore + """The configuration of the draft model initialized internal.""" + draft_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore + """The parallel configuration for the draft model initialized internal.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + # Eagle3 affects the computation graph because it returns intermediate + # hidden states in addition to the final hidden state. + factors.append(self.method == "eagle3") + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + @staticmethod + def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type == "deepseek_v3": + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + + if hf_config.architectures[0] == "MiMoForCausalLM": + hf_config.model_type = "mimo_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["MiMoMTPModel"] + }) + + if hf_config.architectures[0] == "Glm4MoeForCausalLM": + hf_config.model_type = "glm4_moe_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["Glm4MoeMTPModel"] + }) + + if hf_config.model_type == "ernie4_5_moe": + hf_config.model_type = "ernie_mtp" + if hf_config.model_type == "ernie_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["ErnieMTPModel"] + }) + + if hf_config.model_type == "qwen3_next": + hf_config.model_type = "qwen3_next_mtp" + if hf_config.model_type == "qwen3_next_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["Qwen3NextMTP"] + }) + + return hf_config + + def __post_init__(self): + + # Note: "method" is a new parameter that helps to extend the + # configuration of non-model-based proposers, and the "model" parameter + # will be used to set the draft model, eagle head, or additional weight + # when needed. If users do not specify "method", the speculative method + # will be detected automatically if possible. If the speculative method + # can not be detected, it will be considered as the "draft_model" by + # default. + + if self.model is None and self.num_speculative_tokens is not None: + # TODO(Shangming): Refactor mtp configuration logic when supporting + # mtp acceleration for more models besides deepseek_v3 + if self.target_model_config and \ + (self.target_model_config.hf_text_config.model_type \ + == "deepseek_v3" or + self.target_model_config.hf_text_config.model_type in + ("mimo","ernie4_5_moe", "qwen3_next")): + # use the draft model from the same model: + self.model = self.target_model_config.model + # Align the quantization of draft model for cases such as + # --quantization fp8 with a bf16 checkpoint. + if not self.quantization: + self.quantization = self.target_model_config.quantization + elif self.method in ("ngram", "[ngram]"): + self.model = "ngram" + else: + raise ValueError("num_speculative_tokens was provided without " + "speculative model.") + + # Automatically configure the method for ngram when "model" is used + # instead of "method" + if self.method is None and (self.model is not None + and self.model in ("ngram", "[ngram]")): + self.method = "ngram" + + if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" + # Set default values if not provided + if (self.prompt_lookup_min is None + and self.prompt_lookup_max is None): + # TODO(woosuk): Tune these values. They are arbitrarily chosen. + self.prompt_lookup_min = 5 + self.prompt_lookup_max = 5 + elif self.prompt_lookup_min is None: + assert self.prompt_lookup_max is not None + self.prompt_lookup_min = self.prompt_lookup_max + elif self.prompt_lookup_max is None: + assert self.prompt_lookup_min is not None + self.prompt_lookup_max = self.prompt_lookup_min + + # Validate values + if self.prompt_lookup_min < 1: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") + if self.prompt_lookup_max < 1: + raise ValueError( + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must " + f"be <= prompt_lookup_max={self.prompt_lookup_max}") + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + else: + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if self.model is not None: + # TODO: Move this import to the top once `ModelConfig` + # lives in `vllm.config.model`. + from vllm.config import ModelConfig + self.draft_model_config = ModelConfig( + model=self.model, + runner="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config. + trust_remote_code, + allowed_local_media_path=self.target_model_config. + allowed_local_media_path, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config. + tokenizer_revision, + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=self.target_model_config.enforce_eager, + max_seq_len_to_capture=self.target_model_config. + max_seq_len_to_capture, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) + + # Automatically detect the method + if self.method in ('eagle', 'eagle3'): + pass + # examples: + # yuhuili/EAGLE-LLaMA3-Instruct-8B + # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B + # AngelSlim/Qwen3-8B_eagle3 + elif "eagle-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif "eagle3" in self.draft_model_config.model.lower(): + self.method = "eagle3" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type + in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): + self.method = "deepseek_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Deepseek MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type == + "ernie_mtp"): + self.method = "ernie_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Ernie MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + elif (self.draft_model_config.hf_config.model_type == + "qwen3_next_mtp"): + self.method = "qwen3_next_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Qwen3Next MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + else: + self.method = "draft_model" + raise NotImplementedError( + "Speculative decoding with draft model is not " + "supported yet. Please consider using other " + "speculative decoding methods such as ngram, medusa, " + "eagle, or deepseek_mtp.") + + # Replace hf_config for EAGLE draft_model + if self.method in ("eagle", "eagle3"): + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: + raise ValueError( + "Chunked prefill and EAGLE are not compatible " + "when using V0.") + + from vllm.transformers_utils.configs import ( + SpeculatorsConfig) + from vllm.transformers_utils.configs.eagle import ( + EAGLEConfig) + + if isinstance(self.draft_model_config.hf_config, + (EAGLEConfig, SpeculatorsConfig)): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config, + method=self.method, + model_type="eagle") + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + n_predict = getattr(self.draft_model_config.hf_config, + "n_predict", None) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}") + + if self.speculative_token_tree is None: + # Generate chain of tokens. + self.speculative_token_tree = str([ + (i + 1) * (0, ) + for i in range(self.num_speculative_tokens) + ]) + else: + # Sort the token tree breadth-first. + tree_choices = ast.literal_eval( + self.speculative_token_tree) + self.speculative_token_tree = str( + sorted(tree_choices, key=lambda t: (len(t), t))) + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config + ) + + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + )) + + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, + self.draft_tensor_parallel_size)) + + @staticmethod + def _maybe_override_draft_max_model_len( + speculative_max_model_len: Optional[int], + draft_max_model_len: int, + target_max_model_len: int, + ) -> int: + """Determine the max sequence len for the draft model. This is usually + the draft_max_model_len, but may be the target_max_model_len if it is + less than the draft_max_model_len, or may be speculative_max_model_len + if it is specified. + + This is necessary so that sequences do not exceed the capacity of the + draft model or the target model. + + speculative_max_model_len is mainly used for testing that sequences can + skip speculation. + """ + + if speculative_max_model_len is not None: + + if speculative_max_model_len > draft_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}") + + if speculative_max_model_len > target_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}") + + return speculative_max_model_len + + return min( + draft_max_model_len, + target_max_model_len, + ) + + @staticmethod + def _verify_and_get_draft_tp( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig) -> int: + """ + Verifies and adjusts the tensor parallel size for a draft model + specified using speculative_draft_tensor_parallel_size. + """ + # If speculative_draft_tensor_parallel_size is unset then set it + # appropriately else verify that it is set correctly. + if speculative_draft_tensor_parallel_size is None: + if draft_hf_config.model_type == "mlp_speculator": + speculative_draft_tensor_parallel_size = 1 + if target_parallel_config.tensor_parallel_size > 1: + logger.warning( + "%s cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1", + draft_hf_config.model_type) + else: + speculative_draft_tensor_parallel_size = \ + target_parallel_config.tensor_parallel_size + elif speculative_draft_tensor_parallel_size not in ( + 1, target_parallel_config.tensor_parallel_size): + raise ValueError( + f"{speculative_draft_tensor_parallel_size=} cannot be " + f"other value than 1 or target model tensor_parallel_size") + return speculative_draft_tensor_parallel_size + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: int, + ) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config, except the tp_size. + """ + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=target_parallel_config. + pipeline_parallel_size, + tensor_parallel_size=speculative_draft_tensor_parallel_size, + distributed_executor_backend=target_parallel_config. + distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config. + max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config. + disable_custom_all_reduce, + ray_workers_use_nsight=target_parallel_config. + ray_workers_use_nsight, + placement_group=target_parallel_config.placement_group, + ) + + return draft_parallel_config + + @model_validator(mode='after') + def _verify_args(self) -> Self: + if self.num_speculative_tokens is None: + raise ValueError( + "num_speculative_tokens must be provided with " + "speculative model unless the draft model config contains an " + "n_predict parameter.") + + if self.num_speculative_tokens <= 0: + raise ValueError("Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens}).") + + if self.draft_model_config: + self.draft_model_config.verify_with_parallel_config( + self.draft_parallel_config) + + if (self.disable_by_batch_size is not None + and self.disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{self.disable_by_batch_size=}") + + eagle3_target_supported = ["llama", "qwen"] + if self.method == "eagle3" and self.target_model_config and not any( + supported_model in + self.target_model_config.hf_text_config.model_type + for supported_model in eagle3_target_supported): + raise ValueError( + f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 + f"Got {self.target_model_config.hf_text_config.model_type=}") + + return self + + @property + def num_lookahead_slots(self) -> int: + """The number of additional slots the scheduler should allocate per + step, in addition to the slots allocated for each known token. + + This is equal to the number of speculative tokens, as each speculative + token must be scored. + """ + return self.num_speculative_tokens + + def use_eagle(self) -> bool: + return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp", + "qwen3_next_mtp") + + def __repr__(self) -> str: + method = self.method + model = None if method == "ngram" else self.draft_model_config.model + num_spec_tokens = self.num_speculative_tokens + return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py new file mode 100644 index 0000000000000..b1f14294510f8 --- /dev/null +++ b/vllm/config/structured_outputs.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import Any, Literal + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + +StructuredOutputsBackend = Literal["auto", "xgrammar", "guidance", "outlines", + "lm-format-enforcer"] + + +@config +@dataclass +class StructuredOutputsConfig: + """Dataclass which contains structured outputs config for the engine.""" + + backend: StructuredOutputsBackend = "auto" + """Which engine will be used for structured outputs (e.g. JSON schema, + regex, etc) by default. With "auto", we will make opinionated choices + based on request contents and what the backend libraries currently support, + so the behavior is subject to change in each release.""" + disable_fallback: bool = False + """If `True`, vLLM will not fallback to a different backend on error.""" + disable_any_whitespace: bool = False + """If `True`, the model will not generate any whitespace during structured + outputs. This is only supported for xgrammar and guidance backends.""" + disable_additional_properties: bool = False + """If `True`, the `guidance` backend will not use `additionalProperties` + in the JSON schema. This is only supported for the `guidance` backend and + is used to better align its behaviour with `outlines` and `xgrammar`.""" + reasoning_parser: str = "" + """Select the reasoning parser depending on the model that you're using. + This is used to parse the reasoning content into OpenAI API format.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if (self.disable_any_whitespace + and self.backend not in ("xgrammar", "guidance")): + raise ValueError("disable_any_whitespace is only supported for " + "xgrammar and guidance backends.") + if (self.disable_additional_properties and self.backend != "guidance"): + raise ValueError("disable_additional_properties is only supported " + "for the guidance backend.") diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 98fbeb1fa86aa..db8c05ef8be4a 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import MISSING, Field, field, fields, is_dataclass from typing import TYPE_CHECKING, TypeVar if TYPE_CHECKING: @@ -27,3 +28,20 @@ def config(cls: ConfigT) -> ConfigT: script, which is invoked during the pre-commit checks. """ return cls + + +def get_field(cls: ConfigType, name: str) -> Field: + """Get the default factory field of a dataclass by name. Used for getting + default factory fields in `EngineArgs`.""" + if not is_dataclass(cls): + raise TypeError("The given class is not a dataclass.") + cls_fields = {f.name: f for f in fields(cls)} + if name not in cls_fields: + raise ValueError(f"Field '{name}' not found in {cls.__name__}.") + named_field: Field = cls_fields[name] + if (default_factory := named_field.default_factory) is not MISSING: + return field(default_factory=default_factory) + if (default := named_field.default) is not MISSING: + return field(default=default) + raise ValueError( + f"{cls.__name__}.{name} must have a default value or default factory.") diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 7d9b32cd4b674..ae876d131eb66 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -182,7 +182,7 @@ class NaiveBlockAllocator(BlockAllocator): # Increment refcount for each block. assert block.block_id is not None refcount = self._refcounter.incr(block.block_id) - assert refcount != 1, "can't fork free'd block" + assert refcount != 1, "can't fork freed block" forked_block = self._block_pool.init_block( prev_block=prev_block, diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py index 7a4a836ee348e..85ff6bc9ca610 100644 --- a/vllm/core/evictor.py +++ b/vllm/core/evictor.py @@ -58,7 +58,7 @@ class Evictor(ABC): class BlockMetaData: """Data structure for storing key data describe cached block, so that - evitor could use to make its decision which one to choose for eviction + evictor could use to make its decision which one to choose for eviction Here we use physical block id as the dict key, as there maybe several blocks with the same content hash, but their physical id is unique. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d7864293e9647..92ebad778ea4b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -11,7 +11,8 @@ from typing import Callable, Deque, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Union -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import CacheConfig, SchedulerConfig +from vllm.config.lora import LoRAConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 7963fb15c4191..af7ca6be1fca8 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -16,8 +16,11 @@ from typing import Any, Callable, Optional, Union import torch +from vllm.logger import init_logger from vllm.utils import is_pin_memory_available +logger = init_logger(__name__) + def find_loaded_library(lib_name) -> Optional[str]: """ @@ -165,6 +168,9 @@ class CuMemAllocator: py_d_mem = allocation_handle[2] self.pointer_to_data[py_d_mem] = AllocationData( allocation_handle, self.current_tag) + logger.debug( + "Allocated %s bytes for %s with address %s from cumem allocator", + allocation_handle[1], self.current_tag, py_d_mem) return def _python_free_callback(self, ptr: int) -> HandleType: @@ -174,6 +180,9 @@ class CuMemAllocator: data = self.pointer_to_data.pop(ptr) if data.cpu_backup_tensor is not None: data.cpu_backup_tensor = None + logger.debug( + "Freed %s bytes for %s with address %s from cumem allocator", + data.handle[1], data.tag, ptr) return data.handle def sleep( @@ -197,9 +206,14 @@ class CuMemAllocator: assert isinstance(offload_tags, tuple) + total_bytes = 0 + backup_bytes = 0 + for ptr, data in self.pointer_to_data.items(): handle = data.handle + total_bytes += handle[1] if data.tag in offload_tags: + backup_bytes += handle[1] size_in_bytes = handle[1] cpu_backup_tensor = torch.empty( size_in_bytes, @@ -211,6 +225,12 @@ class CuMemAllocator: data.cpu_backup_tensor = cpu_backup_tensor unmap_and_release(handle) + logger.info( + "CuMemAllocator: sleep freed %.2f GiB memory in total, of which " + "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded " + "directly.", total_bytes / 1024**3, backup_bytes / 1024**3, + (total_bytes - backup_bytes) / 1024**3) + gc.collect() torch.cuda.empty_cache() @@ -267,12 +287,17 @@ class CuMemAllocator: # when using pluggable allocator, see # https://github.com/pytorch/pytorch/issues/145168 . # if we have some memory allocated and then freed, - # the memory will not be released. - # right now it is fine, because we only use this allocator - # during weight loading and kv cache creation, where we only - # allocate memory. - # TODO: we need to find a way to release the memory, - # i.e. calling torch.cuda.empty_cache() + # the memory will not be released, e.g. in online quantization, + # where the model is created in higher precision, and then + # quantized in lower precision. + # Find all unused allocations and manually release them. + # TODO: we should expose `empty_cache` method in the memory pool. + # TODO: ask for help from PyTorch team to expose this method. + allocations = data[0].snapshot() + for allocation in allocations: + if allocation["allocated_size"] == 0: + handle = self._python_free_callback(allocation["address"]) + unmap_and_release(handle) self.current_tag = old_tag def get_current_usage(self) -> int: diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 85f87cb21edcd..149df73d8667b 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any +from typing import Any import torch import torch.distributed as dist +from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils import has_deep_ep, has_pplx @@ -13,11 +14,6 @@ from .base_device_communicator import All2AllManagerBase, Cache logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.model_executor.layers.fused_moe.layer import FusedMoE -else: - FusedMoE = None - class NaiveAll2AllManager(All2AllManagerBase): """ @@ -74,6 +70,44 @@ class NaiveAll2AllManager(All2AllManagerBase): pass +class AgRsAll2AllManager(All2AllManagerBase): + """ + An implementation of all2all communication based on + all-gather (dispatch) and reduce-scatter (combine). + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + """ + Gather hidden_states and router_logits from all dp ranks. + """ + sizes = get_forward_context( + ).dp_metadata.get_chunk_sizes_across_dp_rank() + hidden_states, router_logits = get_dp_group().all_gatherv( + [hidden_states, router_logits], + dim=0, + sizes=sizes, + ) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Reduce-scatter hidden_states across all dp ranks. + """ + sizes = get_forward_context( + ).dp_metadata.get_chunk_sizes_across_dp_rank() + hidden_states = get_dp_group().reduce_scatterv(hidden_states, + dim=0, + sizes=sizes) + return hidden_states + + def destroy(self): + pass + + class PPLXAll2AllManager(All2AllManagerBase): """ All2All communication based on PPLX kernels. @@ -256,9 +290,4 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): logger.debug("DeepEP all2all args %s", buffer_kwargs) handle: deep_ep.Buffer = self.handle_cache.get_or_create( buffer_kwargs, deep_ep.Buffer) - # It is dangerous to set num sms outside this function. num_sms is not - # a part of the hash-key that identifies this object. If we are in a - # situation where we make objects with different num_sms, the hash key - # in get_or_create must be updated. - handle.set_num_sms(self.num_sms) return handle diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 5c64e7d5c4ba3..805a88854b77c 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -36,8 +36,8 @@ CUSTOM_ALL_REDUCE_MAX_SIZES = { "10.0": { 2: 2 * MiB, # 2 MB 4: 2 * MiB, # 2 MB - 6: 2 * MiB, # 2 MB - 8: 2 * MiB, # 2 MB + 6: 1 * MiB, # 1 MB + 8: 1 * MiB, # 1 MB } } diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 9131582eef754..01f59b44a0e69 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -252,7 +252,10 @@ class DeviceCommunicatorBase: moe_modules = [ module for module in model.modules() - if module.__class__.__name__ == "FusedMoE" + # TODO(bnell): Should use isinstance but can't. Maybe search for + # presence of quant_method.init_prepare_finalize? + if (module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE") ] for module in moe_modules: module.quant_method.init_prepare_finalize(module) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index eef3f9f75f9f1..b2bf3bc3cc2ed 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -57,11 +57,19 @@ class CudaCommunicator(DeviceCommunicatorBase): self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, + symm_mem_enabled=(self.symm_mem_comm is not None + and not self.symm_mem_comm.disabled), ) if current_platform.is_rocm(): @@ -72,11 +80,6 @@ class CudaCommunicator(DeviceCommunicatorBase): # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) - if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): - self.symm_mem_comm = SymmMemCommunicator( - group=self.cpu_group, - device=self.device, - ) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND @@ -84,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase): from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") + elif all2all_backend == "allgather_reducescatter": + from .all2all import AgRsAll2AllManager + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + logger.info("Using AllGather-ReduceScatter all2all manager.") elif all2all_backend == "pplx": from .all2all import PPLXAll2AllManager self.all2all_manager = PPLXAll2AllManager(self.cpu_group) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 80aca81234eb0..3cc4bbb258244 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -54,13 +54,14 @@ class CustomAllreduce: def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - max_size=8192 * 1024) -> None: + max_size=8192 * 1024, + symm_mem_enabled=False) -> None: """ Args: group: the process group to work on. If None, it will use the default process group. device: the device to bind the CustomAllreduce to. If None, - it will be bind to f"cuda:{local_rank}". + it will be bound to f"cuda:{local_rank}". It is the caller's responsibility to make sure each communicator is bind to a unique device, and all communicators in this group are in the same node. @@ -111,7 +112,7 @@ class CustomAllreduce: self.device = device device_capability = current_platform.get_device_capability( ).as_version_str() - if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM + if (current_platform.is_cuda() and symm_mem_enabled and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): max_size = min( CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], @@ -158,7 +159,7 @@ class CustomAllreduce: self.disabled = False # Buffers memory are owned by this Python class and passed to C++. - # Meta data composes of two parts: meta data for synchronization and a + # Metadata composes of two parts: metadata for synchronization and a # temporary buffer for storing intermediate allreduce results. self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, group=group, diff --git a/vllm/distributed/device_communicators/neuron_communicator.py b/vllm/distributed/device_communicators/neuron_communicator.py deleted file mode 100644 index 5b61a1687a016..0000000000000 --- a/vllm/distributed/device_communicators/neuron_communicator.py +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch - -from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase) -from vllm.platforms import current_platform - -if current_platform.is_neuron(): - import torch_xla.core.xla_model as xm - - -class NeuronCommunicator(DeviceCommunicatorBase): - - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - return xm.all_reduce(xm.REDUCE_SUM, x) - - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: - assert dim == -1, "Neuron only supports dim=-1 for all-gather." - return xm.all_gather(x, dim=dim) diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py new file mode 100644 index 0000000000000..0310fc14da256 --- /dev/null +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -0,0 +1,635 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle +from abc import ABC, abstractmethod +from collections.abc import Iterable +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import chain +from multiprocessing import shared_memory +from multiprocessing.synchronize import Lock as LockType +from typing import Any, Callable, Optional, Union +from unittest.mock import patch + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SingleWriterShmRingBuffer: + """ + A single-writer, multiple-reader ring buffer implementation using shared + memory. This class provides a thread-safe ring buffer where one process + can write data while multiple processes/threads can read from it. + + Architecture: + - Uses shared memory for cross-process communication + - Maintains metadata for each allocated buffer chunk in the writer process + - Supports custom "is_free_fn" functions to determine when buffers can be + reused + - Each buffer chunk contains: `[4-byte id][4-byte size][actual_data]` + + Key Concepts: + - monotonic_id_start/end: Track the range of active buffer IDs + - data_buffer_start/end: Track the physical memory range in use + - Automatic wraparound when reaching buffer end + - Lazy garbage collection based on is_free_fn checks + + Example Usage Scenarios: + + Scenario 1: Simple Linear Allocation + ``` + Buffer size: 100 bytes + Initial state: [................................................. ] + ^start=end(0) + + After allocating 20 bytes (id=0): + [id:0|size:20|data........][...................................] + ^start(0) ^end(28) + + After allocating 30 bytes (id=1): + [id:0|size:20|data........][id:1|size:30|data..............][..] + ^start(0) ^end(66) + ``` + + Scenario 2: Memory Reclamation + ``` + Before freeing (both buffers still in use): + [id:0|size:20|data........][id:1|size:30|data..............][..] + ^start(0) ^end(66) + + After id:0 is marked free by readers: + [FREED.................... ][id:1|size:30|data..............][..] + ^start(28) ^end(66) + + After both are freed: + [FREED..............................................][..] + ^start=end(66) + ``` + + Scenario 3: Wraparound Allocation (continuing from Scenario 2) + ``` + Starting from after memory reclamation in Scenario 2: + [FREED..............................................][..] + ^start=end(66) + + Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound: + [id:2|size:40|data........................][FREED.............][..] + ^end(148) ^start(66) + ``` + + Scenario 4: Error Handling - Out of Space + ``` + Starting from after wraparound allocation in Scenario 3: + [id:2|size:40|data........................][FREED.............][..] + ^end(148) ^start(66) + + Trying to allocate 20 more bytes: + occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100) + -> Raises MemoryError: "Not enough space in the data buffer" + ``` + + Thread Safety: + - Single writer: Only one process/thread should write (allocate_buf) + - Multiple readers: Multiple processes/threads can read (access_buf) + - Reader synchronization handled by is_free_fn callback + - Writer handles garbage collection (free_buf) based on reader feedback + + Memory Layout per Buffer Chunk: + `[4-byte monotonic_id][4-byte chunk_size][actual_data...]` + ^metadata_start ^data_start + + The monotonic_id ensures data integrity - readers can verify they're + accessing the correct data even after buffer wraparound or reuse. + """ + + def __init__( + self, + data_buffer_size: int, + name: Optional[str] = None, + create: bool = False, + ): + self.data_buffer_size = data_buffer_size + self.is_writer = create + + self.ID_NBYTES = 4 + self.ID_MAX = 2**31 # exclusive, so 2**31 - 1 is the max value + self.SIZE_NBYTES = 4 + # 4 bytes for id, 4 bytes for buffer size + self.MD_SIZE = self.ID_NBYTES + self.SIZE_NBYTES + self.monotonic_id_end = 0 + self.monotonic_id_start = 0 + self.data_buffer_start = 0 + self.data_buffer_end = 0 + + if create: + # we are creating a buffer + self.metadata = { + self.monotonic_id_end: self.data_buffer_end + } # monotonic_id -> start address + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.data_buffer_size, name=name) + else: + # we are opening an existing buffer + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + self.shared_memory = shared_memory.SharedMemory(name=name) + # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa + # Some platforms allocate memory based on page size, + # so the shared memory block size may be larger or equal + # to the requested size. The size parameter is ignored + # when attaching to an existing block. + assert self.shared_memory.size >= self.data_buffer_size + + logger.debug("Shared memory created/opened with name: %s, size: %d", + self.shared_memory.name, self.data_buffer_size) + + def handle(self): + return ( + self.data_buffer_size, + self.shared_memory.name, + ) + + def clear(self) -> None: + """Clear the ring buffer.""" + assert self.is_writer, "Only the writer can clear the buffer." + self.metadata.clear() + self.monotonic_id_end = 0 + self.monotonic_id_start = 0 + self.data_buffer_start = 0 + self.data_buffer_end = 0 + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_writer: + self.shared_memory.unlink() + + def int2byte(self, integer: int) -> bytes: + """Convert an integer to bytes.""" + return integer.to_bytes(self.ID_NBYTES, "little", signed=True) + + def byte2int(self, byte_data: bytes) -> int: + """Convert bytes back to an integer.""" + return int.from_bytes(byte_data, "little", signed=True) + + def allocate_buf(self, size: int) -> tuple[int, int]: + ''' + Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory. + Memory layout: + `[4-byte monotonic_id][4-byte size][buffer data...]` + ''' + assert self.is_writer, "Only the writer can allocate buffers." + assert size > 0, "Size must be greater than 0" + size += self.MD_SIZE # add metadata size to the buffer size + # reset to beginning if the buffer does have enough contiguous space + buffer_end_reset = self.data_buffer_end % self.data_buffer_size + if buffer_end_reset + size > self.data_buffer_size: + buffer_end_reset = (self.data_buffer_end // self.data_buffer_size + + 1) * self.data_buffer_size + else: # no reset needed + buffer_end_reset = self.data_buffer_end + + # check if we have enough space in the data buffer + # i.e. if the new end (self.data_buffer_end + size) + # exceeds the start of the data buffer + occupied_size_new = buffer_end_reset + size - self.data_buffer_start + if occupied_size_new > self.data_buffer_size: + raise MemoryError("Not enough space in the data buffer, " + "try calling free_buf() to free up space") + self.data_buffer_end = buffer_end_reset + + # first 4 bytes as the monotonic id + buf_idx = self.data_buffer_end % self.data_buffer_size + self.shared_memory.buf[buf_idx:buf_idx + self.ID_NBYTES] = \ + self.int2byte(self.monotonic_id_end) + # next 4 bytes as the size of the data buffer + self.shared_memory.buf[buf_idx + self.ID_NBYTES: \ + buf_idx + self.MD_SIZE] = self.int2byte(size) + + # record metadata + self.metadata[self.monotonic_id_end % + self.ID_MAX] = self.data_buffer_end + # update buffer and monotonic id indices + current_buffer_end = self.data_buffer_end + current_id_end = self.monotonic_id_end + self.data_buffer_end += size + self.monotonic_id_end = (self.monotonic_id_end + 1) % self.ID_MAX + return current_buffer_end, current_id_end + + @contextmanager + def access_buf(self, address: int): + buf_idx = address % self.data_buffer_size + + # read metadata + metadata_buff = self.shared_memory.buf[buf_idx:buf_idx + self.MD_SIZE] + id = self.byte2int(metadata_buff[:self.ID_NBYTES]) + size = self.byte2int(metadata_buff[self.ID_NBYTES:self.MD_SIZE]) + + # yield the data buffer and metadata + data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE:buf_idx + + size] + with (memoryview(data_buff) as data_view, ): + yield data_view, (id, size) + + def free_buf(self, + is_free_fn: Callable[[int, memoryview], bool], + nbytes: Optional[int] = None) -> Iterable[int]: + ''' + Free a buffer of the given size. This is a no-op in shared memory, + but we need to keep track of the metadata. + + If freed memory spreads across the end and start of the ring buffer, + the actual freed memory will be in two segments. In this case there + still might not be a contiguous space of `nbytes` available. + + Args: + nbytes (int, optional): The size of the buffer to free. If None, + frees the maximum size of the ring buffer. + ''' + + assert self.is_writer, "Only the writer can free buffers." + logger.debug( + "Freeing up space in the ring buffer, " + "monotonic_id_start: %d, monotonic_id_end: %d", + self.monotonic_id_start, self.monotonic_id_end) + monotonic_id_before = self.monotonic_id_start + # if nbytes is None, free up the maximum size of the ring buffer + if nbytes is None: + nbytes = self.data_buffer_size + freed_bytes = 0 + while self.monotonic_id_start in self.metadata and freed_bytes < nbytes: + address = self.metadata[self.monotonic_id_start] + with self.access_buf(address) as (data_buff, metadata): + if is_free_fn(self.monotonic_id_start, data_buff): + # check passed, we can free the buffer + del self.metadata[self.monotonic_id_start] + self.monotonic_id_start = ((self.monotonic_id_start + 1) % + self.ID_MAX) + self.data_buffer_start = address + freed_bytes += metadata[1] + else: + # there are still readers, we cannot free the buffer + break + + logger.debug( + "Freed %d bytes from the ring buffer, " + "monotonic_id_start: %d, monotonic_id_end: %d", freed_bytes, + self.monotonic_id_start, self.monotonic_id_end) + + # buffer wrap around + if self.data_buffer_start >= self.data_buffer_size: + self.data_buffer_start -= self.data_buffer_size + self.data_buffer_end -= self.data_buffer_size + + monotonic_id_after = self.monotonic_id_start + # id wrap around + if monotonic_id_after >= monotonic_id_before: + return range(monotonic_id_before, monotonic_id_after) + else: + return chain(range(monotonic_id_before, self.ID_MAX), + range(0, monotonic_id_after)) + + +class ObjectSerde(ABC): + + @abstractmethod + def serialize(self, value: Any) -> tuple[Any, int, bytes, int]: + """Serialize an object to bytes.""" + raise NotImplementedError + + @abstractmethod + def deserialize(self, data: memoryview) -> Any: + """Deserialize bytes back to an object.""" + raise NotImplementedError + + +class MsgpackSerde(ObjectSerde): + + def __init__(self): + # Delayed import to avoid circular dependency + from vllm.multimodal.inputs import MultiModalKwargsItem + from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + + self.encoder = MsgpackEncoder() + self.tensor_decoder = MsgpackDecoder(torch.Tensor) + self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem) + self._mm_kwargs_item_cls = MultiModalKwargsItem + + def serialize( + self, + value: Any) -> tuple[Union[bytes, list[bytes]], int, bytes, int]: + len_arr = None + if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)): + type_name = type(value).__name__ + value = self.encoder.encode(value) + len_arr = [len(s) for s in value] + nbytes = sum(len_arr) + else: + value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) + type_name = type(value).__name__ + nbytes = len(value) + + object_metadata = (type_name, nbytes, len_arr) + serialized_metadata = pickle.dumps(object_metadata, + protocol=pickle.HIGHEST_PROTOCOL) + return value, nbytes, serialized_metadata, len(serialized_metadata) + + def deserialize(self, data_view: memoryview) -> Any: + # pickle.loads do not read past the end of a pickled object + # within a large buffer, so we can skip storing the metadata size + type_name, nbytes, len_arr = pickle.loads(data_view) + serialized_data = bytearray(data_view[-nbytes:]) + + if type_name == torch.Tensor.__name__: + obj = [] + start_idx = 0 + for length in len_arr: + item_bytes = serialized_data[start_idx:start_idx + length] + obj.append(item_bytes) + start_idx += length + obj = self.tensor_decoder.decode(obj) + elif type_name == self._mm_kwargs_item_cls.__name__: + obj = [] + start_idx = 0 + for length in len_arr: + item_bytes = serialized_data[start_idx:start_idx + length] + obj.append(item_bytes) + start_idx += length + obj = self.mm_decoder.decode(obj) + elif type_name == bytes.__name__: + obj = pickle.loads(serialized_data) + else: + raise ValueError( + f"Unsupported object type '{type_name}' in metadata") + + return obj + + +@dataclass +class ShmObjectStorageHandle: + max_object_size: int + n_readers: int + ring_buffer_handle: tuple[int, str] + serde_class: type[ObjectSerde] + reader_lock: Optional[LockType] + + +class SingleWriterShmObjectStorage: + """ + A single-writer, multiple-reader object storage system built on top of a + shared memory ring buffer. Provides key-value storage with automatic memory + management and cross-process serialization support. + + This storage system follows a FIFO (First-In-First-Out) eviction policy + where the oldest objects are automatically freed when memory runs low. + Memory is reclaimed based on reader reference counting - objects are only + freed when all readers have finished accessing them. + + Architecture: + - Single writer process can put(key, value) objects + - Multiple reader processes can get(address, monotonic_id) objects + - Built on SingleWriterShmRingBuffer for efficient shared memory management + - Thread-safe operations with reader synchronization via locks + + Key Features: + - FIFO Eviction: Oldest objects are evicted first when memory is full + - Reference Counting: Objects are only freed when no readers are + accessing them + - Duplicate Key Handling: Existing keys are not overwritten, just + re-referenced + - Customized Serialization: By default uses Msgpack for efficient + serialization of Python objects, but can be extended for custom types + - Cross-Process Safety: Uses shared memory with proper synchronization + - Automatic Cleanup: Garbage collection happens transparently during + allocation + + Memory Layout per Object: + `[4-byte reference_count][metadata_size][serialized_object_data]` + + Thread Safety: + - Writer operations (put, clear) are single-threaded by design + - Reader operations (get) are thread-safe with lock-based reference + counting + - Memory reclamation is handled exclusively by the writer process + """ + + def __init__( + self, + max_object_size: int, + n_readers: int, + ring_buffer: SingleWriterShmRingBuffer, + serde_class: type[ObjectSerde] = MsgpackSerde, + reader_lock: Optional[LockType] = None, + ): + """ + Initialize the object storage. + + Args: + max_object_size: Maximum size for a single object in bytes. + n_readers: Number of reader processes that can access the storage. + ring_buffer: The shared memory ring buffer for storing objects. + serde_class: Serializer/deserializer for objects. + reader_lock: Optional lock for synchronizing reader access. + Raises: + ValueError: If reader_lock is None for readers. + """ + + self.max_object_size = max_object_size + self.n_readers = n_readers + self.serde_class = serde_class + self.ser_de = serde_class() + self.ring_buffer = ring_buffer + self.is_writer = self.ring_buffer.is_writer + + self.flag_bytes = 4 # for in-use flag + + if self.is_writer: + # Key-value mapping: key -> (address, monotonic_id) + self.key_index: dict[str, tuple[int, int]] = {} + # Reverse mapping: monotonic_id -> key + self.id_index: dict[int, str] = {} + # Writer flag to track in-use status: monotonic_id -> count + self.writer_flag: dict[int, int] = {} + else: + if reader_lock is None: + raise ValueError("Lock must be provided for readers.") + + self._reader_lock = reader_lock + + def clear(self) -> None: + """Clear the object storage.""" + if self.is_writer: + self.ring_buffer.clear() + self.key_index.clear() + self.id_index.clear() + self.writer_flag.clear() + logger.debug("Object storage cleared and reinitialized.") + + def copy_to_buffer( + self, + data: Union[bytes, list[bytes]], + data_bytes: int, + metadata: bytes, + md_bytes: int, + data_view: memoryview, + ) -> None: + data_view[self.flag_bytes:self.flag_bytes + md_bytes] = metadata + if isinstance(data, bytes): + data_view[-data_bytes:] = data + elif isinstance(data, list): + start_idx = self.flag_bytes + md_bytes + for item_bytes in data: + item_size = len(item_bytes) + data_view[start_idx:start_idx + item_size] = item_bytes + start_idx += item_size + else: + raise ValueError( + f"Unsupported data type for serialization: {type(data)}") + + def increment_writer_flag(self, id: int) -> None: + """Set the in-use flag for the writer.""" + self.writer_flag[id] = self.writer_flag.get(id, 0) + 1 + + def increment_reader_flag(self, data_view: memoryview) -> None: + """Set the in-use flag for the reader.""" + # >0 for in-use flag + reader_count = self.ring_buffer.byte2int(data_view) + data_view[:] = self.ring_buffer.int2byte(reader_count + 1) + + def free_unused(self) -> None: + """Free unused buffers in the ring buffer.""" + # try to free up 2*max_object_size bytes of space in the ring buffer, + # since the buffer might be fragmented + freed_ids = self.ring_buffer.free_buf(self.default_is_free_check, + 2 * self.max_object_size) + # update the metadata after freeing up space + for freed_id in freed_ids: + key_to_free = self.id_index[freed_id] + del self.key_index[key_to_free] + del self.id_index[freed_id] + del self.writer_flag[freed_id] + + def is_cached(self, key: str) -> bool: + """ + Check if the object with the given key is cached. + """ + return key in self.key_index + + def get_cached(self, key: str) -> tuple[int, int]: + """ + Get the cached object by key if it exists. + """ + address, monotonic_id = self.key_index[key] + self.increment_writer_flag(monotonic_id) + return address, monotonic_id + + def put(self, key: str, value: Any) -> tuple[int, int]: + """ + Store a key-value pair in the object storage. + Attempts to free max_object_size bytes using FIFO order + when the ring buffer runs out of space during a put() operation. + + Args: + key: String key to identify the object + value: Any serializable Python object + + Raises: + MemoryError: If there's not enough space in the buffer + ValueError: If the serialized object is too large + ValueError: If the key already exists in the storage + """ + if key in self.key_index: + raise ValueError(f"Key '{key}' already exists in the storage.") + + object_data, data_bytes, object_metadata, md_bytes = \ + self.ser_de.serialize(value) + buffer_size = self.flag_bytes + data_bytes + md_bytes + + # Sanity checks + if buffer_size > self.max_object_size: + raise ValueError( + f"Serialized object size ({buffer_size} bytes) exceeds " + f"max object size ({self.max_object_size} bytes)") + + # Allocate new buffer + try: + address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size) + except MemoryError: + self.free_unused() + # try again after freeing up space + address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size) + + # Write data to buffer + with self.ring_buffer.access_buf(address) as (data_view, metadata): + data_view[:self.flag_bytes] = self.ring_buffer.int2byte(0) + self.copy_to_buffer(object_data, data_bytes, object_metadata, + md_bytes, data_view) + self.increment_writer_flag(monotonic_id) + + # Update key index + self.key_index[key] = (address, monotonic_id) + self.id_index[monotonic_id] = key + return address, monotonic_id + + def get(self, address: int, monotonic_id: int) -> Any: + # Read data from buffer + with self.ring_buffer.access_buf(address) as (data_view, buf_metadata): + # check id from metadata + if buf_metadata[0] != monotonic_id: + raise ValueError( + f"Data for address:id '{address}:{monotonic_id}'" + " has been modified or is invalid.") + + obj = self.ser_de.deserialize(data_view[self.flag_bytes:]) + + # decrease the in-use flag for reader reads + if self._reader_lock is not None: + with self._reader_lock: + self.increment_reader_flag(data_view[:self.flag_bytes]) + else: + # if self._reader_lock is None, it means we are the writer + # in this case, we do not need to decrease the reader count + assert self.is_writer + + return obj + + def handle(self): + """Get handle for sharing across processes.""" + return ShmObjectStorageHandle( + max_object_size=self.max_object_size, + n_readers=self.n_readers, + ring_buffer_handle=self.ring_buffer.handle(), + serde_class=self.serde_class, + reader_lock=self._reader_lock, + ) + + @staticmethod + def create_from_handle( + handle: ShmObjectStorageHandle) -> "SingleWriterShmObjectStorage": + logger.debug("Creating storage from handle: %s", handle) + ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle) + return SingleWriterShmObjectStorage( + max_object_size=handle.max_object_size, + n_readers=handle.n_readers, + ring_buffer=ring_buffer, + serde_class=handle.serde_class, + reader_lock=handle.reader_lock, + ) + + def default_is_free_check(self, id: int, buf: memoryview) -> bool: + """ + Default is_free function that checks if the first 4 bytes are zero. + This indicates that the buffer is free. + """ + reader_count = int.from_bytes(buf[0:4], "little", signed=True) + writer_count = self.writer_flag[id] + return reader_count >= writer_count * self.n_readers diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index d907e1b833d04..09012d16978d9 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -27,8 +27,13 @@ class SymmMemCommunicator: "10.0": [6, 8], } - def __init__(self, group: ProcessGroup, device: Union[int, str, - torch.device]): + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + # add options for testing + force_multimem: Optional[bool] = None, + max_size_override: Optional[int] = None): self.disabled = True if not symm_mem_available: @@ -64,8 +69,17 @@ class SymmMemCommunicator: self.world_size, ) return - self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ - self.world_size] + # Use override max_size if provided, otherwise use default + if max_size_override is not None: + self.max_size = max_size_override + logger.info( + "SymmMemCommunicator: Using override max_size: %s bytes", + self.max_size, + ) + else: + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[ + self.device_capability][self.world_size] + self.buffer = torch_symm_mem.empty( self.max_size // self.dtype.itemsize, device=self.device, @@ -76,6 +90,7 @@ class SymmMemCommunicator: logger.warning("SymmMemCommunicator: symmetric memory " "multicast operations are not supported.") return + self.force_multimem = force_multimem self.disabled = False def should_use_symm_mem(self, inp: torch.Tensor): @@ -98,8 +113,18 @@ class SymmMemCommunicator: if out is None: out = torch.empty_like(inp) self.buffer[:inp.numel()].copy_(inp.view(-1)) - if self.world_size in self._WORLD_SIZES_MULTIMEM[ - self.device_capability]: + + # Determine which algorithm to use + use_multimem = False + if self.force_multimem is not None: + # Test override: use forced setting + use_multimem = self.force_multimem + else: + # Normal logic: use multimem for supported world sizes + use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[ + self.device_capability] + + if use_multimem: torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], "sum", self.group.group_name) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index d5ab61473ab01..3e318d7848326 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -337,11 +337,12 @@ class EplbState: Args: model (MixtureOfExperts): The MoE model. is_dummy (bool): If `True`, this is a dummy step and the load - metrics recorded in this forward pass will not count. Defaults - to `False`. + metrics recorded in this forward pass will not count. + Defaults to `False`. is_profile (bool): If `True`, perform a dummy rearrangement - with maximum communication cost. This is used in `profile_run` - to reserve enough memory for the communication buffer. + with maximum communication cost. This is used in + `profile_run` to reserve enough memory + for the communication buffer. log_stats (bool): If `True`, log the expert load metrics. # Stats diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index 879b5b9f18240..fc43dbe3b6533 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -102,20 +102,23 @@ def rebalance_experts_hierarchical( num_groups: int, num_nodes: int, num_gpus: int, -): +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Parameters: weight: [num_moe_layers, num_logical_experts] num_physical_experts: number of physical experts after replication num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network - (e.g, NVLink) is faster + num_nodes: number of server nodes, where the intra-node network + (e.g., NVLink) is faster num_gpus: number of GPUs, must be a multiple of `num_nodes` Returns: - physical_to_logical_map: [num_moe_layers, num_physical_experts] - logical_to_physical_map: [num_moe_layers, num_logical_experts, X] - logical_count: [num_moe_layers, num_logical_experts] + physical_to_logical_map (torch.Tensor): + [num_moe_layers, num_physical_experts] + logical_to_physical_map (torch.Tensor): + [num_moe_layers, num_logical_experts, X] + logical_count (torch.Tensor): + [num_moe_layers, num_logical_experts] """ num_layers, num_logical_experts = weight.shape assert num_logical_experts % num_groups == 0 @@ -197,11 +200,13 @@ def rebalance_experts( num_gpus: number of GPUs, must be a multiple of `num_nodes` Returns: - physical_to_logical_map: [layers, num_replicas], the expert index of - each replica - logical_to_physical_map: [layers, num_logical_experts, X], the replica - indices for each expert - expert_count: [layers, num_logical_experts], number of physical + physical_to_logical_map: + [layers, num_replicas], the expert index of each replica + logical_to_physical_map: + [layers, num_logical_experts, X], the replica indices for each + expert + expert_count: + [layers, num_logical_experts], number of physical replicas for each logical expert """ num_layers, num_logical_experts = weight.shape diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 37f8f72fa9056..46f0cd9289b23 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -14,8 +14,9 @@ from typing import Any, Callable, Optional, Union import msgspec import zmq -from vllm.config import KVEventsConfig +from vllm.config.kv_events import KVEventsConfig from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import ExternalBlockHash logger = init_logger(__name__) @@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU" class BlockStored(KVCacheEvent): - block_hashes: list[int] - parent_block_hash: Optional[int] + block_hashes: list[ExternalBlockHash] + parent_block_hash: Optional[ExternalBlockHash] token_ids: list[int] block_size: int lora_id: Optional[int] @@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent): class BlockRemoved(KVCacheEvent): - block_hashes: list[int] + block_hashes: list[ExternalBlockHash] medium: Optional[str] diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index fa9b7e4f14c02..cf58e7914972c 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.distributed.kv_transfer.kv_transfer_state import ( - KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, - has_kv_transfer_group, is_v1_kv_transfer_group) + KVConnectorBaseType, ensure_kv_transfer_initialized, + ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group, + is_v1_kv_transfer_group) __all__ = [ "get_kv_transfer_group", "has_kv_transfer_group", "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", - "KVConnectorBaseType" + "ensure_kv_transfer_shutdown", "KVConnectorBaseType" ] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 584fc1d655951..670f9c26b2104 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -14,7 +14,8 @@ from vllm.logger import init_logger # yapf: enable if TYPE_CHECKING: - from vllm.config import KVTransferConfig, VllmConfig + from vllm.config import VllmConfig + from vllm.config.kv_transfer import KVTransferConfig logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 2364400b3d350..f4dc248a12794 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -6,7 +6,7 @@ KV cache helper for store. from collections import defaultdict from collections.abc import Sequence from concurrent.futures import CancelledError, Future -from typing import Optional, cast +from typing import Literal, Optional, Union, cast import torch @@ -196,3 +196,51 @@ class KVOutputAggregator: output_future.add_done_callback(make_callback(i)) return result_future + + +def _make_src_and_dst_indices( + src_block_ids: list[int], + dst_block_ids: list[int], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> tuple[torch.Tensor, torch.Tensor]: + src_indices = torch.tensor(src_block_ids, + device=src_device, + dtype=torch.int64) + dst_indices = torch.tensor(dst_block_ids, + device=dst_device, + dtype=torch.int64) + return src_indices, dst_indices + + +def copy_kv_blocks( + src_kv_caches: dict[str, torch.Tensor], + dst_kv_caches: dict[str, torch.Tensor], + src_block_ids: list[int], + dst_block_ids: list[int], + direction: Literal["h2d", "d2h"], +) -> None: + """Copy kv blocks between different buffers.""" + if not src_kv_caches or not dst_kv_caches or \ + not src_block_ids or not dst_block_ids or \ + len(src_block_ids) != len(dst_block_ids): + return + + src_device = next(iter(src_kv_caches.values())).device + dst_device = next(iter(dst_kv_caches.values())).device + + src_indices, dst_indices = _make_src_and_dst_indices( + src_block_ids=src_block_ids, + dst_block_ids=dst_block_ids, + src_device=src_device, + dst_device=dst_device) + + from vllm.platforms import current_platform + if direction == "h2d": + copy_fn = current_platform.insert_blocks_to_device + else: + copy_fn = current_platform.swap_out_blocks_to_host + for layer_name in src_kv_caches: + src_tensor = src_kv_caches[layer_name] + dst_tensor = dst_kv_caches[layer_name] + copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2804003f5a708..70c07eac6304b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -149,7 +149,7 @@ class KVConnectorBase_V1(ABC): @abstractmethod def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the @@ -182,7 +182,8 @@ class KVConnectorBase_V1(ABC): @abstractmethod def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + attn_metadata: "AttentionMetadata", + **kwargs: Any) -> None: """ Start saving a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to @@ -226,6 +227,14 @@ class KVConnectorBase_V1(ABC): """ return None, None + def shutdown(self): + """ + Shutdown the connector. This is called when the worker process + is shutting down to ensure that all the async operations are + completed and the connector is cleaned up properly. + """ + return None + # ============================== # Scheduler-side methods # ============================== @@ -235,7 +244,7 @@ class KVConnectorBase_V1(ABC): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -247,8 +256,11 @@ class KVConnectorBase_V1(ABC): Returns: A tuple with the following elements: - - The number of tokens that can be loaded from the - external KV cache beyond what is already computed. + - An optional number of tokens that can be loaded from the + external KV cache beyond what is already computed. + If None, it means that the connector needs more time to + determine the number of matched tokens, and the scheduler + should query for this request again later. - `True` if external KV cache tokens will be loaded asynchronously (between scheduler steps). Must be 'False' if the first element is 0. @@ -343,3 +355,14 @@ class KVConnectorBase_V1(ABC): raise TypeError("get_required_kvcache_layout should not be called " "on the abstract base class") return None + + def get_finished_count(self) -> Optional[int]: + """ + Get the count of requests expected to complete send/receive operations + via this connector. + + Returns: + int: expected sending or receiving completion count. + """ + + return None \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e838ac2499c04..2b0abe983fbb3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -30,7 +30,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): # Worker-side methods # ============================== def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the @@ -61,7 +61,8 @@ class LMCacheConnectorV1(KVConnectorBase_V1): self._lmcache_engine.wait_for_layer_load(layer_name) def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + attn_metadata: "AttentionMetadata", + **kwargs: Any) -> None: """ Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to @@ -110,7 +111,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 65bcb4d93b1e1..616d158d67670 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional import torch -from vllm.config import KVTransferConfig, VllmConfig +from vllm.config import VllmConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) @@ -87,6 +88,18 @@ class MultiConnector(KVConnectorBase_V1): for c in self._connectors: c.clear_connector_metadata() + def shutdown(self): + exception: Optional[Exception] = None + for c in self._connectors: + try: + c.shutdown() + except Exception as e: + logger.exception("Exception during connector %s shutdown.", + c.__class__.__name__) + exception = e + if exception: + raise exception + # ============================== # Worker-side methods # ============================== @@ -142,11 +155,15 @@ class MultiConnector(KVConnectorBase_V1): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: to_return = (0, False) for i, c in enumerate(self._connectors): toks, load_async = c.get_num_new_matched_tokens( request, num_computed_tokens) + # If there is a connector still looking up the matches, + # we return None to indicate that we are not done yet. + if toks is None: + return (None, False) # The first connector that has new matched tokens will be assigned # to this request. if to_return[0] == 0 and toks > 0: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6608d2a4a9e09..1ff1407aeb99b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import msgspec +import numpy as np import torch import zmq @@ -55,11 +56,12 @@ except ImportError: logger.warning("NIXL is not available") NixlWrapper = None -# Supported xPUs and types of kv transfer buffer. -# {xPU: tuple of supported kv buffer types} -_NIXL_SUPPORTED_XPUS = { +# Supported platforms and types of kv transfer buffer. +# {device: tuple of supported kv buffer types} +_NIXL_SUPPORTED_DEVICE = { "cuda": ("cuda", ), "tpu": ("cpu", ), + "xpu": ("cpu", ), } @@ -160,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1): def get_num_new_matched_tokens( self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + num_computed_tokens: int) -> tuple[Optional[int], bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens( request, num_computed_tokens) @@ -456,9 +458,9 @@ class NixlConnectorWorker: self.device_type = current_platform.device_type self.kv_buffer_device: str = \ vllm_config.kv_transfer_config.kv_buffer_device - if self.device_type not in _NIXL_SUPPORTED_XPUS: + if self.device_type not in _NIXL_SUPPORTED_DEVICE: raise RuntimeError(f"{self.device_type} is not supported.") - elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[ + elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[ self.device_type]: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " @@ -466,7 +468,7 @@ class NixlConnectorWorker: self.device_kv_caches: dict[str, torch.Tensor] = {} # cpu kv buffer for xfer - # used when xPU memory can not be registered under nixl + # used when device memory can not be registered under nixl self.host_xfer_buffers: dict[str, torch.Tensor] = {} self.use_host_buffer = self.kv_buffer_device == "cpu" if self.kv_buffer_device == "cuda": @@ -706,8 +708,6 @@ class NixlConnectorWorker: caches_data = [] # With hybrid allocator, layers can share a kv cache tensor seen_base_addresses = [] - xfer_buffers = (self.host_xfer_buffers - if self.use_host_buffer else kv_caches) # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -715,7 +715,7 @@ class NixlConnectorWorker: # are non-contiguous (it's not locally guaranteed that they will be) # Disadvantage is that the encoded NixlAgentMetadata is now larger # (roughly 8KB vs 5KB). - # Conversely for FlashInfer, K and V are transferred in the same tensor + # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). split_k_and_v = not (self.use_mla or self._use_pallas_v1 or self._use_flashinfer) @@ -758,12 +758,21 @@ class NixlConnectorWorker: assert tensor_size_bytes % self.num_blocks == 0 self.block_len = tensor_size_bytes // self.num_blocks self.slot_size_bytes = self.block_len // self.block_size + self.device_kv_caches = kv_caches + self.dst_num_blocks[self.engine_id] = self.num_blocks if self._use_flashinfer: assert self.slot_size_bytes % 2 == 0 self.slot_size_bytes /= 2 - self.device_kv_caches = kv_caches - self.dst_num_blocks[self.engine_id] = self.num_blocks + # NOTE (NickLucche) When FlashInfer is used, memory is registered + # with joint KV for each block. This minimizes the overhead in + # registerMem allowing faster descs queries. In order to be able to + # split on kv_heads dim as required by heterogeneous TP, one must + # be able to index K/V separately. Hence we double the number + # of 'virtual' regions here and halve `block_len` below. + self.num_regions *= 2 + + kv_block_len = self.get_backend_aware_kv_block_len() # Register local/src descr for NIXL xfer. blocks_data = [] for base_addr in seen_base_addresses: @@ -776,8 +785,18 @@ class NixlConnectorWorker: block_offset = block_id * self.block_len addr = base_addr + block_offset # (addr, len, device id) - # TODO: does device_id matter to DRAM? - blocks_data.append((addr, self.block_len, self.tp_rank)) + blocks_data.append((addr, kv_block_len, self.tp_rank)) + + if self._use_flashinfer: + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + addr = base_addr + block_offset + # Register addresses for V cache (K registered first). + v_addr = addr + kv_block_len + blocks_data.append((v_addr, kv_block_len, self.tp_rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) @@ -787,7 +806,7 @@ class NixlConnectorWorker: self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) - # TODO(mgoin): Hybrid memory allocator is currently diabled for + # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. if self.vllm_config.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig @@ -903,11 +922,14 @@ class NixlConnectorWorker: remote_block_size = nixl_agent_meta.block_len // ( self.slot_size_bytes * tp_ratio) if self._use_flashinfer: - # Account for joint KV in FlashInfer. + # With flashinfer, KV are sent in the same message. remote_block_size //= 2 if tp_ratio > 1: # Heterogeneous TP expects same kv_cache_layout. assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout + if self.device_type == "xpu": + raise ValueError( + "Heterogeneous TP is not supported on XPU") assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( "Remote P worker KV layer cache must be of shape [2, N, " @@ -929,10 +951,10 @@ class NixlConnectorWorker: # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - # Only register the remote's descriptors if current rank pulls from it. self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.tp_rank % tp_ratio * self.block_len \ + kv_block_len = self.get_backend_aware_kv_block_len() + rank_offset = self.tp_rank % tp_ratio * kv_block_len \ if not (self.use_mla or is_kv_replicated) else 0 # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: @@ -943,7 +965,16 @@ class NixlConnectorWorker: # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, remote_tp_rank)) + blocks_data.append((addr, kv_block_len, remote_tp_rank)) + + if self._use_flashinfer: + # With FlashInfer index V separately to allow head splitting. + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + addr = base_addr + block_offset + rank_offset + v_addr = addr + nixl_agent_meta.block_len // 2 + blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) + logger.debug( "Created %s blocks for dst engine %s with remote rank %s and " "local rank %s", len(blocks_data), engine_id, remote_tp_rank, @@ -1163,8 +1194,8 @@ class NixlConnectorWorker: # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - local_block_descs_ids: list[int] = [] - remote_block_descs_ids: list[int] = [] + local_block_descs_ids: np.ndarray + remote_block_descs_ids: np.ndarray if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( @@ -1174,6 +1205,8 @@ class NixlConnectorWorker: else: # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) + local_descs_list = [] + remote_descs_list = [] for layer_idx, block_window in enumerate( self.block_window_per_layer): # For each layer: @@ -1193,8 +1226,11 @@ class NixlConnectorWorker: layer_remote_desc_ids = self._get_block_descs_ids( dst_engine_id, layer_remote_block_ids, layer_idx) - local_block_descs_ids.extend(layer_local_desc_ids) - remote_block_descs_ids.extend(layer_remote_desc_ids) + local_descs_list.append(layer_local_desc_ids) + remote_descs_list.append(layer_remote_desc_ids) + + local_block_descs_ids = np.concatenate(local_descs_list) + remote_block_descs_ids = np.concatenate(remote_descs_list) assert len(local_block_descs_ids) == len(remote_block_descs_ids) @@ -1219,14 +1255,14 @@ class NixlConnectorWorker: def _get_block_descs_ids(self, engine_id: str, block_ids: list[int], - layer_idx: Optional[int] = None) -> list[int]: + layer_idx: Optional[int] = None) -> np.ndarray: """ Get the descs ids for a set of block ids. If layer_idx is provided, we use the region_ids for the given layer. Otherwise, we use all regions. """ if layer_idx is None: - region_ids = range(self.num_regions) + region_ids = np.arange(self.num_regions) else: assert layer_idx < self.num_layers if self.num_layers < self.num_regions: @@ -1234,20 +1270,35 @@ class NixlConnectorWorker: # the regions are organized as [K0, V0, K1, V1, ...] # and we select K_i and V_i assert 2 * self.num_layers == self.num_regions - region_ids = range(2 * layer_idx, 2 * layer_idx + 2) + region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2) else: # Otherwise, we assume we have MLA and select i-th layer assert self.num_layers == self.num_regions - region_ids = range(layer_idx, layer_idx + 1) + region_ids = np.arange(layer_idx, layer_idx + 1) num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block. - descs_ids: list[int] = [] - for reg_id in region_ids: - for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) - return descs_ids + region_ids = region_ids[:, None] + block_ids = np.array(block_ids)[None, :] + descs_ids = region_ids * num_blocks + block_ids + return descs_ids.flatten() + + def get_backend_aware_kv_block_len(self): + """ + Get the block length for one K/V element (K and V have the same size). + + For FA and other backends, this is equal to the length of the whole + block, as K and V are in separate regions. + For FlashInfer, this is half the length of the whole block, as K and V + share the same region. + """ + if self._use_flashinfer: + # For indexing only half (either just the K or V part). + block_len = self.block_len // 2 + else: + block_len = self.block_len + return block_len @contextlib.contextmanager diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 2485c57d86ecc..ec72905a0d3ec 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -91,7 +91,7 @@ class P2pNcclConnector(KVConnectorBase_V1): # ============================== def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + **kwargs: Any) -> None: """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. @@ -212,7 +212,8 @@ class P2pNcclConnector(KVConnectorBase_V1): return def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + attn_metadata: "AttentionMetadata", + **kwargs: Any) -> None: """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. @@ -278,7 +279,7 @@ class P2pNcclConnector(KVConnectorBase_V1): def get_finished( self, finished_req_ids: set[str], - **kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]: + **kwargs: Any) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have finished generating tokens. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index dfd95548c4632..fa7cc66ab654d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -15,7 +15,7 @@ import msgpack import torch import zmq -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py index b775276d4a846..26070488bad89 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py @@ -218,8 +218,9 @@ class TensorMemoryPool: return addr - def load_tensor(self, addr: int, dtype: torch.dtype, - shape: tuple[int, ...], device) -> torch.Tensor: + def load_tensor(self, addr: int, dtype: torch.dtype, shape: tuple[int, + ...], + device: torch.device) -> torch.Tensor: """Loads a tensor from pinned host memory to the specified device. Args: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index fd79387269d56..48fa1a82c6775 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -3,7 +3,7 @@ import hashlib import os from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional import safetensors import torch @@ -90,7 +90,7 @@ class SharedStorageConnector(KVConnectorBase_V1): logger.info("Shared storage path is %s", self._storage_path) def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + **kwargs: Any) -> None: """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. @@ -191,7 +191,8 @@ class SharedStorageConnector(KVConnectorBase_V1): return def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + attn_metadata: "AttentionMetadata", + **kwargs: Any) -> None: """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. @@ -238,7 +239,7 @@ class SharedStorageConnector(KVConnectorBase_V1): self, request: "Request", num_computed_tokens: int, - ) -> tuple[int, bool]: + ) -> tuple[Optional[int], bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -300,11 +301,12 @@ class SharedStorageConnector(KVConnectorBase_V1): total_need_load = 0 for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self._requests_need_load: - meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size, - is_store=False, - mm_hashes=new_req.mm_hashes) + meta.add_request( + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=False, + mm_hashes=[f.identifier for f in new_req.mm_features]) total_need_load += 1 else: # NOTE: here, we set the store and load being exclusive, @@ -312,11 +314,12 @@ class SharedStorageConnector(KVConnectorBase_V1): # NOTE(rob): for this debug implementation, we only cache # the original prompt tokens. if not self._found_match_for_request(new_req): - meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size, - is_store=True, - mm_hashes=new_req.mm_hashes) + meta.add_request( + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=True, + mm_hashes=[f.identifier for f in new_req.mm_features]) cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): @@ -341,11 +344,12 @@ class SharedStorageConnector(KVConnectorBase_V1): # of the block_ids for the request. block_ids = new_block_ids[0] - meta.add_request(token_ids=token_ids, - block_ids=block_ids, - block_size=self._block_size, - is_store=False, - mm_hashes=request.mm_hashes) + meta.add_request( + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False, + mm_hashes=[f.identifier for f in request.mm_features]) total_need_load += 1 assert total_need_load == len(self._requests_need_load) @@ -364,10 +368,10 @@ class SharedStorageConnector(KVConnectorBase_V1): """ num_tokens_to_check = align_to_block_size( len(request.prompt_token_ids) - 1, self._block_size) - foldername = self._generate_foldername_debug(torch.tensor( - request.prompt_token_ids)[:num_tokens_to_check], - request.mm_hashes, - create_folder=False) + foldername = self._generate_foldername_debug( + torch.tensor(request.prompt_token_ids)[:num_tokens_to_check], + [f.identifier for f in request.mm_features], + create_folder=False) return os.path.exists(foldername) def _generate_foldername_debug( diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 0b560d1b3b3ce..2a434e280179e 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -13,7 +13,7 @@ import zmq from safetensors.torch import load as safetensors_load from safetensors.torch import save as safetensors_save -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger from vllm.utils import join_host_port, make_zmq_path, split_host_port diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 09de0b682efca..7a79a8cc0c932 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -20,7 +20,7 @@ from typing import Callable, Optional import torch -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.utils import StatelessProcessGroup @@ -251,8 +251,8 @@ class PyNcclPipe(KVPipeBase): """ Receives a tensor and its metadata from the source rank. Blocking call. - Args: - tensor: The received tensor, or `None` if no tensor is received. + Returns: + The received tensor, or `None` if no tensor is received. """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 5e0f64fca220c..d5747bed92771 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: config=vllm_config, role=KVConnectorRole.WORKER) else: raise ValueError("V0 is no longer supported") + + +def ensure_kv_transfer_shutdown() -> None: + global _KV_CONNECTOR_AGENT + if _KV_CONNECTOR_AGENT is not None: + _KV_CONNECTOR_AGENT.shutdown() + _KV_CONNECTOR_AGENT = None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index fc96c2ac926b0..12571afaa4c13 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -29,6 +29,7 @@ import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass +from datetime import timedelta from multiprocessing import shared_memory from typing import Any, Callable, Optional, Union from unittest.mock import patch @@ -662,14 +663,29 @@ class GroupCoordinator: tensor_dict: dict[str, Union[torch.Tensor, Any]], dst: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: Optional[dict[str, bool]] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. + + all_gather_group: The group for the all-gather operation. If provided, + an optimization is enabled where each rank in the group sends a + slice of a tensor and the receiver reconstructs it using an + all-gather, which can improve performance. This is typically the + tensor-parallel group. + all_gather_tensors: A dictionary to specify which tensors should use + the all-gather optimization, which is only effective when + `all_gather_group` is provided. By default, this optimization is + on for any tensor whose size is divisible by the + `all_gather_group`'s world size. However, it should be disabled + for tensors that are not fully replicated across the group (e.g., + the residual tensor when sequence parallelism is enabled). This + dictionary allows overriding the default behavior on a per-tensor + basis. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict - all_gather_size = (1 if all_gather_group is None else all_gather_group.world_size) all_gather_rank = (0 if all_gather_group is None else @@ -698,14 +714,23 @@ class GroupCoordinator: # `send_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. self.send_object(metadata_list, dst=dst) - for tensor in tensor_list: + + tensor_keys = [ + k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor) + ] + assert len(tensor_keys) == len(tensor_list) + + for key, tensor in zip(tensor_keys, tensor_list): if tensor.numel() == 0: # Skip sending empty tensors. continue # send-allgather: send only a slice, then do allgather. - if (all_gather_group is not None - and tensor.numel() % all_gather_size == 0): + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + use_all_gather = all_gather_tensors.get(key, use_all_gather) \ + if all_gather_tensors else use_all_gather + if use_all_gather: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: @@ -724,14 +749,29 @@ class GroupCoordinator: self, src: Optional[int] = None, all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: Optional[dict[str, bool]] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. + + all_gather_group: The group for the all-gather operation. If provided, + an optimization is enabled where each rank in the group sends a + slice of a tensor and the receiver reconstructs it using an + all-gather, which can improve performance. This is typically the + tensor-parallel group. + all_gather_tensors: A dictionary to specify which tensors should use + the all-gather optimization, which is only effective when + `all_gather_group` is provided. By default, this optimization is + on for any tensor whose size is divisible by the + `all_gather_group`'s world size. However, it should be disabled + for tensors that are not fully replicated across the group (e.g., + the residual tensor when sequence parallelism is enabled). This + dictionary allows overriding the default behavior on a per-tensor + basis. """ # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return None - all_gather_size = (1 if all_gather_group is None else all_gather_group.world_size) all_gather_rank = (0 if all_gather_group is None else @@ -765,6 +805,8 @@ class GroupCoordinator: # send-allgather: send only a slice, then do allgather. use_all_gather = (all_gather_group is not None and tensor.numel() % all_gather_size == 0) + use_all_gather = all_gather_tensors.get(key, use_all_gather) \ + if all_gather_tensors else use_all_gather if use_all_gather: orig_shape = tensor.shape @@ -904,6 +946,18 @@ def get_tensor_model_parallel_group(): return get_tp_group() +_DCP: Optional[GroupCoordinator] = None + + +def get_dcp_group() -> GroupCoordinator: + assert _DCP is not None, ( + "decode context model parallel group is not initialized") + return _DCP + + +# kept for backward compatibility +get_context_model_parallel_group = get_dcp_group + _PP: Optional[GroupCoordinator] = None _DP: Optional[GroupCoordinator] = None @@ -966,13 +1020,12 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable -def init_distributed_environment( - world_size: int = -1, - rank: int = -1, - distributed_init_method: str = "env://", - local_rank: int = -1, - backend: str = "nccl", -): +def init_distributed_environment(world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", + timeout: Optional[timedelta] = None): logger.debug( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, @@ -1008,7 +1061,8 @@ def init_distributed_environment( backend=backend, init_method=distributed_init_method, world_size=world_size, - rank=rank) + rank=rank, + timeout=timeout) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1034,6 +1088,7 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: """ @@ -1098,6 +1153,23 @@ def initialize_model_parallel( use_message_queue_broadcaster=True, group_name="tp") + # Build the DCP model-parallel groups. + global _DCP + assert _DCP is None, ( + "decode context model parallel group is already initialized") + # Note(hc): In the current implementation of decode context parallel, + # dcp_size must not exceed tp_size, because the world size does not + # change by DCP, it simply reuses the GPUs of TP group, and split one + # TP group into tp_size//dcp_size DCP groups. + group_ranks = all_ranks.reshape( + -1, decode_context_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _DCP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="dcp") + # Build the pipeline model-parallel groups. global _PP assert _PP is None, ( @@ -1141,6 +1213,7 @@ def initialize_model_parallel( def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, @@ -1151,7 +1224,8 @@ def ensure_model_parallel_initialized( get_world_group().device_group) if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size, backend) + pipeline_model_parallel_size, + decode_context_model_parallel_size, backend) return assert ( @@ -1226,6 +1300,16 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group +def get_decode_context_model_parallel_world_size(): + """Return world size for the decode context model parallel group.""" + return get_dcp_group().world_size + + +def get_decode_context_model_parallel_rank(): + """Return my rank for the decode context model parallel group.""" + return get_dcp_group().rank_in_group + + def get_node_count() -> int: """Return the total number of nodes in the distributed environment. """ assert _NODE_COUNT is not None, ( @@ -1246,6 +1330,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _DCP + if _DCP: + _DCP.destroy() + _DCP = None + global _DP if _DP: _DP.destroy() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4dd545dd43a6..fb5beab77b270 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -22,17 +22,19 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, - ConfigFormat, ConfigType, ConvertOption, - DecodingConfig, DetailedTraceModules, Device, - DeviceConfig, DistributedExecutorBackend, EPLBConfig, - GuidedDecodingBackend, HfOverrides, KVEventsConfig, + ConfigType, ConvertOption, DetailedTraceModules, + Device, DeviceConfig, DistributedExecutorBackend, + EPLBConfig, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, - ModelDType, ModelImpl, MultiModalConfig, - ObservabilityConfig, ParallelConfig, PoolerConfig, - PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, - SchedulerPolicy, SpeculativeConfig, TaskOption, - TokenizerMode, VllmConfig, get_attr_docs, get_field) + ModelDType, ModelImpl, ObservabilityConfig, + ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, + RunnerOption, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, StructuredOutputsConfig, + TaskOption, TokenizerMode, VllmConfig, get_attr_docs) +from vllm.config.multimodal import MMCacheType, MultiModalConfig +from vllm.config.parallel import ExpertPlacementStrategy +from vllm.config.utils import get_field from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -227,8 +229,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: elif contains_type(type_hints, int): kwargs[name]["type"] = int # Special case for large integers - if name in {"max_model_len", "max_num_batched_tokens"}: + human_readable_ints = { + "max_model_len", + "max_num_batched_tokens", + "kv_cache_memory_bytes", + } + if name in human_readable_ints: kwargs[name]["type"] = human_readable_int + kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}" elif contains_type(type_hints, float): kwargs[name]["type"] = float elif (contains_type(type_hints, dict) @@ -289,6 +297,7 @@ class EngineArgs: trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path download_dir: Optional[str] = LoadConfig.download_dir + safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy load_format: Union[str, LoadFormats] = LoadConfig.load_format config_format: str = ModelConfig.config_format dtype: ModelDType = ModelConfig.dtype @@ -306,6 +315,8 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + decode_context_parallel_size: int = \ + ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None data_parallel_start_rank: Optional[int] = None @@ -315,8 +326,13 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_dbo: bool = ParallelConfig.enable_dbo + dbo_decode_token_threshold: int = \ + ParallelConfig.dbo_decode_token_threshold eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb + expert_placement_strategy: ExpertPlacementStrategy = \ + ParallelConfig.expert_placement_strategy num_redundant_experts: int = EPLBConfig.num_redundant_experts eplb_window_size: int = EPLBConfig.window_size eplb_step_interval: int = EPLBConfig.step_interval @@ -332,6 +348,7 @@ class EngineArgs: swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization + kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes max_num_batched_tokens: Optional[ int] = SchedulerConfig.max_num_batched_tokens max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills @@ -363,6 +380,10 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb + mm_processor_cache_type: Optional[MMCacheType] = \ + MultiModalConfig.mm_processor_cache_type + mm_shm_cache_max_object_size_mb: int = \ + MultiModalConfig.mm_shm_cache_max_object_size_mb mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode io_processor_plugin: Optional[str] = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling @@ -396,12 +417,15 @@ class EngineArgs: disable_hybrid_kv_cache_manager: bool = ( SchedulerConfig.disable_hybrid_kv_cache_manager) - guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend - guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback - guided_decoding_disable_any_whitespace: bool = \ - DecodingConfig.disable_any_whitespace - guided_decoding_disable_additional_properties: bool = \ - DecodingConfig.disable_additional_properties + structured_outputs_config: StructuredOutputsConfig = get_field( + VllmConfig, "structured_outputs_config") + reasoning_parser: str = StructuredOutputsConfig.reasoning_parser + # Deprecated guided decoding fields + guided_decoding_backend: Optional[str] = None + guided_decoding_disable_fallback: Optional[bool] = None + guided_decoding_disable_any_whitespace: Optional[bool] = None + guided_decoding_disable_additional_properties: Optional[bool] = None + logits_processor_pattern: Optional[ str] = ModelConfig.logits_processor_pattern @@ -417,8 +441,6 @@ class EngineArgs: scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls - override_neuron_config: dict[str, Any] = \ - get_field(ModelConfig, "override_neuron_config") override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ ModelConfig.override_pooler_config compilation_config: CompilationConfig = \ @@ -442,7 +464,6 @@ class EngineArgs: additional_config: dict[str, Any] = \ get_field(VllmConfig, "additional_config") - reasoning_parser: str = DecodingConfig.reasoning_backend use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location @@ -547,7 +568,6 @@ class EngineArgs: help="Disable async output processing. This may result in " "lower performance.") model_group.add_argument("--config-format", - choices=[f.value for f in ConfigFormat], **model_kwargs["config_format"]) # This one is a special case because it can bool # or str. TODO: Handle this in get_kwargs @@ -559,8 +579,6 @@ class EngineArgs: help=model_kwargs["hf_token"]["help"]) model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) - model_group.add_argument("--override-neuron-config", - **model_kwargs["override_neuron_config"]) model_group.add_argument("--override-pooler-config", **model_kwargs["override_pooler_config"]) model_group.add_argument("--logits-processor-pattern", @@ -590,6 +608,8 @@ class EngineArgs: load_group.add_argument("--load-format", **load_kwargs["load_format"]) load_group.add_argument("--download-dir", **load_kwargs["download_dir"]) + load_group.add_argument("--safetensors-load-strategy", + **load_kwargs["safetensors_load_strategy"]) load_group.add_argument("--model-loader-extra-config", **load_kwargs["model_loader_extra_config"]) load_group.add_argument("--ignore-patterns", @@ -599,28 +619,29 @@ class EngineArgs: load_group.add_argument('--pt-load-map-location', **load_kwargs["pt_load_map_location"]) - # Guided decoding arguments - guided_decoding_kwargs = get_kwargs(DecodingConfig) - guided_decoding_group = parser.add_argument_group( - title="DecodingConfig", - description=DecodingConfig.__doc__, + # Structured outputs arguments + structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) + structured_outputs_group = parser.add_argument_group( + title="StructuredOutputsConfig", + description=StructuredOutputsConfig.__doc__, ) - guided_decoding_group.add_argument("--guided-decoding-backend", - **guided_decoding_kwargs["backend"]) - guided_decoding_group.add_argument( - "--guided-decoding-disable-fallback", - **guided_decoding_kwargs["disable_fallback"]) - guided_decoding_group.add_argument( - "--guided-decoding-disable-any-whitespace", - **guided_decoding_kwargs["disable_any_whitespace"]) - guided_decoding_group.add_argument( - "--guided-decoding-disable-additional-properties", - **guided_decoding_kwargs["disable_additional_properties"]) - guided_decoding_group.add_argument( + structured_outputs_group.add_argument( "--reasoning-parser", # This choice is a special case because it's not static choices=list(ReasoningParserManager.reasoning_parsers), - **guided_decoding_kwargs["reasoning_backend"]) + **structured_outputs_kwargs["reasoning_parser"]) + # Deprecated guided decoding arguments + for arg, type in [ + ("--guided-decoding-backend", str), + ("--guided-decoding-disable-fallback", bool), + ("--guided-decoding-disable-any-whitespace", bool), + ("--guided-decoding-disable-additional-properties", bool), + ]: + structured_outputs_group.add_argument( + arg, + type=type, + help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."), + deprecated=True) # Parallel arguments parallel_kwargs = get_kwargs(ParallelConfig) @@ -636,6 +657,9 @@ class EngineArgs: **parallel_kwargs["pipeline_parallel_size"]) parallel_group.add_argument("--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]) + parallel_group.add_argument( + "--decode-context-parallel-size", "-dcp", + **parallel_kwargs["decode_context_parallel_size"]) parallel_group.add_argument("--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]) parallel_group.add_argument( @@ -676,10 +700,18 @@ class EngineArgs: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument("--enable-dbo", + **parallel_kwargs["enable_dbo"]) + parallel_group.add_argument( + "--dbo-decode-token-threshold", + **parallel_kwargs["dbo_decode_token_threshold"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"]) + parallel_group.add_argument( + "--expert-placement-strategy", + **parallel_kwargs["expert_placement_strategy"]) parallel_group.add_argument( "--num-redundant-experts", type=int, @@ -731,6 +763,8 @@ class EngineArgs: cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) cache_group.add_argument("--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument("--kv-cache-memory-bytes", + **cache_kwargs["kv_cache_memory_bytes"]) cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"]) @@ -770,6 +804,12 @@ class EngineArgs: multimodal_group.add_argument("--disable-mm-preprocessor-cache", action="store_true", deprecated=True) + multimodal_group.add_argument( + "--mm-processor-cache-type", + **multimodal_kwargs["mm_processor_cache_type"]) + multimodal_group.add_argument( + "--mm-shm-cache-max-object-size-mb", + **multimodal_kwargs["mm_shm_cache_max_object_size_mb"]) multimodal_group.add_argument( "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) multimodal_group.add_argument( @@ -896,6 +936,8 @@ class EngineArgs: **vllm_kwargs["compilation_config"]) vllm_group.add_argument("--additional-config", **vllm_kwargs["additional_config"]) + vllm_group.add_argument('--structured-outputs-config', + **vllm_kwargs["structured_outputs_config"]) # Other arguments parser.add_argument('--disable-log-stats', @@ -921,7 +963,6 @@ class EngineArgs: if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 and self.model in MODELS_ON_S3 and self.load_format == "auto"): self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" - self.load_format = "runai_streamer" if self.disable_mm_preprocessor_cache: logger.warning( @@ -986,8 +1027,10 @@ class EngineArgs: config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_processor_cache_type=self.mm_processor_cache_type, + mm_shm_cache_max_object_size_mb=self. + mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, - override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, generation_config=self.generation_config, @@ -1024,6 +1067,7 @@ class EngineArgs: return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + safetensors_load_strategy=self.safetensors_load_strategy, device="cpu" if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, @@ -1053,9 +1097,10 @@ class EngineArgs: SpeculatorsConfig) if self.speculative_config is None: - hf_config = get_config(self.hf_config_path or self.model, - self.trust_remote_code, self.revision, - self.code_revision, self.config_format) + hf_config = get_config( + self.hf_config_path or target_model_config.model, + self.trust_remote_code, self.revision, self.code_revision, + self.config_format) # if loading a SpeculatorsConfig, load the speculative_config # details from the config directly @@ -1065,7 +1110,7 @@ class EngineArgs: self.speculative_config = {} self.speculative_config[ "num_speculative_tokens"] = hf_config.num_lookahead_tokens - self.speculative_config["model"] = self.model + self.speculative_config["model"] = target_model_config.model self.speculative_config["method"] = hf_config.method else: return None @@ -1156,9 +1201,21 @@ class EngineArgs: # global layers in interleaved sliding window models. sliding_window = model_config.get_sliding_window() + # Note(hc): In the current implementation of decode context + # parallel(DCP), tp_size needs to be divisible by dcp_size, + # because the world size does not change by dcp, it simply + # reuses the GPUs of TP group, and split one TP group into + # tp_size//dcp_size DCP groups. + assert self.tensor_parallel_size % self.decode_context_parallel_size \ + == 0, ( + f"tp_size={self.tensor_parallel_size} must be divisible by" + f"dcp_size={self.decode_context_parallel_size}." + ) + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, + kv_cache_memory_bytes=self.kv_cache_memory_bytes, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, is_attention_free=model_config.is_attention_free, @@ -1258,11 +1315,8 @@ class EngineArgs: # Async scheduling does not work with the uniprocess backend. if self.distributed_executor_backend is None: self.distributed_executor_backend = "mp" - logger.info("Using mp-based distributed executor backend " - "for async scheduling.") - if self.distributed_executor_backend == "uni": - raise ValueError("Async scheduling is not supported with " - "uni-process backend.") + logger.info("Defaulting to mp-based distributed executor " + "backend for async scheduling.") if self.pipeline_parallel_size > 1: raise ValueError("Async scheduling is not supported with " "pipeline-parallel-size > 1.") @@ -1296,8 +1350,11 @@ class EngineArgs: data_parallel_backend=self.data_parallel_backend, data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, + enable_dbo=self.enable_dbo, + dbo_decode_token_threshold=self.dbo_decode_token_threshold, enable_eplb=self.enable_eplb, eplb_config=self.eplb_config, + expert_placement_strategy=self.expert_placement_strategy, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, @@ -1306,6 +1363,7 @@ class EngineArgs: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, + decode_context_parallel_size=self.decode_context_parallel_size, ) speculative_config = self.create_speculative_config( @@ -1367,14 +1425,25 @@ class EngineArgs: load_config = self.create_load_config() - decoding_config = DecodingConfig( - backend=self.guided_decoding_backend, - disable_fallback=self.guided_decoding_disable_fallback, - disable_any_whitespace=self.guided_decoding_disable_any_whitespace, - disable_additional_properties=\ - self.guided_decoding_disable_additional_properties, - reasoning_backend=self.reasoning_parser - ) + # Pass reasoning_parser into StructuredOutputsConfig + if self.reasoning_parser: + self.structured_outputs_config.reasoning_parser = \ + self.reasoning_parser + + # Forward the deprecated CLI args to the StructuredOutputsConfig + so_config = self.structured_outputs_config + if self.guided_decoding_backend is not None: + so_config.guided_decoding_backend = \ + self.guided_decoding_backend + if self.guided_decoding_disable_fallback is not None: + so_config.guided_decoding_disable_fallback = \ + self.guided_decoding_disable_fallback + if self.guided_decoding_disable_any_whitespace is not None: + so_config.guided_decoding_disable_any_whitespace = \ + self.guided_decoding_disable_any_whitespace + if self.guided_decoding_disable_additional_properties is not None: + so_config.guided_decoding_disable_additional_properties = \ + self.guided_decoding_disable_additional_properties observability_config = ObservabilityConfig( show_hidden_metrics_for_version=( @@ -1392,7 +1461,7 @@ class EngineArgs: lora_config=lora_config, speculative_config=speculative_config, load_config=load_config, - decoding_config=decoding_config, + structured_outputs_config=self.structured_outputs_config, observability_config=observability_config, compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, @@ -1465,12 +1534,6 @@ class EngineArgs: recommend_to_remove=False) return False - # No OTLP observability so far. - if (self.otlp_traces_endpoint or self.collect_detailed_traces): - _raise_or_fallback(feature_name="--otlp-traces-endpoint", - recommend_to_remove=False) - return False - # V1 supports N-gram, Medusa, and Eagle speculative decoding. if (self.speculative_config is not None and self.speculative_config.get("method") == "draft_model"): @@ -1488,8 +1551,11 @@ class EngineArgs: "TRITON_MLA", "CUTLASS_MLA", "FLASHMLA", + "FLASHMLA_VLLM_V1", + "FLASH_ATTN_MLA", "FLASHINFER", "FLASHINFER_VLLM_V1", + "FLASHINFER_MLA", "ROCM_AITER_MLA", "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", @@ -1578,20 +1644,12 @@ class EngineArgs: "in low performance due to small KV cache size. Consider " "setting --max-model-len to a smaller value.", max_model_len) - # if using prefix caching, we must set a hash algo - if self.enable_prefix_caching: - # Disable prefix caching for multimodal models for VLLM_V0. - if model_config.is_multimodal_model: - logger.warning( - "--enable-prefix-caching is not supported for multimodal " - "models in V0 and has been disabled.") - self.enable_prefix_caching = False - - # VLLM_V0 only supports builtin hash algo for prefix caching. - if self.prefix_caching_hash_algo == "sha256": - raise ValueError( - "sha256 is not supported for prefix caching in V0 engine. " - "Please use 'builtin'.") + # Disable prefix caching for multimodal models for VLLM_V0. + if self.enable_prefix_caching and model_config.is_multimodal_model: + logger.warning( + "--enable-prefix-caching is not supported for multimodal " + "models in V0 and has been disabled.") + self.enable_prefix_caching = False # Set max_num_seqs to 256 for VLLM_V0. if self.max_num_seqs is None: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9f9ad1854c3b6..6793041abc502 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -10,8 +10,8 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, from weakref import ReferenceType import vllm.envs as envs -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VllmConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -389,11 +389,8 @@ class _AsyncLLMEngine(LLMEngine): """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() - async def get_tokenizer_async(self, - lora_request: Optional[LoRARequest] = None - ) -> AnyTokenizer: - return await ( - self.get_tokenizer_group().get_lora_tokenizer_async(lora_request)) + async def get_tokenizer_async(self) -> AnyTokenizer: + return self.get_tokenizer() async def add_request_async( self, @@ -434,7 +431,6 @@ class _AsyncLLMEngine(LLMEngine): processed_inputs = await self.input_preprocessor.preprocess_async( prompt, - lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) @@ -613,11 +609,8 @@ class AsyncLLMEngine(EngineClient): async def get_input_preprocessor(self) -> InputPreprocessor: return self.engine.input_preprocessor - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return await self.engine.get_tokenizer_async(lora_request) + async def get_tokenizer(self) -> AnyTokenizer: + return self.engine.get_tokenizer() def start_background_loop(self) -> None: """Start the background loop.""" @@ -717,7 +710,7 @@ class AsyncLLMEngine(EngineClient): # Stop the execute model loop in parallel workers until there # are more requests to process. This avoids waiting # indefinitely in torch.distributed ops which may otherwise - # timeout, and unblocks the RPC thread in the workers so that + # time out, and unblocks the RPC thread in the workers so that # they can process any other queued control plane messages, # such as add/remove lora adapters. await engine.engine.stop_remote_worker_execution_loop_async() @@ -961,10 +954,6 @@ class AsyncLLMEngine(EngineClient): """Get the parallel configuration of the vLLM engine.""" return self.engine.get_parallel_config() - async def get_decoding_config(self) -> DecodingConfig: - """Get the decoding configuration of the vLLM engine.""" - return self.engine.get_decoding_config() - async def get_scheduler_config(self) -> SchedulerConfig: """Get the scheduling configuration of the vLLM engine.""" return self.engine.get_scheduler_config() diff --git a/vllm/engine/async_timeout.py b/vllm/engine/async_timeout.py index 28a023a71ef52..3b9c055160c1b 100644 --- a/vllm/engine/async_timeout.py +++ b/vllm/engine/async_timeout.py @@ -16,19 +16,6 @@ if sys.version_info[:2] >= (3, 11): from asyncio import timeout as asyncio_timeout else: - def asyncio_timeout(delay: Optional[float]) -> "Timeout": - """timeout context manager. - Useful in cases when you want to apply timeout logic around block - of code or in cases when asyncio.wait_for is not suitable. For example: - >>> async with timeout(0.001): - ... async with aiohttp.get('https://github.com') as r: - ... await r.text() - delay - value in seconds or None to disable timeout logic - """ - loop = asyncio.get_running_loop() - deadline = loop.time() + delay if delay is not None else None - return Timeout(deadline, loop) - class _State(enum.Enum): INIT = "INIT" ENTER = "ENTER" @@ -171,3 +158,16 @@ else: self._state = _State.TIMEOUT # drop the reference early self._timeout_handler = None + + def asyncio_timeout(delay: Optional[float]) -> Timeout: + """timeout context manager. + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + deadline = loop.time() + delay if delay is not None else None + return Timeout(deadline, loop) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 10ded6f16d41c..708f3bbeeff15 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -16,9 +16,8 @@ import torch from typing_extensions import TypeVar import vllm.envs as envs -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, - VllmConfig) +from vllm.config import (LoRAConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, SchedulerConfig, VllmConfig) from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase, Stats @@ -40,6 +39,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, Sequence, SequenceGroup, SequenceGroupBase, @@ -48,9 +48,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import ( - TokenizerGroup, init_tokenizer_from_configs) +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind @@ -185,7 +184,7 @@ class LLMEngine: return outputs_ - tokenizer: Optional[TokenizerGroup] + tokenizer: Optional[AnyTokenizer] def __init__( self, @@ -213,8 +212,7 @@ class LLMEngine: self.device_config = vllm_config.device_config self.speculative_config = vllm_config.speculative_config # noqa self.load_config = vllm_config.load_config - self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa - ) + self.structured_outputs_config = vllm_config.structured_outputs_config self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa ) @@ -232,18 +230,9 @@ class LLMEngine: if self.model_config.skip_tokenizer_init: self.tokenizer = None self.detokenizer = None - tokenizer_group = None else: self.tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(self.tokenizer) - tokenizer_group = self.get_tokenizer_group() - - # Ensure that the function doesn't contain a reference to self, - # to avoid engine GC issues - def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: - assert tokenizer_group, ("tokenizer_group cannot be None, " - "make sure skip_tokenizer_init is False") - return tokenizer_group.get_lora_tokenizer(sequence.lora_request) self.seq_counter = Counter() self.generation_config_fields = ( @@ -278,7 +267,8 @@ class LLMEngine: self.cache_config.block_size, "gpu_memory_utilization": self.cache_config.gpu_memory_utilization, - + "kv_cache_memory_bytes": + self.cache_config.kv_cache_memory_bytes, # Quantization "quantization": self.model_config.quantization, @@ -371,6 +361,13 @@ class LLMEngine: "vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + # Initialize reasoning parser if reasoning backend is set. + if self.structured_outputs_config.reasoning_parser and self.tokenizer: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + self.structured_outputs_config.reasoning_parser) + self.reasoner: ReasoningParser = reasoner_class( + self.tokenizer.get_lora_tokenizer()) + # Create sequence output processor, e.g. for beam search or # speculative decoding. self.output_processor = ( @@ -379,9 +376,12 @@ class LLMEngine: self.detokenizer, self.scheduler, self.seq_counter, - get_tokenizer_for_seq, - stop_checker=StopChecker(self.scheduler_config.max_model_len, - get_tokenizer_for_seq), + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + self.reasoner + if self.structured_outputs_config.reasoning_parser + and self.tokenizer else None, + ), )) self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} @@ -507,24 +507,15 @@ class LLMEngine: if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() - def get_tokenizer_group(self) -> TokenizerGroup: + def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " "skip_tokenizer_init is True") return self.tokenizer - def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - - def _init_tokenizer(self) -> TokenizerGroup: - return init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=self.scheduler_config, - lora_config=self.lora_config) + def _init_tokenizer(self) -> AnyTokenizer: + return init_tokenizer_from_configs(model_config=self.model_config) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -560,11 +551,11 @@ class LLMEngine: ) return None - self._validate_model_inputs(processed_inputs, lora_request) + self._validate_model_inputs(processed_inputs) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + eos_token_id = self.input_preprocessor.get_eos_token_id() encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) @@ -678,15 +669,17 @@ class LLMEngine: arrival_time = time.time() if (isinstance(prompt, dict) - and prompt.get("prompt_embeds", None) is not None - and not prompt.get("prompt_token_ids", None)): - seq_len = prompt["prompt_embeds"].shape[0] - prompt["prompt_token_ids"] = [0] * seq_len + and prompt.get("prompt_embeds", None) is not None): + if not prompt.get("prompt_token_ids", None): + seq_len = prompt["prompt_embeds"].shape[0] + prompt["prompt_token_ids"] = [0] * seq_len + if params.prompt_logprobs is not None: + raise ValueError( + "prompt_logprobs is not compatible with prompt embeds.") processed_inputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, ) self._add_processed_request( @@ -777,10 +770,6 @@ class LLMEngine: """Gets the parallel configuration.""" return self.parallel_config - def get_decoding_config(self) -> DecodingConfig: - """Gets the decoding configuration.""" - return self.decoding_config - def get_scheduler_config(self) -> SchedulerConfig: """Gets the scheduler configuration.""" return self.scheduler_config @@ -1414,7 +1403,7 @@ class LLMEngine: num_generation_tokens_iter = 0 num_tokens_iter = 0 time_to_first_tokens_iter: List[float] = [] - time_per_output_tokens_iter: List[float] = [] + inter_token_latencies_iter: List[float] = [] num_preemption_iter = (0 if scheduler_outputs is None else scheduler_outputs.preempted) @@ -1498,9 +1487,9 @@ class LLMEngine: num_generation_tokens_from_prefill_groups += ( seq_group.num_seqs()) else: - # TPOTs. + # ITLs latency = seq_group.get_last_token_latency() - time_per_output_tokens_iter.append(latency) + inter_token_latencies_iter.append(latency) if seq_group.state.current_step == 0: # For async_output_proc, the do_log_stats() # is called following init_multi_step(), which @@ -1582,7 +1571,7 @@ class LLMEngine: num_generation_tokens_iter=num_generation_tokens_iter, num_tokens_iter=num_tokens_iter, time_to_first_tokens_iter=time_to_first_tokens_iter, - time_per_output_tokens_iter=time_per_output_tokens_iter, + inter_token_latencies_iter=inter_token_latencies_iter, num_preemption_iter=num_preemption_iter, # Request stats @@ -1725,29 +1714,22 @@ class LLMEngine: SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE, metrics.model_execute_time) - def _validate_model_inputs(self, inputs: ProcessorInputs, - lora_request: Optional[LoRARequest]): + def _validate_model_inputs(self, inputs: ProcessorInputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) if encoder_inputs is not None: - self._validate_model_input(encoder_inputs, - lora_request, - prompt_type="encoder") + self._validate_model_input(encoder_inputs, prompt_type="encoder") - self._validate_model_input(decoder_inputs, - lora_request, - prompt_type="decoder") + self._validate_model_input(decoder_inputs, prompt_type="decoder") def _validate_model_input( self, prompt_inputs: SingletonInputs, - lora_request: Optional[LoRARequest], *, prompt_type: Literal["encoder", "decoder"], ): model_config = self.model_config - tokenizer = (None if self.tokenizer is None else - self.tokenizer.get_lora_tokenizer(lora_request)) + tokenizer = self.tokenizer prompt_ids = prompt_inputs.get("prompt_token_ids", []) if not prompt_ids: @@ -1775,7 +1757,7 @@ class LLMEngine: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper and Donut + return # Skip encoder length check for Whisper if model_config.is_multimodal_model: suggestion = ( @@ -1808,7 +1790,7 @@ class LLMEngine: logits_processors = [] if (sampling_params.logit_bias or sampling_params.allowed_token_ids): - tokenizer = self.get_tokenizer(lora_request=lora_request) + tokenizer = self.get_tokenizer() processors = get_openai_logits_processors( logit_bias=sampling_params.logit_bias, @@ -1821,7 +1803,7 @@ class LLMEngine: sampling_params.allowed_token_ids = None if len(sampling_params.bad_words) > 0: - tokenizer = self.get_tokenizer(lora_request) + tokenizer = self.get_tokenizer() processors = get_bad_words_logits_processors( bad_words=sampling_params.bad_words, tokenizer=tokenizer) logits_processors.extend(processors) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index ba8dbd1fad791..2762175c430fb 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -113,9 +113,21 @@ class Metrics: 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, 2560.0 ]) + # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds + # TODO: in 0.12, only enable if show_hidden_metrics=True self.histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", - documentation="Histogram of time per output token in seconds.", + documentation=( + "Histogram of time per output token in seconds." + "DEPRECATED: Use vllm:inter_token_latency_seconds instead."), + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ]) + self.histogram_inter_token_latency = self._histogram_cls( + name="vllm:inter_token_latency_seconds", + documentation="Histogram of inter token latency in seconds.", labelnames=labelnames, buckets=[ 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, @@ -367,7 +379,7 @@ class LoggingStatLogger(StatLoggerBase): if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): # Compute summary metrics for tracked stats (and log them - # to promethus if applicable). + # to prometheus if applicable). prompt_throughput = get_throughput(self.num_prompt_tokens, now=stats.now, last_log=self.last_local_log) @@ -420,7 +432,7 @@ class LoggingStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase): - """PrometheusStatLogger is used LLMEngine to log to Promethus.""" + """PrometheusStatLogger is used LLMEngine to log to Prometheus.""" _metrics_cls = Metrics _gauge_cls = prometheus_client.Gauge @@ -491,7 +503,9 @@ class PrometheusStatLogger(StatLoggerBase): self._log_histogram(self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter) self._log_histogram(self.metrics.histogram_time_per_output_token, - stats.time_per_output_tokens_iter) + stats.inter_token_latencies_iter) + self._log_histogram(self.metrics.histogram_inter_token_latency, + stats.inter_token_latencies_iter) # Request level data # Latency diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 3281a9121a9df..9778ab5a8c99b 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -43,7 +43,7 @@ class Stats: num_generation_tokens_iter: int num_tokens_iter: int time_to_first_tokens_iter: List[float] - time_per_output_tokens_iter: List[float] + inter_token_latencies_iter: List[float] num_preemption_iter: int # Request stats (should have _requests suffix) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py deleted file mode 100644 index 9f64ee0808df2..0000000000000 --- a/vllm/engine/multiprocessing/__init__.py +++ /dev/null @@ -1,145 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import uuid -from dataclasses import dataclass, field -from enum import Enum -from typing import List, Mapping, Optional, Union - -from vllm import PoolingParams -from vllm.inputs import PromptType -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams -from vllm.utils import Device - -VLLM_RPC_SUCCESS_STR = "SUCCESS" - -IPC_INPUT_EXT = "_input_socket" -IPC_OUTPUT_EXT = "_output_socket" -IPC_HEALTH_EXT = "_health_socket" -IPC_DATA_EXT = "_data_socket" - - -class MQEngineDeadError(RuntimeError): - pass - - -@dataclass -class RPCProcessRequest: - prompt: PromptType - params: Union[SamplingParams, PoolingParams] - request_id: str - lora_request: Optional[LoRARequest] = None - trace_headers: Optional[Mapping[str, str]] = None - priority: int = 0 - - def __init__( - self, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> None: - super().__init__() - - self.prompt = prompt - self.params = params - self.request_id = request_id - self.lora_request = lora_request - self.trace_headers = trace_headers - self.priority = priority - - -@dataclass -class RPCError: - request_id: Optional[str] - is_engine_errored: bool - exception: BaseException - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCStartupRequest(Enum): - IS_SERVER_READY = 1 - - -@dataclass -class RPCStartupResponse: - tracing_enabled: bool - - -class RPCUProfileRequest(Enum): - START_PROFILE = 1 - STOP_PROFILE = 2 - - -class RPCResetMultiModalCacheRequest(Enum): - RESET = 1 - - -@dataclass -class RPCResetPrefixCacheRequest: - device: Device - - -class RPCSleepRequest(Enum): - SLEEP_LEVEL_1 = 1 - SLEEP_LEVEL_2 = 2 - - -@dataclass -class RPCWakeUpRequest: - tags: Optional[list[str]] = None - - -@dataclass -class RPCIsSleepingRequest: - # Set the default value of request_id to a new UUID - request_id: str = field(default_factory=lambda: str(uuid.uuid4())) - - -@dataclass -class RPCIsSleepingResponse: - request_id: str - is_sleeping: bool - - -@dataclass -class RPCLoadAdapterRequest: - lora_request: LoRARequest - # Set the default value of request_id to a new UUID - request_id: str = field(default_factory=lambda: str(uuid.uuid4())) - - -@dataclass -class RPCAdapterLoadedResponse: - request_id: str - lora_loaded: bool - - -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, - RPCUProfileRequest, RPCLoadAdapterRequest, - RPCResetMultiModalCacheRequest, - RPCResetPrefixCacheRequest, RPCSleepRequest, - RPCWakeUpRequest, RPCIsSleepingRequest] - -REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, - RPCIsSleepingResponse, RPCError] - - -def ENGINE_DEAD_ERROR( - error: Optional[BaseException] = None) -> MQEngineDeadError: - if error is None: - return MQEngineDeadError( - "Engine loop is not running. Inspect the stacktrace to " - "find the original error") - - return MQEngineDeadError( - "Engine loop is not running. Inspect the stacktrace to " - f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py deleted file mode 100644 index 2d3248859c940..0000000000000 --- a/vllm/engine/multiprocessing/client.py +++ /dev/null @@ -1,643 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import copy -import pickle -from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List, - Mapping, Optional, Union) - -import cloudpickle -import psutil -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import PoolingParams -from vllm.config import DecodingConfig, ModelConfig, VllmConfig -from vllm.core.scheduler import SchedulerOutputs -# yapf conflicts with isort for this block -# yapf: disable -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, RPC_REQUEST_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCAdapterLoadedResponse, RPCError, - RPCIsSleepingRequest, - RPCIsSleepingResponse, - RPCLoadAdapterRequest, - RPCProcessRequest, - RPCResetMultiModalCacheRequest, - RPCResetPrefixCacheRequest, - RPCSleepRequest, RPCStartupRequest, - RPCStartupResponse, - RPCUProfileRequest, RPCWakeUpRequest) -from vllm.engine.protocol import EngineClient -# yapf: enable -from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import PoolingRequestOutput, RequestOutput -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import Device - -logger = init_logger(__name__) - - -class MQClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class MQLLMEngineClient(EngineClient): - """A client wrapper for MQLLMEngine that conforms to the - EngineClient protocol. - - MQLLMEngine and MQLLMEngineClient are intended to run in separate - processes communicating via zeromq ipc sockets. - - The entrypoint to MQLLMEngineClient is through the generate() - method. On generate() MQLLMEngine does three things: - - Creates an asyncio output queue - - Sends a RPCGenerateRequest to the MQLLMEngine via zmq - - Pulls RequestOutputs from its queue and yields them - - MQLLMEngine runs two background loops: - - output_loop: the output loop pulls List[RequestOutput] - from the MQLLMEngine via zmq (each list is the output - of one engine_step in the LLMEngine). It then parses - the list and pushes individual request_outputs into - the corresponding output_queue such that they can be - consumed by the .generate() method. - - health_loop: the health loop queries the health socket - every N seconds, confirming the engine is healthy - """ - - def __init__(self, ipc_path: str, engine_config: VllmConfig, - engine_pid: int): - self.context = zmq.asyncio.Context() - self._errored_with: Optional[BaseException] = None - - # Get the configs. - self.vllm_config = engine_config - self.model_config = engine_config.model_config - self.decoding_config = engine_config.decoding_config - - if self.vllm_config.model_config.skip_tokenizer_init: - self.tokenizer = None - - else: - # Create the tokenizer group. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=engine_config.scheduler_config, - lora_config=engine_config.lora_config) - - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer) - - # Send RPCGenerateRequest to the MQLLMEngine. - self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) - self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") - - # Receive streams of RequestOutput from the MQLLMEngine. - self.output_socket: Socket = self.context.socket(zmq.constants.PULL) - self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") - - # IPC path for acking heartbeats. - self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) - self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") - - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - - # Stream for each individual request. - self.output_queues: Dict[str, asyncio.Queue] = {} - - # Loop to handle output of the LLMEngine periodically. - # Started after the MQLLMEngine is ready so that we can - # build the Client in an executor to enable clean shutdown. - self.output_loop: Optional[asyncio.Task] = None - - # Loop to check health of the LLMEngine periodically. - # Started after the MQLLMEngine is ready. - self.health_loop: Optional[asyncio.Task] = None - self._engine_process = psutil.Process(engine_pid) - - @staticmethod - def is_unsupported_config(vllm_config: VllmConfig): - # Pipeline parallel not yet supported - return vllm_config.parallel_config.pipeline_parallel_size > 1 - - @contextmanager - def get_data_socket(self) -> Iterator[Socket]: - socket = self.context.socket(zmq.constants.DEALER) - try: - socket.connect(self.data_ipc_path) - yield socket - finally: - socket.close(linger=0) - - async def run_heartbeat_loop(self, timeout: int): - """Background loop that continually checks to ensure the engine process - is still alive. - """ - try: - while True: - # Check if the engine process is running: - if not self._engine_process.is_running() or ( - self._engine_process.status() == psutil.STATUS_ZOMBIE): - # NB: is_running() returns True for zombies - self._set_errored( - RuntimeError( - f"Engine process (pid {self._engine_process.pid}) " - "died.")) - break - - if await self.heartbeat_socket.poll(timeout=timeout): - # Heartbeat received- check the message - await self._check_success( - error_message="Heartbeat failed.", - socket=self.heartbeat_socket) - - logger.debug("Heartbeat successful.") - - except asyncio.CancelledError: - logger.debug("Shutting down MQLLMEngineClient check health loop.") - - except psutil.NoSuchProcess: - self._set_errored( - RuntimeError( - f"Engine process (pid {self._engine_process.pid}) died.")) - - except Exception as e: - self._set_errored(e) - - async def run_output_handler_loop(self): - """Get RequestOutputs from Engine and stream to Request Queues""" - - try: - while True: - # Poll, checking for ENGINE_DEAD - while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT - ) == 0: - logger.debug("Waiting for output from MQLLMEngine.") - - # If errored, alert all running requests. - if self.errored: - for queue_j in tuple(self.output_queues.values()): - queue_j.put_nowait( - ENGINE_DEAD_ERROR(self._errored_with)) - return - - message: Frame = await self.output_socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) - - is_error = isinstance(request_outputs, - (BaseException, RPCError)) - if is_error: - if isinstance(request_outputs, RPCError): - rpc_error: RPCError = request_outputs - request_id = rpc_error.request_id - exception = rpc_error.exception - is_engine_errored = rpc_error.is_engine_errored - else: - # MPLLMEngine should always return an RPCError to - # the output_socket when an issue arises. - # If we are here, we are in a bad state and - # should shut down the server. - error: BaseException = request_outputs - logger.error( - "Received Exception %s rather than RPCError from " - "MPLLMEngine. This should never happen.", error) - request_id = None - exception = error - is_engine_errored = True - - # Set to error state only on engine critical error - # (and record only the first one) - if is_engine_errored and not self._errored_with: - self._errored_with = exception - # If engine is errored, no matter the type of exception - # it will no longer be able to receive new requests, - # therefore we have to inform that the current - # processed requests failed as well. Send back a dead - # engine error give this feedback and also give a - # 'hint' to the server to shutdown next. - exception = self.dead_error - - if request_id is None: - # If request_id is None, then the engine raised an - # exception for a batch, and we may not know the - # request that caused it, neither if it was actually - # caused by any of them (e.g. CUDA OOM). Therefore we - # broadcast the same exception for all requests. - for queue_i in tuple(self.output_queues.values()): - queue_i.put_nowait(exception) - else: - queue = self.output_queues.get(request_id) - if queue is not None: - queue.put_nowait(exception) - # Put each output into the appropriate queue. - elif isinstance( - request_outputs, - (RPCAdapterLoadedResponse, RPCIsSleepingResponse)): - self._add_output(request_outputs) - else: - for request_output in request_outputs: - self._add_output(request_output) - - except asyncio.CancelledError: - logger.debug("Shutting down MQLLMEngineClient output handler.") - - def _add_output(self, request_output: Union[RequestOutput, - RPCAdapterLoadedResponse, - RPCIsSleepingResponse]): - queue = self.output_queues.get(request_output.request_id) - if queue is not None: - queue.put_nowait(request_output) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Start output_loop - if self.output_loop is None: - # only generate once to avoid multiple concurrent output_loops - # this will lead to race conditions and wrong orders of tokens - # returned by the engine - # setup will be called multiple times during the startup of - # the engine - self.output_loop = asyncio.create_task( - self.run_output_handler_loop()) - - with self.get_data_socket() as socket: - # Wait until server is ready. - response = await self._wait_for_server_rpc(socket) - - self.tracing_flag = response.tracing_enabled - - # Start health_loop. - if self.health_loop is None: - self.health_loop = asyncio.create_task( - self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets and terminate the context. - self.context.destroy(linger=0) - - # Cancel background tasks. - if self.health_loop is not None: - self.health_loop.cancel() - if self.output_loop is not None: - self.output_loop.cancel() - - def _set_errored(self, e: BaseException): - logger.exception(repr(e)) - if self._errored_with is None: - self._errored_with = e - - @staticmethod - async def _send_get_data_rpc_request(request: RPCStartupRequest, - expected_type: Any, - error_message: str, - socket: Socket) -> Any: - """Send an RPC request that is expecting data back.""" - - # Ping RPCServer with a request. - await socket.send_multipart((pickle.dumps(request), ), copy=False) - - # Make sure the server responds in time. - if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: - raise TimeoutError("RPCServer didn't reply within " - f"{VLLM_RPC_TIMEOUT} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - data = pickle.loads(frame.buffer) - - if isinstance(data, BaseException): - raise data - elif not isinstance(data, expected_type): - raise ValueError(error_message) - - return data - - @staticmethod - async def _send_one_way_rpc_request(request: RPC_REQUEST_T, - socket: Socket): - """Send one-way RPC request to trigger an action.""" - - if socket.closed: - raise MQClientClosedError() - - await socket.send_multipart((pickle.dumps(request), )) - - async def _await_ack(self, error_message: str, socket: Socket): - """Await acknowledgement that a request succeeded.""" - - if socket.closed: - raise MQClientClosedError() - - if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: - raise TimeoutError("MQLLMEngine didn't reply within " - f"{VLLM_RPC_TIMEOUT}ms") - - await self._check_success(error_message, socket) - - @staticmethod - async def _check_success(error_message: str, socket: Socket): - """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" - - if socket.closed: - raise MQClientClosedError() - - frame = await socket.recv(copy=False) - response = pickle.loads(frame.buffer) - - # Raise error if unsuccessful - if isinstance(response, BaseException): - raise response - elif (not isinstance(response, str) - or response != VLLM_RPC_SUCCESS_STR): - raise ValueError(error_message) - - async def get_input_preprocessor(self) -> InputPreprocessor: - return self.input_preprocessor - - async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): - if self.tokenizer is None: - return None - else: - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_vllm_config(self) -> VllmConfig: - return self.vllm_config - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: - """Wait for the RPCServer to start up.""" - - return await self._send_get_data_rpc_request( - request=RPCStartupRequest.IS_SERVER_READY, - expected_type=RPCStartupResponse, - error_message="Unable to start RPC Server", - socket=socket) - - async def abort(self, request_id: Union[str, Iterable[str]]): - """Send an ABORT_REQUEST signal to the RPC Server""" - - if not isinstance(request_id, str): - raise RuntimeError("Only single-request abort supported in" - " deprecated V0") - - with suppress(MQClientClosedError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), socket=self.input_socket) - - async def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, - ) -> None: - """ - Ignore do_log_stats (handled on MQLLMEngine polling) - """ - pass - - async def check_health(self): - """ - The check health loop probes the health status of the - Engine's health every N seconds and sets _errored_with - if the engine is unhealthy. - """ - if self._errored_with is not None: - raise self._errored_with - - @property - def is_running(self) -> bool: - return not self.errored - - @property - def is_stopped(self) -> bool: - return self.errored - - @property - def errored(self) -> bool: - return self._errored_with is not None - - @property - def dead_error(self) -> BaseException: - return ENGINE_DEAD_ERROR(self._errored_with) - - def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput, None]: - """Generate outputs for a request. - - Generate outputs for a request. This method is a coroutine. It adds the - request into the waiting queue of the LLMEngine and streams the outputs - from the LLMEngine to the caller. - - Args: - prompt: The prompt to the LLM. See - [`PromptType`][vllm.inputs.PromptType] for more details about - the format of each input. - sampling_params: The sampling parameters of the request. - request_id: The unique id of the request. - lora_request: LoRA request to use for generation, if any. - trace_headers: OpenTelemetry trace headers. - priority: Priority of the request (lower means earlier handling). - Any priority other than 0 will lead to an error if the - scheduling policy is not "priority". - """ - return self._process_request(prompt, sampling_params, request_id, - lora_request, trace_headers, priority) - - def encode( - self, - prompt: PromptType, - pooling_params: PoolingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> AsyncGenerator[PoolingRequestOutput, None]: - raise NotImplementedError( - "Pooling models are not supported in vLLM V0") - - async def _process_request( - self, - prompt: PromptType, - params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - # If already dead, error out. - if self._errored_with is not None: - raise ENGINE_DEAD_ERROR(self._errored_with) - - # Ensure the request id is unique among running requests - if request_id in self.output_queues: - raise ValueError(f"Request {request_id} already exists") - - # 1) Create output queue for this request. - queue: asyncio.Queue[Union[RequestOutput, - BaseException]] = asyncio.Queue() - self.output_queues[request_id] = queue - - try: - # 2) Detach logits processors so that they can be pickled - # separately (may require cloudpickle which is slower) - if params.logits_processors: - # Defensive shallow copy - params = copy.copy(params) - logits_processors = params.logits_processors - params.logits_processors = None - lp_bytes = cloudpickle.dumps(logits_processors) - else: - lp_bytes = None - - request_bytes = pickle.dumps( - RPCProcessRequest( - prompt=prompt, - params=params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - )) - - # 3) Send the RPCGenerateRequest to the MQLLMEngine. - parts = (request_bytes, - lp_bytes) if lp_bytes else (request_bytes, ) - await self.input_socket.send_multipart(parts, copy=False) - - # 4) Stream the RequestOutputs from the output queue. Note - # that the output_loop pushes RequestOutput objects to this - # queue after pulling them from the zmq socket. - finished = False - try: - while not finished: - request_output = await queue.get() - - if isinstance(request_output, BaseException): - raise request_output - - finished = request_output.finished - yield request_output - finally: - # Request was canceled by the client. - if not finished and not self.errored: - await self.abort(request_id) - finally: - self.output_queues.pop(request_id) - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) - - async def reset_mm_cache(self) -> None: - """Reset the multi-modal cache""" - - await self._send_one_way_rpc_request( - request=RPCResetMultiModalCacheRequest.RESET, - socket=self.input_socket) - - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: - """Reset the prefix cache""" - - await self._send_one_way_rpc_request( - request=RPCResetPrefixCacheRequest(device), - socket=self.input_socket) - - async def sleep(self, level: int = 1) -> None: - """Sleep the engine for a given level""" - return await self._send_one_way_rpc_request( - request=RPCSleepRequest(level), socket=self.input_socket) - - async def wake_up(self, tags: Optional[list[str]] = None) -> None: - """Wake up the engine""" - return await self._send_one_way_rpc_request( - request=RPCWakeUpRequest(tags), socket=self.input_socket) - - async def is_sleeping(self) -> bool: - """Check whether the engine is sleeping""" - request = RPCIsSleepingRequest() - - queue: asyncio.Queue[Union[BaseException, - RPCIsSleepingResponse]] = asyncio.Queue() - self.output_queues[request.request_id] = queue - - request_bytes = pickle.dumps(request) - await self.input_socket.send_multipart((request_bytes, ), copy=False) - - request_output = await queue.get() - self.output_queues.pop(request.request_id) - - if isinstance(request_output, BaseException): - raise request_output - return request_output.is_sleeping - - async def add_lora(self, lora_request: LoRARequest) -> bool: - """Load a new LoRA adapter into the engine for future requests.""" - # Uses the same I/O as generate requests - request = RPCLoadAdapterRequest(lora_request) - - # Create output queue for this request. - queue: asyncio.Queue[Union[ - BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue() - self.output_queues[request.request_id] = queue - - # Send the request - request_bytes = pickle.dumps(request) - await self.input_socket.send_multipart((request_bytes, ), copy=False) - - # Wait for the response - request_output = await queue.get() - self.output_queues.pop(request.request_id) - - # Raise on error, otherwise happily return None - if isinstance(request_output, BaseException): - raise request_output - return request_output.lora_loaded diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py deleted file mode 100644 index 343b8df7e87bd..0000000000000 --- a/vllm/engine/multiprocessing/engine.py +++ /dev/null @@ -1,470 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pickle -import signal -from contextlib import contextmanager -from typing import Iterator, List, Optional, Union - -import cloudpickle -import zmq - -from vllm import AsyncEngineArgs, SamplingParams -from vllm.config import VllmConfig -from vllm.engine.llm_engine import LLMEngine -# yapf conflicts with isort for this block -# yapf: disable -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCAdapterLoadedResponse, RPCError, - RPCIsSleepingRequest, - RPCIsSleepingResponse, - RPCLoadAdapterRequest, - RPCProcessRequest, - RPCResetMultiModalCacheRequest, - RPCResetPrefixCacheRequest, - RPCSleepRequest, RPCStartupRequest, - RPCStartupResponse, - RPCUProfileRequest, RPCWakeUpRequest) -# yapf: enable -from vllm.logger import init_logger -from vllm.outputs import RequestOutput -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) -from vllm.usage.usage_lib import UsageContext -from vllm.utils import deprecate_kwargs -from vllm.worker.model_runner_base import InputProcessingError - -logger = init_logger(__name__) - -POLLING_TIMEOUT_MS = 10000 -HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) - - -class MQLLMEngine: - """A multiprocessing wrapper for - [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - - This class is used to wrap the - [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use - in concurrnet manner. It runs a background loop and uses zeromq to - receive new requests and stream outputs incrementally via ipc. - - The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode - process is kicked off when a new RPCProcessRequest is received by the - input_socket. - - The self.engine_loop checks the input_socket for new requests, - adds them to the LLMEngine if there are any, calls the internal - [`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends - the RequestOutputs back over the output_socket. - - If use_async_sockets is set, the logic associated with reading new - requests from the socket and sending data to the socket is passed - as a callback to the llm_engine, which calls the logic asynchronously - such that the IPC can be overlapped with the GPU. - - Args: - ipc_path: Base path for zeromq interprocess messaging - use_async_sockets: Whether to make send/recv async with GPU - log_requests: Whether to log the requests. - *args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - **kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. - """ - - def __init__(self, - ipc_path: str, - use_async_sockets: bool, - *args, - log_requests: bool = True, - **kwargs) -> None: - # For MQLLMEngine, we can use cached outputs, since each new request - # output is immediately pickled and send over the socket, which frees - # the python object to be reused again. - kwargs['use_cached_outputs'] = True - - self.engine = LLMEngine(*args, **kwargs) - self.log_requests = log_requests - - self.use_async_sockets = use_async_sockets - if self.use_async_sockets: - self.engine.process_request_outputs_callback = \ - self._async_socket_engine_callback - - self.ctx = zmq.Context() # type: ignore[attr-defined] - - # Receive input from the client. - self.input_socket = self.ctx.socket(zmq.constants.PULL) - self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") - - # Send output stream back to client. - self.output_socket = self.ctx.socket(zmq.constants.PUSH) - self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") - - # Send heartbeats back to client. - self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) - self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") - - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - - # Error state. - self._errored_with: Optional[BaseException] = None - - @property - def dead_error(self) -> BaseException: - if self._errored_with is not None: - return ENGINE_DEAD_ERROR(self._errored_with) - else: - return ENGINE_DEAD_ERROR() - - @classmethod - @deprecate_kwargs( - "disable_log_requests", - additional_message=("This argument will have no effect. " - "Use `enable_log_requests` instead."), - ) - def from_vllm_config( - cls, - vllm_config: VllmConfig, - usage_context: UsageContext, - enable_log_requests: bool, - disable_log_stats: bool, - ipc_path: str, - disable_log_requests: bool = True, # Deprecated, will be removed - ) -> "MQLLMEngine": - # Setup plugins for each process - from vllm.plugins import load_general_plugins - load_general_plugins() - - use_async_sockets = vllm_config.model_config.use_async_output_proc - - return cls( - vllm_config=vllm_config, - executor_class=LLMEngine._get_executor_cls(vllm_config), - ipc_path=ipc_path, - usage_context=usage_context, - use_async_sockets=use_async_sockets, - log_requests=enable_log_requests, - log_stats=(not disable_log_stats), - ) - - @staticmethod - def from_engine_args(engine_args: AsyncEngineArgs, - usage_context: UsageContext, ipc_path: str): - """Creates an MQLLMEngine from the engine arguments.""" - - vllm_config = engine_args.create_engine_config(usage_context) - return MQLLMEngine.from_vllm_config( - ipc_path=ipc_path, - vllm_config=vllm_config, - usage_context=usage_context, - enable_log_requests=engine_args.enable_log_requests, - disable_log_stats=engine_args.disable_log_stats, - ) - - def start(self): - try: - try: - logger.debug("Starting Startup Loop.") - self.run_startup_loop() - logger.debug("Starting Engine Loop.") - self.run_engine_loop() - except Exception as e: - logger.exception(repr(e)) - except KeyboardInterrupt: - logger.debug("Shutting down MQLLMEngine.") - finally: - logger.debug("MQLLMEngine is shut down.") - self.cleanup() - - def cleanup(self): - """Cleanup zeromq state on shutdown.""" - # Closes all sockets and destroys context. - self.ctx.destroy(linger=0) - del self.engine - - @contextmanager - def make_data_socket( - self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] - socket = self.ctx.socket(zmq.constants.ROUTER) - try: - socket.bind(self.data_ipc_path) - yield socket - finally: - socket.close(linger=0) - - def run_startup_loop(self) -> None: - """Startup loop for sending data from Engine -> Client.""" - - with self.make_data_socket() as socket: - response: Union[RPCStartupResponse, BaseException] - try: - identity, message = socket.recv_multipart(copy=False) - request: RPCStartupRequest = pickle.loads(message.buffer) - - # Handle the query from the Client. - if request == RPCStartupRequest.IS_SERVER_READY: - tracing_enabled = self.engine.is_tracing_enabled() - response = RPCStartupResponse( - tracing_enabled=tracing_enabled) - - except Exception as e: - response = e - - socket.send_multipart((identity, pickle.dumps(response)), - copy=False) - - def run_engine_loop(self): - """Core busy loop of the LLMEngine.""" - - while True: - if not self.engine.has_unfinished_requests(): - # Poll until there is work to do. - while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - # When there's no work, check on engine health and send - # health status back to client - self._health_check() - self.engine.do_log_stats() - logger.debug("Waiting for new requests in engine loop.") - - # Handle any input from the client. - self.handle_new_input() - - # Engine step. - request_outputs = self.engine_step() - - # Send request outputs (if async, done in engine_step callback). - if not self.use_async_sockets: - self._send_outputs(request_outputs) - - def engine_step(self) -> List[RequestOutput]: - """Engine step wrapper with error handling.""" - try: - return self.engine.step() - except SystemExit: - raise - except InputProcessingError as e: - # Special case where we handle an error preparing the inputs for - # a single request in the batch - rpc_err = RPCError(request_id=e.request_id, - is_engine_errored=False, - exception=e.__cause__) - self._send_outputs(rpc_err) - return [] - except BaseException as e: - self._set_errored(e) - rpc_err = RPCError(request_id=None, - is_engine_errored=True, - exception=e) - self._send_outputs(rpc_err) - raise e - - def handle_new_input(self): - """Handle new input from the socket""" - try: - while self.input_socket.poll(timeout=0) != 0: - frames = self.input_socket.recv_multipart(copy=False) - request = pickle.loads(frames[0].buffer) - - if isinstance(request, RPCProcessRequest): - if len(frames) > 1: - # Use cloudpickle for logits processors - assert isinstance(request.params, SamplingParams) - lprocs = cloudpickle.loads(frames[1].buffer) - request.params.logits_processors = lprocs - self._handle_process_request(request) - elif isinstance(request, RPCAbortRequest): - self._handle_abort_request(request) - elif isinstance(request, RPCUProfileRequest): - if request == RPCUProfileRequest.START_PROFILE: - self.start_profile() - else: - self.stop_profile() - elif isinstance(request, RPCLoadAdapterRequest): - self._handle_load_adapter_request(request) - elif isinstance(request, RPCResetMultiModalCacheRequest): - self.reset_mm_cache() - elif isinstance(request, RPCResetPrefixCacheRequest): - self.reset_prefix_cache() - elif isinstance(request, RPCSleepRequest): - self.sleep(request.value) - elif isinstance(request, RPCWakeUpRequest): - self.wake_up(request.tags) - elif isinstance(request, RPCIsSleepingRequest): - self._handle_is_sleeping_request(request) - else: - raise ValueError("Unknown RPCRequest Type: " - f"{type(request)}") - - except Exception as e: - self._set_errored(e) - self._send_unhealthy(e) - raise e from None - - def _handle_process_request(self, request: RPCProcessRequest): - """Handle RPCProcessRequest by adding it to the LLMEngine.""" - request_id = request.request_id - - if self._errored_with is not None: - rpc_err = RPCError(request_id=request_id, - is_engine_errored=True, - exception=ENGINE_DEAD_ERROR(self._errored_with)) - self._send_outputs(rpc_err) - - try: - self.engine.add_request(request_id=request_id, - prompt=request.prompt, - params=request.params, - lora_request=request.lora_request, - trace_headers=request.trace_headers, - priority=request.priority) - - if self.log_requests: - logger.info("Added request %s.", request.request_id) - - except Exception as e: - # We do not set self._errored = True here, since the error - # is due to an issue adding this request to the engine, - # rather than an issue with the engine itself. - logger.debug("Failed to add request %s to engine. %s", - request.request_id, e) - is_errored = self._errored_with is not None - rpc_err = RPCError(request_id=request_id, - is_engine_errored=is_errored, - exception=e) - self._send_outputs(rpc_err) - - # Remove request from the engine. - self.engine.abort_request(request_id) - - def _handle_abort_request(self, request: RPCAbortRequest): - self.engine.abort_request(request.request_id) - if self.log_requests: - logger.info("Aborted request %s.", request.request_id) - - def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): - try: - lora_loaded = self.engine.add_lora(request.lora_request) - except BaseException as e: - # Send back an error if the adater fails to load - rpc_err = RPCError(request_id=request.request_id, - is_engine_errored=False, - exception=e) - self._send_outputs(rpc_err) - return - # Otherwise, send back the successful load message - self._send_outputs( - RPCAdapterLoadedResponse(request_id=request.request_id, - lora_loaded=lora_loaded)) - - def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): - is_sleeping = self.is_sleeping() - self._send_outputs( - RPCIsSleepingResponse(request_id=request.request_id, - is_sleeping=is_sleeping)) - - def _health_check(self): - # Send unhealthy if engine has already errored - if self._errored_with is not None: - self._send_unhealthy(self._errored_with) - try: - self.engine.check_health() - self._send_healthy() - except Exception as e: - self._set_errored(e) - self._send_unhealthy(e) - - def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): - """Send outputs back to the engine client. These can be: - - Exceptions - - A list of generation outputs - - A response from loading a lora adapter - """ - if outputs: - try: - from ray.exceptions import RayTaskError - - # RayTaskError might not pickelable here. We need to unpack the - # underlying exception as the real exception in the output. - if (isinstance(outputs, RPCError) - and isinstance(outputs.exception, RayTaskError)): - outputs.exception = outputs.exception.cause - except ImportError: - pass - - output_bytes = pickle.dumps(outputs) - self.output_socket.send_multipart((output_bytes, ), copy=False) - - def _send_healthy(self): - """Send HEALTHY message to RPCClient.""" - if not self.heartbeat_socket.closed: - self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) - - def _send_unhealthy(self, error: BaseException): - """Send UNHEALTHY message to RPCClient.""" - if not self.heartbeat_socket.closed: - error_bytes = pickle.dumps(error) - self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) - - def _async_socket_engine_callback(self, - request_outputs: REQUEST_OUTPUTS_T): - """Callback used by engine to make socket handling async with GPU.""" - self._send_outputs(request_outputs) - self.handle_new_input() - - def _set_errored(self, e: BaseException): - """Log and set errored status if this is the first issue.""" - if self._errored_with is None: - self._errored_with = e - - def start_profile(self) -> None: - self.engine.start_profile() - - def stop_profile(self) -> None: - self.engine.stop_profile() - - def reset_mm_cache(self) -> bool: - return self.engine.reset_mm_cache() - - def reset_prefix_cache(self) -> bool: - return self.engine.reset_prefix_cache() - - def sleep(self, level: int = 1) -> None: - self.engine.sleep(level) - - def wake_up(self, tags: Optional[list[str]] = None) -> None: - self.engine.wake_up(tags) - - def is_sleeping(self) -> bool: - return self.engine.is_sleeping() - - -def signal_handler(*_) -> None: - raise KeyboardInterrupt("MQLLMEngine terminated") - - -def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext, - ipc_path: str, disable_log_stats: bool, - enable_log_requests: bool, engine_alive): - try: - # Ensure we can serialize transformer config before spawning - maybe_register_config_serialize_by_value() - - engine = MQLLMEngine.from_vllm_config( - vllm_config=vllm_config, - usage_context=usage_context, - disable_log_stats=disable_log_stats, - enable_log_requests=enable_log_requests, - ipc_path=ipc_path) - - signal.signal(signal.SIGTERM, signal_handler) - - engine.start() - - except BaseException as e: - logger.exception(e) - engine_alive.value = False - raise e from None diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 4d75719c1719b..587a9221e32c8 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -2,14 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Callable, List +from typing import List from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput +from vllm.sequence import SequenceGroup, SequenceGroupOutput from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC): detokenizer: Detokenizer, scheduler: List[Scheduler], seq_counter: Counter, - get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], stop_checker: "StopChecker", ): """Create an output processor. diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 3fb2f71b5e999..0916f1c918c85 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple from vllm.lora.request import LoRARequest +from vllm.reasoning import ReasoningParser from vllm.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceStatus -from vllm.transformers_utils.tokenizer import AnyTokenizer class StopChecker: @@ -16,11 +16,14 @@ class StopChecker: emitted, or if we have exceeded the max model len. """ - def __init__(self, max_model_len: int, - get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): + def __init__( + self, + max_model_len: int, + reasoner: Optional[ReasoningParser] = None, + ): # Do not use it directly, but use `self._get_max_model_len`. self._max_model_len = max_model_len - self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.reasoner = reasoner def _get_max_model_len(self, lora_req: Optional[LoRARequest]): if lora_req and lora_req.long_lora_max_len: @@ -57,6 +60,11 @@ class StopChecker: seq.status = SequenceStatus.FINISHED_STOPPED return + # Skip stop string/token checks if in reasoning content generation + if self.reasoner is not None and \ + not self.reasoner.is_reasoning_end(seq.get_token_ids()): + return + # Check if a stop token was encountered. # This assumes a single token produced per step. last_token_id = seq.get_last_token_id() diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py deleted file mode 100644 index 1e127eb982425..0000000000000 --- a/vllm/engine/output_processor/util.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List -from typing import Sequence as GenericSequence -from typing import cast - -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput - - -def create_output_by_sequence_group( - outputs: GenericSequence[SamplerOutput], - num_seq_groups: int) -> List[List[SequenceGroupOutput]]: - """Helper method which transforms a 2d list organized by - [step][sequence group] into [sequence group][step]. - """ - output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [ - [] for _ in range(num_seq_groups) - ] - for step in outputs: - sequence_group_output: CompletionSequenceGroupOutput - for i, sequence_group_output in enumerate(step): - output_by_sequence_group[i].append(sequence_group_output) - - # Cast to the more generic type that CompletionSequenceGroupOutput - # inherits from. - return cast(List[List[SequenceGroupOutput]], output_by_sequence_group) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index b0b11a33a4443..c345f17e6614f 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function -from vllm.config import DecodingConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.parse import is_explicit_encoder_decoder_prompt @@ -76,8 +76,8 @@ class EngineClient(ABC): include_stop_str_in_output = params.include_stop_str_in_output preprocessor = await self.get_input_preprocessor() - tokenizer_group = preprocessor.get_tokenizer_group() - tokenizer = await tokenizer_group.get_lora_tokenizer_async() + tokenizer = preprocessor.get_tokenizer() + eos_token_id = tokenizer.eos_token_id if is_explicit_encoder_decoder_prompt(prompt): raise NotImplementedError @@ -104,7 +104,7 @@ class EngineClient(ABC): tokenized_length = len(prompt_token_ids) sort_beams_key = create_sort_beams_key_function( - tokenizer.eos_token_id, length_penalty) + eos_token_id, length_penalty) beam_search_params = SamplingParams( logprobs=2 * beam_width, @@ -154,7 +154,7 @@ class EngineClient(ABC): if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] for token_id, logprob_obj in logprobs.items(): - if token_id == tokenizer.eos_token_id and \ + if token_id == eos_token_id and \ not ignore_eos: completed.append( BeamSearchSequence( @@ -166,7 +166,7 @@ class EngineClient(ABC): cum_logprob=current_beam.cum_logprob + logprob_obj.logprob, finish_reason="stop", - stop_reason=tokenizer.eos_token_id)) + stop_reason=eos_token_id)) else: new_beams.append( BeamSearchSequence( @@ -189,14 +189,14 @@ class EngineClient(ABC): best_beams = sorted_completed[:beam_width] for beam in best_beams: - if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos): + if (beam.tokens[-1] == eos_token_id and not ignore_eos): # Skip the eos token in the text. tokens = beam.tokens[tokenized_length:-1] else: tokens = beam.tokens[tokenized_length:] beam.text = tokenizer.decode(tokens) - beam_search_output = RequestOutput( + yield RequestOutput( request_id=request_id, prompt=prompt_text, outputs=[ @@ -214,8 +214,6 @@ class EngineClient(ABC): prompt_token_ids=prompt_token_ids, prompt_logprobs=None) - yield beam_search_output - @abstractmethod def encode( self, @@ -250,22 +248,14 @@ class EngineClient(ABC): """Get the model configuration of the vLLM engine.""" ... - @abstractmethod - async def get_decoding_config(self) -> DecodingConfig: - """Get the decoding configuration of the vLLM engine.""" - ... - @abstractmethod async def get_input_preprocessor(self) -> InputPreprocessor: """Get the input processor of the vLLM engine.""" ... @abstractmethod - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - """Get the appropriate tokenizer for the request""" + async def get_tokenizer(self) -> AnyTokenizer: + """Get the tokenizer""" ... async def get_io_processor(self) -> IOProcessor: diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 1954cbcbf1edd..00ef39f134653 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -41,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalUUIDDict) from vllm.multimodal.utils import MediaConnector # yapf: disable from vllm.transformers_utils.chat_templates import ( @@ -75,7 +76,7 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False): class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): - image_embeds: Required[Union[str, dict[str, str]]] + image_embeds: Optional[Union[str, dict[str, str]]] """ The image embeddings. It can be either: - A single base64 string. @@ -83,6 +84,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): """ type: Required[Literal["image_embeds"]] """The type of the content part.""" + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class VideoURL(TypedDict, total=False): @@ -103,6 +109,7 @@ class PILImage(BaseModel): """ A PIL.Image.Image object. """ + image_pil: Image.Image model_config = ConfigDict(arbitrary_types_allowed=True) @@ -115,7 +122,13 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False): "image_pil": ImageAsset('cherry_blossom').pil_image } """ - image_pil: Required[PILImage] + + image_pil: Optional[PILImage] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): @@ -127,7 +140,13 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): "image_url": "https://example.com/image.jpg" } """ - image_url: Required[str] + + image_url: Optional[str] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): @@ -138,7 +157,8 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): "audio_url": "https://example.com/audio.mp3" } """ - audio_url: Required[str] + + audio_url: Optional[str] class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): @@ -149,7 +169,13 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): "video_url": "https://example.com/video.mp4" } """ - video_url: Required[str] + + video_url: Optional[str] + uuid: Optional[str] + """ + User-provided UUID of a media. User must guarantee that it is properly + generated and unique for different medias. + """ class CustomThinkCompletionContentParam(TypedDict, total=False): @@ -174,19 +200,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ - OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + OpenAIChatCompletionContentPartParam, + ChatCompletionContentPartAudioParam, ChatCompletionContentPartInputAudioParam, - ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartVideoParam, + ChatCompletionContentPartRefusalParam, CustomChatCompletionContentPILImageParam, CustomChatCompletionContentSimpleImageParam, ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, - CustomChatCompletionContentSimpleVideoParam, str, - CustomThinkCompletionContentParam] + CustomChatCompletionContentSimpleVideoParam, + str, + CustomThinkCompletionContentParam, +] class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" + role: Required[str] """The role of the message's author.""" @@ -207,9 +238,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): """The tool calls generated by the model, such as function calls.""" -ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, - CustomChatCompletionMessageParam, - OpenAIHarmonyMessage] +ChatCompletionMessageParam = Union[ + OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam, + OpenAIHarmonyMessage, +] # TODO: Make fields ReadOnly once mypy supports it @@ -262,13 +295,13 @@ def _is_var_or_elems_access( key: Optional[str] = None, ) -> bool: if isinstance(node, jinja2.nodes.Filter): - return (node.node is not None - and _is_var_or_elems_access(node.node, varname, key)) + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key) if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) - if (isinstance(node, jinja2.nodes.Getitem) - and isinstance(node.arg, jinja2.nodes.Slice)): + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice): return _is_var_or_elems_access(node.node, varname, key) # yapf: disable @@ -373,15 +406,18 @@ def resolve_mistral_chat_template( ) -> Optional[str]: if chat_template is not None: logger.warning_once( - "'chat_template' cannot be overridden for mistral tokenizer.") + "'chat_template' cannot be overridden for mistral tokenizer." + ) if "add_generation_prompt" in kwargs: logger.warning_once( "'add_generation_prompt' is not supported for mistral tokenizer, " - "so it will be ignored.") + "so it will be ignored." + ) if "continue_final_message" in kwargs: logger.warning_once( "'continue_final_message' is not supported for mistral tokenizer, " - "so it will be ignored.") + "so it will be ignored." + ) return None @@ -401,23 +437,35 @@ def resolve_hf_chat_template( try: processor = cached_get_processor( tokenizer.name_or_path, - processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, - ProcessorMixin), + processor_cls=( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ), trust_remote_code=model_config.trust_remote_code, ) - if isinstance(processor, ProcessorMixin) and \ - hasattr(processor, 'chat_template') and \ - processor.chat_template is not None: + if ( + isinstance(processor, ProcessorMixin) + and hasattr(processor, "chat_template") + and processor.chat_template is not None + ): return processor.chat_template except Exception: - logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 + logger.debug( + "Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) # noqa: E501 # 3rd priority: AutoTokenizer chat template try: return tokenizer.get_chat_template(chat_template, tools=tools) except Exception: - logger.debug("Failed to load AutoTokenizer chat template for %s", - tokenizer.name_or_path, exc_info=True) + logger.debug( + "Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) # 4th priority: Predefined fallbacks path = get_chat_template_fallback_path( @@ -425,12 +473,16 @@ def resolve_hf_chat_template( tokenizer_name_or_path=model_config.tokenizer, ) if path is not None: - logger.info("Loading chat template fallback for %s as there isn't one " - "defined on HF Hub.", tokenizer.name_or_path) + logger.info( + "Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", + tokenizer.name_or_path, + ) chat_template = load_chat_template(path) else: - logger.debug("There is no chat template fallback for %s", - tokenizer.name_or_path) + logger.debug( + "There is no chat template fallback for %s", tokenizer.name_or_path + ) return chat_template @@ -452,11 +504,17 @@ def _resolve_chat_template_content_format( else: hf_chat_template = None - jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) - else load_chat_template(chat_template, is_literal=True)) + jinja_text = ( + hf_chat_template + if isinstance(hf_chat_template, str) + else load_chat_template(chat_template, is_literal=True) + ) - detected_format = ("string" if jinja_text is None else - _detect_content_format(jinja_text, default="string")) + detected_format = ( + "string" + if jinja_text is None + else _detect_content_format(jinja_text, default="string") + ) return detected_format @@ -512,7 +570,6 @@ def resolve_chat_template_content_format( return detected_format - ModalityStr = Literal["image", "audio", "video", "image_embeds"] _T = TypeVar("_T") @@ -530,7 +587,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self._model_config = model_config self._tokenizer = tokenizer - self._items_by_modality = defaultdict[str, list[_T]](list) + self._items_by_modality = defaultdict[str, list[Optional[_T]]](list) + self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list) @property def model_config(self) -> ModelConfig: @@ -539,6 +597,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): @cached_property def model_cls(self) -> type[SupportsMultiModal]: from vllm.model_executor.model_loader import get_model_cls + model_cls = get_model_cls(self.model_config) return cast(type[SupportsMultiModal], model_cls) @@ -554,10 +613,18 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def mm_processor(self): return self.mm_registry.create_processor(self.model_config) - def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + def add( + self, + modality: ModalityStr, + item: Optional[_T], + uuid: Optional[str] = None, + ) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. + + An optional uuid can be added which serves as a unique identifier of the + media. """ input_modality = modality.replace("_embeds", "") num_items = len(self._items_by_modality[modality]) + 1 @@ -565,37 +632,64 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): self.mm_processor.validate_num_items(input_modality, num_items) self._items_by_modality[modality].append(item) + self._uuids_by_modality[modality].append(uuid) return self.model_cls.get_placeholder_str(modality, num_items) + def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]: + if not self._items_by_modality: + return None + mm_uuids = {} + uuids_by_modality = dict(self._uuids_by_modality) + if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: + raise ValueError( + "Mixing raw image and embedding inputs is not allowed" + ) + + if "image_embeds" in uuids_by_modality: + image_embeds_uuids = uuids_by_modality["image_embeds"] + if len(image_embeds_uuids) > 1: + raise ValueError( + "Only one message can have {'type': 'image_embeds'}" + ) + mm_uuids["image"] = uuids_by_modality["image_embeds"] + if "image" in uuids_by_modality: + mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images + if "audio" in uuids_by_modality: + mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios + if "video" in uuids_by_modality: + mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos + return mm_uuids + @abstractmethod def create_parser(self) -> "BaseMultiModalContentParser": raise NotImplementedError class MultiModalItemTracker(BaseMultiModalItemTracker[object]): - def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: - raise ValueError(\ - "Mixing raw image and embedding inputs is not allowed") + raise ValueError( + "Mixing raw image and embedding inputs is not allowed" + ) if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: - raise ValueError(\ - "Only one message can have {'type': 'image_embeds'}") + raise ValueError( + "Only one message can have {'type': 'image_embeds'}" + ) mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: - mm_inputs["image"] = items_by_modality["image"] # A list of images + mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: - mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: - mm_inputs["video"] = items_by_modality["video"] # A list of videos + mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -603,32 +697,38 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): - async def all_mm_data(self) -> Optional[MultiModalDataDict]: if not self._items_by_modality: return None mm_inputs = {} - items_by_modality = { - modality: await asyncio.gather(*items) - for modality, items in self._items_by_modality.items() - } + items_by_modality = {} + for modality, items in self._items_by_modality.items(): + coros = [] + for item in items: + if item is not None: + coros.append(item) + else: + coros.append(asyncio.sleep(0)) + items_by_modality[modality] = await asyncio.gather(*coros) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError( - "Mixing raw image and embedding inputs is not allowed") + "Mixing raw image and embedding inputs is not allowed" + ) if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: raise ValueError( - "Only one message can have {'type': 'image_embeds'}") + "Only one message can have {'type': 'image_embeds'}" + ) mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: - mm_inputs["image"] = items_by_modality["image"] # A list of images + mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio" in items_by_modality: - mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: - mm_inputs["video"] = items_by_modality["video"] # A list of videos + mm_inputs["video"] = items_by_modality["video"] # A list of videos return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -636,7 +736,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): class BaseMultiModalContentParser(ABC): - def __init__(self) -> None: super().__init__() @@ -648,8 +747,9 @@ class BaseMultiModalContentParser(ABC): # } self._placeholder_storage: dict[str, list] = defaultdict(list) - def _add_placeholder(self, modality: ModalityStr, - placeholder: Optional[str]): + def _add_placeholder( + self, modality: ModalityStr, placeholder: Optional[str] + ): mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality] if placeholder: self._placeholder_storage[mod_placeholder].append(placeholder) @@ -658,108 +758,157 @@ class BaseMultiModalContentParser(ABC): return dict(self._placeholder_storage) @abstractmethod - def parse_image(self, image_url: str) -> None: + def parse_image( + self, image_url: Optional[str], uuid: Optional[str] = None) -> None: raise NotImplementedError @abstractmethod - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, + image_embeds: Union[str, dict[str, str], None], + uuid: Optional[str] = None, + ) -> None: raise NotImplementedError @abstractmethod - def parse_image_pil(self, image_pil: Image.Image) -> None: + def parse_image_pil( + self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_audio(self, audio_url: str) -> None: + def parse_audio( + self, audio_url: Optional[str], uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_input_audio(self, input_audio: InputAudio) -> None: + def parse_input_audio( + self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + ) -> None: raise NotImplementedError @abstractmethod - def parse_video(self, video_url: str) -> None: + def parse_video( + self, video_url: Optional[str], uuid: Optional[str] = None + ) -> None: raise NotImplementedError class MultiModalContentParser(BaseMultiModalContentParser): - def __init__(self, tracker: MultiModalItemTracker) -> None: super().__init__() self._tracker = tracker - + multimodal_config = self._tracker.model_config.multimodal_config + media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) self._connector = MediaConnector( - media_io_kwargs=self._tracker._model_config.media_io_kwargs, + media_io_kwargs=media_io_kwargs, allowed_local_media_path=tracker.allowed_local_media_path, ) - def parse_image(self, image_url: str) -> None: - image = self._connector.fetch_image(image_url) + def parse_image( + self, image_url: Optional[str], uuid: Optional[str] = None + ) -> None: + image = self._connector.fetch_image(image_url) if image_url else None - placeholder = self._tracker.add("image", image) + placeholder = self._tracker.add("image", image, uuid) self._add_placeholder("image", placeholder) - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: + def parse_image_embeds( + self, + image_embeds: Union[str, dict[str, str], None], + uuid: Optional[str] = None, + ) -> None: if isinstance(image_embeds, dict): embeds = { k: self._connector.fetch_image_embedding(v) for k, v in image_embeds.items() } - placeholder = self._tracker.add("image_embeds", embeds) + placeholder = self._tracker.add("image_embeds", embeds, uuid) if isinstance(image_embeds, str): embedding = self._connector.fetch_image_embedding(image_embeds) - placeholder = self._tracker.add("image_embeds", embedding) + placeholder = self._tracker.add("image_embeds", embedding, uuid) + + if image_embeds is None: + placeholder = self._tracker.add("image_embeds", None, uuid) self._add_placeholder("image", placeholder) - def parse_image_pil(self, image_pil: Image.Image) -> None: - placeholder = self._tracker.add("image", image_pil) + def parse_image_pil( + self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + ) -> None: + placeholder = self._tracker.add("image", image_pil, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: str) -> None: - audio = self._connector.fetch_audio(audio_url) + def parse_audio( + self, audio_url: Optional[str], uuid: Optional[str] = None + ) -> None: + audio = self._connector.fetch_audio(audio_url) if audio_url else None - placeholder = self._tracker.add("audio", audio) + placeholder = self._tracker.add("audio", audio, uuid) self._add_placeholder("audio", placeholder) - def parse_input_audio(self, input_audio: InputAudio) -> None: - audio_data = input_audio.get("data", "") - audio_format = input_audio.get("format", "") - audio_url = f"data:audio/{audio_format};base64,{audio_data}" + def parse_input_audio( + self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + ) -> None: + if input_audio: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + if audio_data: + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + else: + # If a UUID is provided, audio data may be empty. + audio_url = None + else: + audio_url = None - return self.parse_audio(audio_url) + return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: str) -> None: - video = self._connector.fetch_video(video_url=video_url) + def parse_video( + self, video_url: Optional[str], uuid: Optional[str] = None + ) -> None: + video = ( + self._connector.fetch_video(video_url=video_url) + if video_url + else None + ) - placeholder = self._tracker.add("video", video) + placeholder = self._tracker.add("video", video, uuid) self._add_placeholder("video", placeholder) class AsyncMultiModalContentParser(BaseMultiModalContentParser): - def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: super().__init__() self._tracker = tracker + multimodal_config = self._tracker.model_config.multimodal_config + media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) self._connector = MediaConnector( - media_io_kwargs=self._tracker._model_config.media_io_kwargs, - allowed_local_media_path=tracker.allowed_local_media_path + media_io_kwargs=media_io_kwargs, + allowed_local_media_path=tracker.allowed_local_media_path, ) - def parse_image(self, image_url: str) -> None: - image_coro = self._connector.fetch_image_async(image_url) + def parse_image( + self, image_url: Optional[str], uuid: Optional[str] = None + ) -> None: + image_coro = ( + self._connector.fetch_image_async(image_url) if image_url else None + ) - placeholder = self._tracker.add("image", image_coro) + placeholder = self._tracker.add("image", image_coro, uuid) self._add_placeholder("image", placeholder) - def parse_image_embeds(self, - image_embeds: Union[str, dict[str, str]]) -> None: - future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() + def parse_image_embeds( + self, + image_embeds: Union[str, dict[str, str], None], + uuid: Optional[str] = None, + ) -> None: + future: asyncio.Future[Union[str, dict[str, str], None]] = ( + asyncio.Future() + ) if isinstance(image_embeds, dict): embeds = { @@ -769,37 +918,63 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): future.set_result(embeds) if isinstance(image_embeds, str): - embedding = self._connector.\ - fetch_image_embedding(image_embeds) + embedding = self._connector.fetch_image_embedding(image_embeds) future.set_result(embedding) - placeholder = self._tracker.add("image_embeds", future) + if image_embeds is None: + future.set_result(None) + + placeholder = self._tracker.add("image_embeds", future, uuid) self._add_placeholder("image", placeholder) - def parse_image_pil(self, image_pil: Image.Image) -> None: - future: asyncio.Future[Image.Image] = asyncio.Future() - future.set_result(image_pil) + def parse_image_pil( + self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + ) -> None: + future: asyncio.Future[Optional[Image.Image]] = asyncio.Future() + if image_pil: + future.set_result(image_pil) + else: + future.set_result(None) - placeholder = self._tracker.add("image", future) + placeholder = self._tracker.add("image", future, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: str) -> None: - audio_coro = self._connector.fetch_audio_async(audio_url) + def parse_audio( + self, audio_url: Optional[str], uuid: Optional[str] = None + ) -> None: + audio_coro = ( + self._connector.fetch_audio_async(audio_url) if audio_url else None + ) - placeholder = self._tracker.add("audio", audio_coro) + placeholder = self._tracker.add("audio", audio_coro, uuid) self._add_placeholder("audio", placeholder) - def parse_input_audio(self, input_audio: InputAudio) -> None: - audio_data = input_audio.get("data", "") - audio_format = input_audio.get("format", "") - audio_url = f"data:audio/{audio_format};base64,{audio_data}" + def parse_input_audio( + self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + ) -> None: + if input_audio: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + if audio_data: + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + else: + # If a UUID is provided, audio data may be empty. + audio_url = None + else: + audio_url = None - return self.parse_audio(audio_url) + return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: str) -> None: - video = self._connector.fetch_video_async(video_url=video_url) + def parse_video( + self, video_url: Optional[str], uuid: Optional[str] = None + ) -> None: + video = ( + self._connector.fetch_video_async(video_url=video_url) + if video_url + else None + ) - placeholder = self._tracker.add("video", video) + placeholder = self._tracker.add("video", video, uuid) self._add_placeholder("video", placeholder) @@ -809,20 +984,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): return elif isinstance(chat_template, Path) and not chat_template.exists(): - raise FileNotFoundError( - "the supplied chat template path doesn't exist") + raise FileNotFoundError("the supplied chat template path doesn't exist") elif isinstance(chat_template, str): JINJA_CHARS = "{}\n" - if not any(c in chat_template - for c in JINJA_CHARS) and not Path(chat_template).exists(): + if ( + not any(c in chat_template for c in JINJA_CHARS) + and not Path(chat_template).exists() + ): raise ValueError( f"The supplied chat template string ({chat_template}) " - f"appears path-like, but doesn't exist!") + f"appears path-like, but doesn't exist!" + ) else: raise TypeError( - f"{type(chat_template)} is not a valid chat template type") + f"{type(chat_template)} is not a valid chat template type" + ) def _load_chat_template( @@ -835,8 +1013,9 @@ def _load_chat_template( if is_literal: if isinstance(chat_template, Path): - raise TypeError("chat_template is expected to be read directly " - "from its value") + raise TypeError( + "chat_template is expected to be read directly from its value" + ) return chat_template @@ -849,9 +1028,11 @@ def _load_chat_template( JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") + msg = ( + f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}" + ) raise ValueError(msg) from e # If opening a file fails, set chat template to be args to @@ -870,8 +1051,9 @@ def load_chat_template( return _cached_load_chat_template(chat_template, is_literal=is_literal) -def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], - texts: list[str]) -> str: +def _get_interleaved_text_prompt( + placeholder_storage: dict[str, list], texts: list[str] +) -> str: for idx, elem in enumerate(texts): if elem in placeholder_storage: texts[idx] = placeholder_storage[elem].pop(0) @@ -881,10 +1063,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], - texts: list[str], - interleave_strings: bool - ) -> str: +def _get_full_multimodal_text_prompt( + placeholder_storage: dict[str, list], + texts: list[str], + interleave_strings: bool, +) -> str: """Combine multimodal prompts for a multimodal language model.""" # flatten storage to make it looks like @@ -907,7 +1090,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], # Look through the text prompt to check for missing placeholders missing_placeholders: list[str] = [] for placeholder in placeholder_counts: - # For any existing placeholder in the text prompt, we leave it as is placeholder_counts[placeholder] -= text_prompt.count(placeholder) @@ -916,15 +1098,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], "Placeholder count is negative! " "Ensure that the 'interleave_strings' flag is disabled " "(current value: %s) " - "when manually placing image placeholders.", interleave_strings + "when manually placing image placeholders.", + interleave_strings, ) logger.debug("Input prompt: %s", text_prompt) raise ValueError( f"Found more '{placeholder}' placeholders in input prompt than " - "actual multimodal data items.") + "actual multimodal data items." + ) - missing_placeholders.extend([placeholder] * - placeholder_counts[placeholder]) + missing_placeholders.extend( + [placeholder] * placeholder_counts[placeholder] + ) # NOTE: Default behaviour: we always add missing placeholders # at the front of the prompt, if interleave_strings=False @@ -944,7 +1129,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python _ResponsesInputImageParser = TypeAdapter( - ResponseInputImageParam).validate_python + ResponseInputImageParam +).validate_python _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] # Define a mapping from part types to their corresponding parsing functions. @@ -952,32 +1138,35 @@ MM_PARSER_MAP: dict[ str, Callable[[ChatCompletionContentPartParam], _ContentPart], ] = { - "text": - lambda part: _TextParser(part).get("text", None), - "thinking": - lambda part: _ThinkParser(part).get("thinking", None), - "input_text": - lambda part: _TextParser(part).get("text", None), - "input_image": - lambda part: _ResponsesInputImageParser(part).get("image_url", None), - "image_url": - lambda part: _ImageParser(part).get("image_url", {}).get("url", None), - "image_embeds": - lambda part: _ImageEmbedsParser(part).get("image_embeds", None), + "text": lambda part: _TextParser(part).get("text", None), + "thinking": lambda part: _ThinkParser(part).get("thinking", None), + "input_text": lambda part: _TextParser(part).get("text", None), + "input_image": lambda part: _ResponsesInputImageParser(part).get( + "image_url", None + ), + "image_url": lambda part: _ImageParser(part) + .get("image_url", {}) + .get("url", None), + "image_embeds": lambda part: _ImageEmbedsParser(part).get( + "image_embeds", None + ), "image_pil": lambda part: _PILImageParser(part).get("image_pil", None), - "audio_url": - lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), - "input_audio": - lambda part: _InputAudioParser(part).get("input_audio", None), - "refusal": - lambda part: _RefusalParser(part).get("refusal", None), - "video_url": - lambda part: _VideoParser(part).get("video_url", {}).get("url", None), + "audio_url": lambda part: _AudioParser(part) + .get("audio_url", {}) + .get("url", None), + "input_audio": lambda part: _InputAudioParser(part).get( + "input_audio", None + ), + "refusal": lambda part: _RefusalParser(part).get("refusal", None), + "video_url": lambda part: _VideoParser(part) + .get("video_url", {}) + .get("url", None), } def _parse_chat_message_content_mm_part( - part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: + part: ChatCompletionContentPartParam, +) -> tuple[str, _ContentPart]: """ Parses a given multi-modal content part based on its type. @@ -993,38 +1182,74 @@ def _parse_chat_message_content_mm_part( ValueError: If the 'type' field is missing and no direct URL is found. """ assert isinstance( - part, dict) # This is needed to avoid mypy errors: part.get() from str + part, dict + ) # This is needed to avoid mypy errors: part.get() from str part_type = part.get("type", None) + uuid = part.get("uuid", None) - if isinstance(part_type, str) and part_type in MM_PARSER_MAP: + if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501 content = MM_PARSER_MAP[part_type](part) # Special case for 'image_url.detail' # We only support 'auto', which is the default if part_type == "image_url" and part.get("detail", "auto") != "auto": - logger.warning("'image_url.detail' is currently not supported " - "and will be ignored.") + logger.warning( + "'image_url.detail' is currently not supported " + "and will be ignored." + ) return part_type, content # Handle missing 'type' but provided direct URL fields. # 'type' is required field by pydantic - if part_type is None: - if part.get("image_url") is not None: - image_params = cast(CustomChatCompletionContentSimpleImageParam, - part) - return "image_url", image_params.get("image_url", "") - if part.get("audio_url") is not None: - audio_params = cast(CustomChatCompletionContentSimpleAudioParam, - part) - return "audio_url", audio_params.get("audio_url", "") + if part_type is None or uuid is not None: + if "image_url" in part: + image_params = cast( + CustomChatCompletionContentSimpleImageParam, part + ) + image_url = image_params.get("image_url", None) + if isinstance(image_url, dict): + # Can potentially happen if user provides a uuid + # with url as a dict of {"url": url} + image_url = image_url.get("url", None) + return "image_url", image_url + if "image_pil" in part: + # "image_pil" could be None if UUID is provided. + image_params = cast( # type: ignore + CustomChatCompletionContentPILImageParam, part + ) + image_pil = image_params.get("image_pil", None) + return "image_pil", image_pil + if "image_embeds" in part: + # "image_embeds" could be None if UUID is provided. + image_params = cast( # type: ignore + ChatCompletionContentPartImageEmbedsParam, part + ) + image_embeds = image_params.get("image_embeds", None) + return "image_embeds", image_embeds + if "audio_url" in part: + audio_params = cast( + CustomChatCompletionContentSimpleAudioParam, part + ) + audio_url = audio_params.get("audio_url", None) + if isinstance(audio_url, dict): + # Can potentially happen if user provides a uuid + # with url as a dict of {"url": url} + audio_url = audio_url.get("url", None) + return "audio_url", audio_url if part.get("input_audio") is not None: input_audio_params = cast(dict[str, str], part) return "input_audio", input_audio_params - if part.get("video_url") is not None: - video_params = cast(CustomChatCompletionContentSimpleVideoParam, - part) - return "video_url", video_params.get("video_url", "") + if "video_url" in part: + video_params = cast( + CustomChatCompletionContentSimpleVideoParam, part + ) + video_url = video_params.get("video_url", None) + if isinstance(video_url, dict): + # Can potentially happen if user provides a uuid + # with url as a dict of {"url": url} + video_url = video_url.get("url", None) + return "video_url", video_url # Raise an error if no 'type' or direct URL is found. raise ValueError("Missing 'type' field in multimodal part.") @@ -1033,9 +1258,10 @@ def _parse_chat_message_content_mm_part( return part_type, "unknown part_type content" -VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", - "image_embeds", "image_pil", - "audio_url", "input_audio", "video_url") +PART_TYPES_TO_SKIP_NONE_CONTENT = ( + "text", + "refusal", +) def _parse_chat_message_content_parts( @@ -1055,21 +1281,20 @@ def _parse_chat_message_content_parts( part, mm_parser, wrap_dicts=wrap_dicts, - interleave_strings=interleave_strings + interleave_strings=interleave_strings, ) if parse_res: content.append(parse_res) if wrap_dicts: # Parsing wraps images and texts as interleaved dictionaries - return [ConversationMessage(role=role, - content=content)] # type: ignore + return [ConversationMessage(role=role, content=content)] # type: ignore texts = cast(list[str], content) mm_placeholder_storage = mm_parser.mm_placeholder_storage() if mm_placeholder_storage: - text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, - texts, - interleave_strings) + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_storage, texts, interleave_strings + ) else: text_prompt = "\n".join(texts) @@ -1096,49 +1321,68 @@ def _parse_chat_message_content_part( part_type, content = _parse_chat_message_content_mm_part(part) # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but # content is None, log a warning and skip - if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: + if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None: logger.warning( "Skipping multimodal part '%s' (type: '%s') " - "with empty / unparsable content.", part, part_type) + "with empty / unparsable content.", + part, + part_type, + ) return None if part_type in ("text", "input_text", "refusal", "thinking"): str_content = cast(str, content) if wrap_dicts: - return {'type': 'text', 'text': str_content} + return {"type": "text", "text": str_content} else: return str_content + # For media items, if a user has provided one, use it. Otherwise, insert + # a placeholder empty uuid. + uuid = part.get("uuid", None) + if uuid is not None: + uuid = str(uuid) + modality = None if part_type == "image_pil": - image_content = cast(Image.Image, content) - mm_parser.parse_image_pil(image_content) + if content is not None: + image_content = cast(Image.Image, content) + else: + image_content = None + mm_parser.parse_image_pil(image_content, uuid) modality = "image" elif part_type in ("image_url", "input_image"): str_content = cast(str, content) - mm_parser.parse_image(str_content) + mm_parser.parse_image(str_content, uuid) modality = "image" elif part_type == "image_embeds": - content = cast(Union[str, dict[str, str]], content) - mm_parser.parse_image_embeds(content) + if content is not None: + content = cast(Union[str, dict[str, str]], content) + else: + content = None + mm_parser.parse_image_embeds(content, uuid) modality = "image" elif part_type == "audio_url": str_content = cast(str, content) - mm_parser.parse_audio(str_content) + mm_parser.parse_audio(str_content, uuid) modality = "audio" elif part_type == "input_audio": dict_content = cast(InputAudio, content) - mm_parser.parse_input_audio(dict_content) + mm_parser.parse_input_audio(dict_content, uuid) modality = "audio" elif part_type == "video_url": str_content = cast(str, content) - mm_parser.parse_video(str_content) + mm_parser.parse_video(str_content, uuid) modality = "video" else: raise NotImplementedError(f"Unknown part type: {part_type}") - return {'type': modality} if wrap_dicts else ( - MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None + return ( + {"type": modality} + if wrap_dicts + else ( + MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None + ) ) @@ -1171,14 +1415,16 @@ def _parse_chat_message_content( ) for result_msg in result: - if role == 'assistant': + if role == "assistant": parsed_msg = _AssistantParser(message) # The 'tool_calls' is not None check ensures compatibility. # It's needed only if downstream code doesn't strictly # follow the OpenAI spec. - if ("tool_calls" in parsed_msg - and parsed_msg["tool_calls"] is not None): + if ( + "tool_calls" in parsed_msg + and parsed_msg["tool_calls"] is not None + ): result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) elif role == "tool": parsed_msg = _ToolParser(message) @@ -1198,12 +1444,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: # so, for messages that have tool_calls, parse the string (which we get # from openAI format) to dict for message in messages: - if (message["role"] == "assistant" and "tool_calls" in message - and isinstance(message["tool_calls"], list)): - + if ( + message["role"] == "assistant" + and "tool_calls" in message + and isinstance(message["tool_calls"], list) + ): for item in message["tool_calls"]: item["function"]["arguments"] = json.loads( - item["function"]["arguments"]) + item["function"]["arguments"] + ) def parse_chat_messages( @@ -1211,7 +1460,11 @@ def parse_chat_messages( model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]: +) -> tuple[ + list[ConversationMessage], + Optional[MultiModalDataDict], + Optional[MultiModalUUIDDict], +]: conversation: list[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) @@ -1224,14 +1477,14 @@ def parse_chat_messages( content_format == "string" and model_config.multimodal_config is not None and model_config.multimodal_config.interleave_mm_strings - ) + ), ) conversation.extend(sub_messages) _postprocess_messages(conversation) - return conversation, mm_tracker.all_mm_data() + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() def parse_chat_messages_futures( @@ -1239,7 +1492,11 @@ def parse_chat_messages_futures( model_config: ModelConfig, tokenizer: AnyTokenizer, content_format: _ChatTemplateContentFormat, -) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: +) -> tuple[ + list[ConversationMessage], + Awaitable[Optional[MultiModalDataDict]], + Optional[MultiModalUUIDDict], +]: conversation: list[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) @@ -1252,14 +1509,14 @@ def parse_chat_messages_futures( content_format == "string" and model_config.multimodal_config is not None and model_config.multimodal_config.interleave_mm_strings - ) + ), ) conversation.extend(sub_messages) _postprocess_messages(conversation) - return conversation, mm_tracker.all_mm_data() + return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() def apply_hf_chat_template( @@ -1283,10 +1540,10 @@ def apply_hf_chat_template( raise ValueError( "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one.") + "does not define one." + ) try: - return tokenizer.apply_chat_template( conversation=conversation, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type] @@ -1298,13 +1555,14 @@ def apply_hf_chat_template( # External library exceptions can sometimes occur despite the framework's # internal exception management capabilities. except Exception as e: - # Log and report any library-related exceptions for further # investigation. logger.exception( - "An error occurred in `transformers` while applying chat template") + "An error occurred in `transformers` while applying chat template" + ) raise ValueError(str(e)) from e + def apply_mistral_chat_template( tokenizer: MistralTokenizer, messages: list[ChatCompletionMessageParam], @@ -1337,26 +1595,26 @@ def apply_mistral_chat_template( # External library exceptions can sometimes occur despite the framework's # internal exception management capabilities. except Exception as e: - # Log and report any library-related exceptions for further # investigation. logger.exception( - "An error occurred in `mistral_common` while applying chat " - "template") + "An error occurred in `mistral_common` while applying chat template" + ) raise ValueError(str(e)) from e + def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): idx = 0 for msg in conversation: - if msg['role'] == 'assistant': - tool_calls = msg.get('tool_calls') - idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa + if msg["role"] == "assistant": + tool_calls = msg.get("tool_calls") + idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa return idx -def make_tool_call_id(id_type:str='random', func_name=None, idx=None): - if id_type=='kimi_k2': - return f'functions.{func_name}:{idx}' +def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): + if id_type == "kimi_k2": + return f"functions.{func_name}:{idx}" else: # by default return random return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index 7c01de94a3436..1929d6a7f77af 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -45,6 +45,28 @@ def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]: return model_name, openai_client +def _print_chat_stream(stream) -> str: + output = "" + for chunk in stream: + delta = chunk.choices[0].delta + if delta.content: + output += delta.content + print(delta.content, end="", flush=True) + print() + return output + + +def _print_completion_stream(stream) -> str: + output = "" + for chunk in stream: + text = chunk.choices[0].text + if text is not None: + output += text + print(text, end="", flush=True) + print() + return output + + def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: conversation: list[ChatCompletionMessageParam] = [] if system_prompt is not None: @@ -58,14 +80,11 @@ def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: break conversation.append({"role": "user", "content": input_message}) - chat_completion = client.chat.completions.create(model=model_name, - messages=conversation) - - response_message = chat_completion.choices[0].message - output = response_message.content - - conversation.append(response_message) # type: ignore - print(output) + stream = client.chat.completions.create(model=model_name, + messages=conversation, + stream=True) + output = _print_chat_stream(stream) + conversation.append({"role": "assistant", "content": output}) def _add_query_options( @@ -108,9 +127,11 @@ class ChatCommand(CLISubcommand): if args.quick: conversation.append({"role": "user", "content": args.quick}) - chat_completion = client.chat.completions.create( - model=model_name, messages=conversation) - print(chat_completion.choices[0].message.content) + stream = client.chat.completions.create(model=model_name, + messages=conversation, + stream=True) + output = _print_chat_stream(stream) + conversation.append({"role": "assistant", "content": output}) return print("Please enter a message for the chat model:") @@ -121,14 +142,11 @@ class ChatCommand(CLISubcommand): break conversation.append({"role": "user", "content": input_message}) - chat_completion = client.chat.completions.create( - model=model_name, messages=conversation) - - response_message = chat_completion.choices[0].message - output = response_message.content - - conversation.append(response_message) # type: ignore - print(output) + stream = client.chat.completions.create(model=model_name, + messages=conversation, + stream=True) + output = _print_chat_stream(stream) + conversation.append({"role": "assistant", "content": output}) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -168,9 +186,10 @@ class CompleteCommand(CLISubcommand): model_name, client = _interactive_cli(args) if args.quick: - completion = client.completions.create(model=model_name, - prompt=args.quick) - print(completion.choices[0].text) + stream = client.completions.create(model=model_name, + prompt=args.quick, + stream=True) + _print_completion_stream(stream) return print("Please enter prompt to complete:") @@ -179,10 +198,10 @@ class CompleteCommand(CLISubcommand): input_prompt = input("> ") except EOFError: break - completion = client.completions.create(model=model_name, - prompt=input_prompt) - output = completion.choices[0].text - print(output) + stream = client.completions.create(model=model_name, + prompt=input_prompt, + stream=True) + _print_completion_stream(stream) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 9d587e8669339..8619452f2445f 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import contextlib import json import logging from abc import ABC, abstractmethod -from collections.abc import Sequence from contextlib import AsyncExitStack from typing import TYPE_CHECKING, Optional, Union @@ -21,6 +22,23 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class TurnTokens: + """Tracks token counts for a single conversation turn.""" + + def __init__(self, input_tokens=0, output_tokens=0): + self.input_tokens = input_tokens + self.output_tokens = output_tokens + + def reset(self): + """Reset counters for a new turn.""" + self.input_tokens = 0 + self.output_tokens = 0 + + def copy(self): + """Create a copy of this turn's token counts.""" + return TurnTokens(self.input_tokens, self.output_tokens) + + class ConversationContext(ABC): @abstractmethod @@ -41,17 +59,32 @@ class ConversationContext(ABC): @abstractmethod async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack) -> None: + exit_stack: AsyncExitStack, + request_id: str) -> None: pass + @abstractmethod + async def cleanup_session(self) -> None: + raise NotImplementedError("Should not be called.") + class SimpleContext(ConversationContext): def __init__(self): self.last_output = None + self.num_prompt_tokens = 0 + self.num_output_tokens = 0 + self.num_cached_tokens = 0 + # todo num_reasoning_tokens is not implemented yet. + self.num_reasoning_tokens = 0 def append_output(self, output) -> None: self.last_output = output + if not isinstance(output, RequestOutput): + raise ValueError("SimpleContext only supports RequestOutput.") + self.num_prompt_tokens = len(output.prompt_token_ids or []) + self.num_cached_tokens = output.num_cached_tokens or 0 + self.num_output_tokens += len(output.outputs[0].token_ids or []) def need_builtin_tool_call(self) -> bool: return False @@ -63,9 +96,13 @@ class SimpleContext(ConversationContext): raise NotImplementedError("Should not be called.") async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack) -> None: + exit_stack: AsyncExitStack, + request_id: str) -> None: pass + async def cleanup_session(self) -> None: + raise NotImplementedError("Should not be called.") + class HarmonyContext(ConversationContext): @@ -75,41 +112,139 @@ class HarmonyContext(ConversationContext): available_tools: list[str], ): self._messages = messages + self.finish_reason: Optional[str] = None self.available_tools = available_tools self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {} + self.called_tools: set[str] = set() self.parser = get_streamable_parser_for_assistant() self.num_init_messages = len(messages) self.num_prompt_tokens = 0 self.num_output_tokens = 0 - # TODO(woosuk): Implement the following fields. self.num_cached_tokens = 0 self.num_reasoning_tokens = 0 + self.num_tool_output_tokens = 0 - def _update_num_prompt_tokens(self, output: RequestOutput): - if output.prompt_token_ids and len(output.prompt_token_ids) > 0: - # NOTE: with built-in tools, there might be multiple rounds in - # the conversation, with the full conversation being resent - # as new prompt each time. Hence the sum. - self.num_prompt_tokens += len(output.prompt_token_ids) + # Turn tracking - replaces multiple individual tracking variables + self.current_turn = TurnTokens() + self.previous_turn = TurnTokens() + self.is_first_turn = True + self.first_tok_of_message = True # For streaming support - def _update_num_output_tokens(self, token_ids: Sequence[int]): - self.num_output_tokens += len(token_ids) + def _update_num_reasoning_tokens(self): + # Count all analysis and commentary channels as reasoning tokens + if self.parser.current_channel in {"analysis", "commentary"}: + self.num_reasoning_tokens += 1 - def append_output(self, output) -> None: + def append_output(self, output: Union[RequestOutput, + list[Message]]) -> None: if isinstance(output, RequestOutput): - self._update_num_prompt_tokens(output) output_token_ids = output.outputs[0].token_ids - self._update_num_output_tokens(output_token_ids) self.parser = get_streamable_parser_for_assistant() for token_id in output_token_ids: self.parser.process(token_id) + # Check if the current token is part of reasoning content + self._update_num_reasoning_tokens() + self._update_prefill_token_usage(output) + # Reset current turn output tokens for this turn + self.current_turn.output_tokens = 0 + self._update_decode_token_usage(output) + # Move current turn to previous turn for next turn's calculations + self.previous_turn = self.current_turn.copy() + # append_output is called only once before tool calling + # in non-streaming case + # so we can append all the parser messages to _messages output_msgs = self.parser.messages + # The responses finish reason is set in the last message + self.finish_reason = output.outputs[0].finish_reason else: # Tool output. output_msgs = output self._messages.extend(output_msgs) + def _update_prefill_token_usage(self, output: RequestOutput) -> None: + """Update token usage statistics for the prefill phase of generation. + + The prefill phase processes the input prompt tokens. This method: + 1. Counts the prompt tokens for this turn + 2. Calculates tool output tokens for multi-turn conversations + 3. Updates cached token counts + 4. Tracks state for next turn calculations + + Tool output tokens are calculated as: + current_prompt_tokens - last_turn_prompt_tokens - + last_turn_output_tokens + This represents tokens added between turns (typically tool responses). + + Args: + output: The RequestOutput containing prompt token information + """ + if output.prompt_token_ids is not None: + this_turn_input_tokens = len(output.prompt_token_ids) + else: + this_turn_input_tokens = 0 + logger.error( + "RequestOutput appended contains no prompt_token_ids.") + + # Update current turn input tokens + self.current_turn.input_tokens = this_turn_input_tokens + self.num_prompt_tokens += this_turn_input_tokens + + # Calculate tool tokens (except on first turn) + if self.is_first_turn: + self.is_first_turn = False + else: + # start counting tool after first turn + # tool tokens = this turn prefill - last turn prefill - + # last turn decode + this_turn_tool_tokens = (self.current_turn.input_tokens - + self.previous_turn.input_tokens - + self.previous_turn.output_tokens) + + # Handle negative tool token counts (shouldn't happen in normal + # cases) + if this_turn_tool_tokens < 0: + logger.error( + "Negative tool output tokens calculated: %d " + "(current_input=%d, previous_input=%d, " + "previous_output=%d). Setting to 0.", + this_turn_tool_tokens, self.current_turn.input_tokens, + self.previous_turn.input_tokens, + self.previous_turn.output_tokens) + this_turn_tool_tokens = 0 + + self.num_tool_output_tokens += this_turn_tool_tokens + + # Update cached tokens + if output.num_cached_tokens is not None: + self.num_cached_tokens += output.num_cached_tokens + + def _update_decode_token_usage(self, output: RequestOutput) -> int: + """Update token usage statistics for the decode phase of generation. + + The decode phase processes the generated output tokens. This method: + 1. Counts output tokens from all completion outputs + 2. Updates the total output token count + 3. Tracks tokens generated in the current turn + + In streaming mode, this is called for each token generated. + In non-streaming mode, this is called once with all output tokens. + + Args: + output: The RequestOutput containing generated token information + + Returns: + int: Number of output tokens processed in this call + """ + updated_output_token_count = 0 + if output.outputs: + for completion_output in output.outputs: + # only keep last round + updated_output_token_count += len(completion_output.token_ids) + self.num_output_tokens += updated_output_token_count + self.current_turn.output_tokens += updated_output_token_count + return updated_output_token_count + @property def messages(self) -> list: return self._messages @@ -118,7 +253,8 @@ class HarmonyContext(ConversationContext): last_msg = self.messages[-1] recipient = last_msg.recipient return recipient is not None and (recipient.startswith("browser.") - or recipient.startswith("python")) + or recipient.startswith("python") or + recipient.startswith("container.")) async def call_tool(self) -> list[Message]: if not self.messages: @@ -132,6 +268,9 @@ class HarmonyContext(ConversationContext): elif recipient.startswith("python"): return await self.call_python_tool( self._tool_sessions["python"], last_msg) + elif recipient.startswith("container."): + return await self.call_container_tool( + self._tool_sessions["container"], last_msg) raise ValueError("No tool call found") def render_for_completion(self) -> list[int]: @@ -140,6 +279,7 @@ class HarmonyContext(ConversationContext): async def call_search_tool(self, tool_session: Union["ClientSession", Tool], last_msg: Message) -> list[Message]: + self.called_tools.add("browser") if isinstance(tool_session, Tool): return await tool_session.get_result(self) tool_name = last_msg.recipient.split(".")[1] @@ -149,12 +289,16 @@ class HarmonyContext(ConversationContext): content = TextContent(text=result_str) author = Author(role=Role.TOOL, name=last_msg.recipient) return [ - Message(author=author, content=[content], recipient=Role.ASSISTANT) + Message(author=author, + content=[content], + recipient=Role.ASSISTANT, + channel=last_msg.channel) ] async def call_python_tool(self, tool_session: Union["ClientSession", Tool], last_msg: Message) -> list[Message]: + self.called_tools.add("python") if isinstance(tool_session, Tool): return await tool_session.get_result(self) param = { @@ -174,13 +318,63 @@ class HarmonyContext(ConversationContext): ] async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack) -> None: + exit_stack: AsyncExitStack, + request_id: str) -> None: if tool_server: for tool_name in self.available_tools: if tool_name not in self._tool_sessions: - self._tool_sessions[ - tool_name] = await exit_stack.enter_async_context( - tool_server.new_session(tool_name)) + tool_session = await exit_stack.enter_async_context( + tool_server.new_session(tool_name, request_id)) + self._tool_sessions[tool_name] = tool_session + exit_stack.push_async_exit(self.cleanup_session) + + async def call_container_tool(self, tool_session: Union["ClientSession", + Tool], + last_msg: Message) -> list[Message]: + """ + Call container tool. Expect this to be run in a stateful docker + with command line terminal. + The official container tool would at least + expect the following format: + - for tool name: exec + - args: + { + "cmd":List[str] "command to execute", + "workdir":optional[str] "current working directory", + "env":optional[object/dict] "environment variables", + "session_name":optional[str] "session name", + "timeout":optional[int] "timeout in seconds", + "user":optional[str] "user name", + } + """ + self.called_tools.add("container") + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + tool_name = last_msg.recipient.split(".")[1].split(" ")[0] + args = json.loads(last_msg.content[0].text) + result = await tool_session.call_tool(tool_name, args) + result_str = result.content[0].text + content = TextContent(text=result_str) + author = Author(role=Role.TOOL, name=last_msg.recipient) + return [ + Message(author=author, + content=[content], + recipient=Role.ASSISTANT, + channel=last_msg.channel) + ] + + async def cleanup_session(self, *args, **kwargs) -> None: + """Can be used as coro to used in __aexit__""" + + async def cleanup_tool_session(tool_session): + if not isinstance(tool_session, Tool): + logger.info("Cleaning up tool session for %s", + tool_session._client_info) + with contextlib.suppress(Exception): + await tool_session.call_tool("cleanup_session", {}) + + await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool]) + for tool in self.called_tools)) class StreamingHarmonyContext(HarmonyContext): @@ -196,23 +390,36 @@ class StreamingHarmonyContext(HarmonyContext): @property def messages(self) -> list: - return self.parser.messages + return self._messages - def append_output(self, output) -> None: + def append_output(self, output: Union[RequestOutput, + list[Message]]) -> None: if isinstance(output, RequestOutput): # append_output is called for each output token in streaming case, # so we only want to add the prompt tokens once for each message. if self.first_tok_of_message: - self._update_num_prompt_tokens(output) + self._update_prefill_token_usage(output) + self.current_turn.output_tokens = 0 # Reset self.first_tok_of_message if needed: # if the current token is the last one of the current message # (finished=True), then the next token processed will mark the # beginning of a new message self.first_tok_of_message = output.finished - tok = output.outputs[0].token_ids[0] - self.parser.process(tok) - self._update_num_output_tokens(output.outputs[0].token_ids) + for tok in output.outputs[0].token_ids: + self.parser.process(tok) + self._update_decode_token_usage(output) + + # For streaming, update previous turn when message is complete + if output.finished: + self.previous_turn = self.current_turn.copy() + # Check if the current token is part of reasoning content + self._update_num_reasoning_tokens() self.last_tok = tok + if len(self._messages) - self.num_init_messages < len( + self.parser.messages): + self._messages.extend( + self.parser.messages[len(self._messages) - + self.num_init_messages:]) else: # Handle the case of tool output in direct message format assert len(output) == 1, "Tool output should be a single message" @@ -225,6 +432,7 @@ class StreamingHarmonyContext(HarmonyContext): for tok in toks: self.parser.process(tok) self.last_tok = toks[-1] + # TODO: add tool_output messages to self._messages def is_expecting_start(self) -> bool: return self.parser.state == StreamState.EXPECT_START diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 078d316844257..1364a41be950d 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + import datetime import json from collections.abc import Iterable, Sequence @@ -13,12 +16,15 @@ from openai.types.responses.response_function_web_search import ( from openai.types.responses.response_reasoning_item import ( Content as ResponseReasoningTextContent) from openai.types.responses.tool import Tool -from openai_harmony import (Author, Conversation, DeveloperContent, - HarmonyEncodingName, Message, ReasoningEffort, - Role, StreamableParser, SystemContent, TextContent, - ToolDescription, load_harmony_encoding) +from openai_harmony import (Author, ChannelConfig, Conversation, + DeveloperContent, HarmonyEncodingName, Message, + ReasoningEffort, Role, StreamableParser, + SystemContent, TextContent, ToolDescription, + load_harmony_encoding) -from vllm.entrypoints.openai.protocol import ResponseInputOutputItem +from vllm import envs +from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam, + ResponseInputOutputItem) from vllm.utils import random_uuid REASONING_EFFORT = { @@ -29,6 +35,20 @@ REASONING_EFFORT = { _harmony_encoding = None +# Builtin tools that should be included in the system message when +# they are available and requested by the user. +# Tool args are provided by MCP tool descriptions. Output +# of the tools are stringified. +BUILTIN_TOOLS = { + "web_search_preview", + "code_interpreter", + "container", +} + + +def has_custom_tools(tool_types: list[str]) -> bool: + return not set(tool_types).issubset(BUILTIN_TOOLS) + def get_encoding(): global _harmony_encoding @@ -44,10 +64,19 @@ def get_system_message( start_date: Optional[str] = None, browser_description: Optional[str] = None, python_description: Optional[str] = None, + container_description: Optional[str] = None, + instructions: Optional[str] = None, + with_custom_tools: bool = False, ) -> Message: sys_msg_content = SystemContent.new() if model_identity is not None: sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if (instructions is not None + and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS): + current_identity = sys_msg_content.model_identity + new_identity = (f'{current_identity}\n{instructions}' + if current_identity else instructions) + sys_msg_content = sys_msg_content.with_model_identity(new_identity) if reasoning_effort is not None: sys_msg_content = sys_msg_content.with_reasoning_effort( REASONING_EFFORT[reasoning_effort]) @@ -59,32 +88,55 @@ def get_system_message( sys_msg_content = sys_msg_content.with_tools(browser_description) if python_description is not None: sys_msg_content = sys_msg_content.with_tools(python_description) + if container_description is not None: + sys_msg_content = sys_msg_content.with_tools(container_description) + if not with_custom_tools: + channel_config = sys_msg_content.channel_config + invalid_channel = "commentary" + new_config = ChannelConfig.require_channels( + [c for c in channel_config.valid_channels if c != invalid_channel]) + sys_msg_content = sys_msg_content.with_channel_config(new_config) sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) return sys_msg -def get_developer_message(instructions: Optional[str] = None, - tools: Optional[list[Tool]] = None) -> Message: +def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]): + if isinstance(tool, ChatCompletionToolsParam): + return ToolDescription.new( + name=tool.function.name, + description=tool.function.description, + parameters=tool.function.parameters, + ) + return ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) + + +def get_developer_message( + instructions: Optional[str] = None, + tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None, +) -> Message: dev_msg_content = DeveloperContent.new() - if instructions is not None: + if (instructions is not None + and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS): dev_msg_content = dev_msg_content.with_instructions(instructions) if tools is not None: - function_tools = [] + function_tools: list[Union[Tool, ChatCompletionToolsParam]] = [] for tool in tools: - if tool.type in ("web_search_preview", "code_interpreter"): + if tool.type in ("web_search_preview", "code_interpreter", + "container"): # These are built-in tools that are added to the system message. pass + elif tool.type == "function": function_tools.append(tool) else: raise ValueError(f"tool type {tool.type} not supported") if function_tools: function_tool_descriptions = [ - ToolDescription.new( - name=tool.name, - description=tool.description, - parameters=tool.parameters, - ) for tool in function_tools + create_tool_definition(tool) for tool in function_tools ] dev_msg_content = dev_msg_content.with_function_tools( function_tool_descriptions) @@ -120,6 +172,8 @@ def parse_response_input( TextContent(text=text_prefix + c["text"]) for c in content ] msg = Message.from_role_and_contents(role, contents) + if role == "assistant": + msg = msg.with_channel("final") elif response_msg["type"] == "function_call_output": call_id = response_msg["call_id"] call_response: Optional[ResponseFunctionToolCall] = None @@ -148,16 +202,46 @@ def parse_response_input( return msg -def parse_chat_input(chat_msg) -> Message: - role = chat_msg["role"] - content = chat_msg["content"] +def parse_chat_input(chat_msg) -> list[Message]: + if not isinstance(chat_msg, dict): + # Handle Pydantic models + chat_msg = chat_msg.model_dump(exclude_none=True) + + role = chat_msg.get("role") + + # Assistant message with tool calls + tool_calls = chat_msg.get("tool_calls") + if role == "assistant" and tool_calls: + msgs: list[Message] = [] + for call in tool_calls: + func = call.get("function", {}) + name = func.get("name", "") + arguments = func.get("arguments", "") or "" + msg = Message.from_role_and_content(Role.ASSISTANT, arguments) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{name}") + msg = msg.with_content_type("json") + msgs.append(msg) + return msgs + + # Tool role message (tool output) + if role == "tool": + name = chat_msg.get("name", "") + content = chat_msg.get("content", "") or "" + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{name}"), + content).with_channel("commentary") + return [msg] + + # Default: user/assistant/system messages with content + content = chat_msg.get("content", "") if isinstance(content, str): contents = [TextContent(text=content)] else: # TODO: Support refusal. contents = [TextContent(text=c.get("text", "")) for c in content] msg = Message.from_role_and_contents(role, contents) - return msg + return [msg] def render_for_completion(messages: list[Message]) -> list[int]: @@ -227,7 +311,7 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: call_id=f"call_{random_id}", type="function_call", name=function_name, - id=f"ft_{random_id}", + id=f"fc_{random_id}", ) output_items.append(response_item) elif recipient is not None and (recipient.startswith("python") @@ -303,7 +387,9 @@ def parse_remaining_state( id=f"msg_{random_uuid()}", content=[output_text], role="assistant", - status="completed", + # if the parser still has messages (ie if the generator got cut + # abruptly), this should be incomplete + status="incomplete", type="message", ) return [text_item] diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 4e852ba594930..c3195dbc4697f 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -12,7 +12,6 @@ from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError -from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.protocol import EngineClient from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) @@ -95,7 +94,7 @@ async def serve_http(app: FastAPI, port = uvicorn_kwargs["port"] process = find_process_using_port(port) if process is not None: - logger.debug( + logger.warning( "port %s is used by process %s launched with command:\n%s", port, process, " ".join(process.cmdline())) logger.info("Shutting down FastAPI HTTP server.") @@ -156,7 +155,6 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: @app.exception_handler(RuntimeError) @app.exception_handler(AsyncEngineDeadError) - @app.exception_handler(MQEngineDeadError) @app.exception_handler(EngineDeadError) @app.exception_handler(EngineGenerateError) async def runtime_exception_handler(request: Request, __): diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9b2ad808eb03e..df6b16c73d6e7 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -15,8 +15,8 @@ import vllm.envs as envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, create_sort_beams_key_function) -from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, - is_init_field) +from vllm.config import (CompilationConfig, ModelDType, + StructuredOutputsConfig, TokenizerMode, is_init_field) from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides, PoolerConfig, RunnerOption) from vllm.engine.llm_engine import LLMEngine @@ -110,6 +110,14 @@ class LLM: values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, it may cause out-of- memory (OOM) errors. + kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default, + this is set to None and vllm can automatically infer the kv cache + size based on gpu_memory_utilization. However, users may want to + manually specify the kv cache memory size. kv_cache_memory_bytes + allows more fine-grain control of how much memory gets used when + compared with using gpu_memory_memory_utilization. Note that + kv_cache_memory_bytes (when not-None) ignores + gpu_memory_utilization swap_space: The size (GiB) of CPU memory per GPU to use as swap space. This can be used for temporarily storing the states of the requests when their `best_of` sampling parameters are larger than 1. If all @@ -184,6 +192,9 @@ class LLM: hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, override_pooler_config: Optional[PoolerConfig] = None, + structured_outputs_config: Optional[Union[dict[ + str, Any], StructuredOutputsConfig]] = None, + kv_cache_memory_bytes: Optional[int] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, logits_processors: Optional[list[Union[str, @@ -204,7 +215,7 @@ class LLM: if "kv_transfer_config" in kwargs and isinstance( kwargs["kv_transfer_config"], dict): - from vllm.config import KVTransferConfig + from vllm.config.kv_transfer import KVTransferConfig raw_config_dict = kwargs["kv_transfer_config"] try: kwargs["kv_transfer_config"] = KVTransferConfig( @@ -227,14 +238,30 @@ class LLM: compilation_config_instance = CompilationConfig( level=compilation_config) elif isinstance(compilation_config, dict): - predicate = lambda x: is_init_field(CompilationConfig, x[0]) compilation_config_instance = CompilationConfig( - **dict(filter(predicate, compilation_config.items()))) + **{ + k: v + for k, v in compilation_config.items() + if is_init_field(CompilationConfig, k) + }) else: compilation_config_instance = compilation_config else: compilation_config_instance = CompilationConfig() + if structured_outputs_config is not None: + if isinstance(structured_outputs_config, dict): + structured_outputs_instance = StructuredOutputsConfig( + **{ + k: v + for k, v in structured_outputs_config.items() + if is_init_field(StructuredOutputsConfig, k) + }) + else: + structured_outputs_instance = structured_outputs_config + else: + structured_outputs_instance = StructuredOutputsConfig() + engine_args = EngineArgs( model=model, runner=runner, @@ -251,6 +278,7 @@ class LLM: tokenizer_revision=tokenizer_revision, seed=seed, gpu_memory_utilization=gpu_memory_utilization, + kv_cache_memory_bytes=kv_cache_memory_bytes, swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, @@ -261,6 +289,7 @@ class LLM: hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, + structured_outputs_config=structured_outputs_instance, compilation_config=compilation_config_instance, logits_processors=logits_processors, **kwargs, @@ -291,23 +320,17 @@ class LLM: self.io_processor = get_io_processor(self.llm_engine.vllm_config, io_processor_plugin) - def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - return self.llm_engine.get_tokenizer_group().get_lora_tokenizer( - lora_request) + def get_tokenizer(self) -> AnyTokenizer: + return self.llm_engine.get_tokenizer() def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: - tokenizer_group = self.llm_engine.get_tokenizer_group() - # While CachedTokenizer is dynamic, have no choice but # compare class name. Misjudgment will arise from # user-defined tokenizer started with 'Cached' if tokenizer.__class__.__name__.startswith("Cached"): - tokenizer_group.tokenizer = tokenizer + self.llm_engine.tokenizer = tokenizer else: - tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) + self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) def get_default_sampling_params(self) -> SamplingParams: if self.default_sampling_params is None: @@ -693,6 +716,105 @@ class LLM: return outputs + def preprocess_chat( + self, + messages: Union[list[ChatCompletionMessageParam], + list[list[ChatCompletionMessageParam]]], + chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tools: Optional[list[dict[str, Any]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ) -> list[TokensPrompt]: + """ + Generate prompt for a chat conversation. The pre-processed + prompt can then be used as input for the other LLM methods. + + Refer to `chat` for a complete description of the arguments. + Returns: + A list of `TokensPrompts` objects containing the tokenized + prompt after chat template interpolation, and the + pre-processed multi-modal inputs. + """ + list_of_messages: list[list[ChatCompletionMessageParam]] + + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is list[list[...]] + list_of_messages = cast(list[list[ChatCompletionMessageParam]], + messages) + else: + # messages is list[...] + list_of_messages = [ + cast(list[ChatCompletionMessageParam], messages) + ] + + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tools, + chat_template_content_format, + tokenizer, + model_config=model_config, + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + prompts: list[TokensPrompt] = [] + + for msgs in list_of_messages: + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. + conversation, mm_data, mm_uuids = parse_chat_messages( + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) + + if isinstance(tokenizer, MistralTokenizer): + prompt_token_ids = apply_mistral_chat_template( + tokenizer, + messages=msgs, + **_chat_template_kwargs, + ) + else: + prompt_str = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + # Special tokens are already included in chat templates so + # should not be added by the tokenizer in this case. + prompt_token_ids = tokenizer.encode(prompt_str, + add_special_tokens=False) + + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + + prompts.append(prompt) + + return prompts + def chat( self, messages: Union[list[ChatCompletionMessageParam], @@ -759,77 +881,17 @@ class LLM: A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. """ - list_of_messages: list[list[ChatCompletionMessageParam]] - # Handle multi and single conversations - if is_list_of(messages, list): - # messages is list[list[...]] - list_of_messages = cast(list[list[ChatCompletionMessageParam]], - messages) - else: - # messages is list[...] - list_of_messages = [ - cast(list[ChatCompletionMessageParam], messages) - ] - - tokenizer = self.get_tokenizer(lora_request) - model_config = self.llm_engine.get_model_config() - resolved_content_format = resolve_chat_template_content_format( - chat_template, - tools, - chat_template_content_format, - tokenizer, - model_config=model_config, - ) - - _chat_template_kwargs: dict[str, Any] = dict( + prompts = self.preprocess_chat( + messages=messages, chat_template=chat_template, + chat_template_content_format=chat_template_content_format, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, tools=tools, + chat_template_kwargs=chat_template_kwargs, + mm_processor_kwargs=mm_processor_kwargs, ) - _chat_template_kwargs.update(chat_template_kwargs or {}) - - prompts: list[Union[TokensPrompt, TextPrompt]] = [] - - for msgs in list_of_messages: - # NOTE: _parse_chat_message_content_parts() currently doesn't - # handle mm_processor_kwargs, since there is no implementation in - # the chat message parsing for it. - conversation, mm_data = parse_chat_messages( - msgs, - model_config, - tokenizer, - content_format=resolved_content_format, - ) - - if isinstance(tokenizer, MistralTokenizer): - prompt_token_ids = apply_mistral_chat_template( - tokenizer, - messages=msgs, - **_chat_template_kwargs, - ) - else: - prompt_str = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) - # Special tokens are already included in chat templates so - # should not be added by the tokenizer in this case. - prompt_token_ids = tokenizer.encode(prompt_str, - add_special_tokens=False) - - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) - - if mm_data is not None: - prompt["multi_modal_data"] = mm_data - - if mm_processor_kwargs is not None: - prompt["mm_processor_kwargs"] = mm_processor_kwargs - - prompts.append(prompt) return self.generate( prompts, @@ -881,6 +943,10 @@ class LLM: considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ + + if self.supported_tasks == ["encode"] and pooling_task is None: + pooling_task = "encode" + if pooling_task is None: if "embed" in self.supported_tasks: pooling_task = "embed" @@ -1440,6 +1506,11 @@ class LLM: for i, prompt in enumerate(it): + if isinstance(prompt, dict): + self._validate_mm_data_and_uuids( + prompt.get("multi_modal_data"), + prompt.get("multi_modal_uuids")) + param = params[i] if isinstance(params, Sequence) else params tokenization_kwargs: dict[str, Any] = {} @@ -1456,6 +1527,41 @@ class LLM: priority=priority[i] if priority else 0, ) + def _validate_mm_data_and_uuids( + self, + multi_modal_data: Optional[Any], # MultiModalDataDict + multi_modal_uuids: Optional[Any], # MultiModalUUIDDict + ): + """ + Validate that if any multi-modal data is skipped (i.e. None), + then its corresponding UUID must be set. + """ + if multi_modal_data is None: + return + + for modality, data in multi_modal_data.items(): + if isinstance(data, list): + for i, d in enumerate(data): + if d is None: + if multi_modal_uuids is None or modality not in multi_modal_uuids or multi_modal_uuids[ # noqa: E501 + modality] is None: + raise ValueError( + f"Multi-modal data for {modality} is None " + f"but UUID is not provided") + else: + if len( + multi_modal_uuids[modality] + ) <= i or multi_modal_uuids[modality][i] is None: + raise ValueError( + f"Multi-modal data for {modality} is None " + f"but UUID is not provided") + else: + if data is None and (multi_modal_uuids is None + or modality not in multi_modal_uuids + or multi_modal_uuids[modality] is None): + raise ValueError(f"Multi-modal data for {modality} is None" + f" but UUID is not provided") + def _add_request( self, prompt: PromptType, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3cebfdf885bec..912e664120929 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import atexit import gc import importlib import inspect @@ -15,9 +14,8 @@ import socket import tempfile import uuid from argparse import Namespace -from collections.abc import AsyncIterator, Awaitable +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable from contextlib import asynccontextmanager -from functools import partial from http import HTTPStatus from typing import Annotated, Any, Callable, Optional @@ -41,8 +39,6 @@ import vllm.envs as envs from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore -from vllm.engine.multiprocessing.client import MQLLMEngineClient -from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (load_chat_template, resolve_hf_chat_template, @@ -70,7 +66,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, RerankRequest, RerankResponse, ResponsesRequest, ResponsesResponse, ScoreRequest, - ScoreResponse, TokenizeRequest, + ScoreResponse, + StreamingResponsesResponse, + TokenizeRequest, TokenizeResponse, TranscriptionRequest, TranscriptionResponse, @@ -101,13 +99,11 @@ from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, log_non_default_args, with_cancellation) from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import (Device, FlexibleArgumentParser, decorate_logs, - get_open_zmq_ipc_path, is_valid_ipv6_address, - set_ulimit) + is_valid_ipv6_address, set_ulimit) +from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION @@ -236,8 +232,7 @@ async def build_async_engine_client_from_engine_args( async_llm.shutdown() # V0 AsyncLLM. - elif (MQLLMEngineClient.is_unsupported_config(vllm_config) - or disable_frontend_multiprocessing): + else: engine_client: Optional[EngineClient] = None try: @@ -251,96 +246,6 @@ async def build_async_engine_client_from_engine_args( if engine_client and hasattr(engine_client, "shutdown"): engine_client.shutdown() - # V0MQLLMEngine. - else: - if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: - # Make TemporaryDirectory for prometheus multiprocessing - # Note: global TemporaryDirectory will be automatically - # cleaned up upon exit. - global prometheus_multiproc_dir - prometheus_multiproc_dir = tempfile.TemporaryDirectory() - os.environ[ - "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name - else: - logger.warning( - "Found PROMETHEUS_MULTIPROC_DIR was set by user. " - "This directory must be wiped between vLLM runs or " - "you will find inaccurate metrics. Unset the variable " - "and vLLM will properly handle cleanup.") - - # Select random path for IPC. - ipc_path = get_open_zmq_ipc_path() - logger.debug("Multiprocessing frontend to use %s for IPC Path.", - ipc_path) - - # Start RPCServer in separate process (holds the LLMEngine). - # the current process might have CUDA context, - # so we need to spawn a new process - context = multiprocessing.get_context("spawn") - - # Ensure we can serialize transformer config before spawning - maybe_register_config_serialize_by_value() - - # The Process can raise an exception during startup, which may - # not actually result in an exitcode being reported. As a result - # we use a shared variable to communicate the information. - engine_alive = multiprocessing.Value('b', True, lock=False) - engine_process = context.Process( - target=run_mp_engine, - args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path, - engine_args.disable_log_stats, - engine_args.enable_log_requests, engine_alive)) - engine_process.start() - engine_pid = engine_process.pid - assert engine_pid is not None, "Engine process failed to start." - logger.info("Started engine process with PID %d", engine_pid) - - def _cleanup_ipc_path(): - socket_path = ipc_path.replace("ipc://", "") - if os.path.exists(socket_path): - os.remove(socket_path) - - # Ensure we clean up the local IPC socket file on exit. - atexit.register(_cleanup_ipc_path) - - # Build RPCClient, which conforms to EngineClient Protocol. - build_client = partial(MQLLMEngineClient, ipc_path, vllm_config, - engine_pid) - mq_engine_client = await asyncio.get_running_loop().run_in_executor( - None, build_client) - try: - while True: - try: - await mq_engine_client.setup() - break - except TimeoutError: - if (not engine_process.is_alive() - or not engine_alive.value): - raise RuntimeError( - "Engine process failed to start. See stack " - "trace for the root cause.") from None - - yield mq_engine_client # type: ignore[misc] - finally: - # Ensure rpc server process was terminated - engine_process.terminate() - - # Close all open connections to the backend - mq_engine_client.close() - - # Wait for engine process to join - engine_process.join(4) - if engine_process.exitcode is None: - # Kill if taking longer than 5 seconds to stop - engine_process.kill() - - # Lazy import for prometheus multiprocessing. - # We need to set PROMETHEUS_MULTIPROC_DIR environment variable - # before prometheus_client is imported. - # See https://prometheus.github.io/client_python/multiprocess/ - from prometheus_client import multiprocess - multiprocess.mark_process_dead(engine_process.pid) - async def validate_json_request(raw_request: Request): content_type = raw_request.headers.get("content-type", "").lower() @@ -447,8 +352,11 @@ def engine_client(request: Request) -> EngineClient: @router.get("/health", response_class=Response) async def health(raw_request: Request) -> Response: """Health check.""" - await engine_client(raw_request).check_health() - return Response(status_code=200) + try: + await engine_client(raw_request).check_health() + return Response(status_code=200) + except EngineDeadError: + return Response(status_code=503) @router.get("/load") @@ -577,6 +485,18 @@ async def show_version(): return JSONResponse(content=ver) +async def _convert_stream_to_sse_events( + generator: AsyncGenerator[StreamingResponsesResponse, None] +) -> AsyncGenerator[str, None]: + """Convert the generator to a stream of events in SSE format""" + async for event in generator: + event_type = getattr(event, 'type', 'unknown') + # https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + event_data = (f"event: {event_type}\n" + f"data: {event.model_dump_json(indent=None)}\n\n") + yield event_data + + @router.post("/v1/responses", dependencies=[Depends(validate_json_request)], responses={ @@ -612,18 +532,29 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): status_code=generator.error.code) elif isinstance(generator, ResponsesResponse): return JSONResponse(content=generator.model_dump()) - return StreamingResponse(content=generator, media_type="text/event-stream") + + return StreamingResponse(content=_convert_stream_to_sse_events(generator), + media_type="text/event-stream") @router.get("/v1/responses/{response_id}") -async def retrieve_responses(response_id: str, raw_request: Request): +async def retrieve_responses( + response_id: str, + raw_request: Request, + starting_after: Optional[int] = None, + stream: Optional[bool] = False, +): handler = responses(raw_request) if handler is None: return base(raw_request).create_error_response( message="The model does not support Responses API") try: - response = await handler.retrieve_responses(response_id) + response = await handler.retrieve_responses( + response_id, + starting_after=starting_after, + stream=stream, + ) except Exception as e: raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e @@ -631,7 +562,10 @@ async def retrieve_responses(response_id: str, raw_request: Request): if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.error.code) - return JSONResponse(content=response.model_dump()) + elif isinstance(response, ResponsesResponse): + return JSONResponse(content=response.model_dump()) + return StreamingResponse(content=_convert_stream_to_sse_events(response), + media_type="text/event-stream") @router.post("/v1/responses/{response_id}/cancel") @@ -1705,6 +1639,8 @@ async def init_app_state( if args.tool_server == "demo": tool_server: Optional[ToolServer] = DemoToolServer() + assert isinstance(tool_server, DemoToolServer) + await tool_server.init_and_validate() elif args.tool_server: tool_server = MCPToolServer() await tool_server.add_tool_server(args.tool_server) @@ -1746,7 +1682,7 @@ async def init_app_state( enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, tool_server=tool_server, - reasoning_parser=args.reasoning_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, @@ -1765,7 +1701,7 @@ async def init_app_state( exclude_tools_when_tool_choice_none=args. exclude_tools_when_tool_choice_none, tool_parser=args.tool_call_parser, - reasoning_parser=args.reasoning_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, @@ -1868,10 +1804,10 @@ def validate_api_server_args(args): f"(chose from {{ {','.join(valid_tool_parses)} }})") valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() - if args.reasoning_parser \ - and args.reasoning_parser not in valid_reasoning_parses: + if ((reasoning_parser := args.structured_outputs_config.reasoning_parser) + and reasoning_parser not in valid_reasoning_parses): raise KeyError( - f"invalid reasoning parser: {args.reasoning_parser} " + f"invalid reasoning parser: {reasoning_parser} " f"(chose from {{ {','.join(valid_reasoning_parses)} }})") diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index d0b5d013eb9e5..1c2a6f58197d8 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -134,14 +134,13 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" """If specified, will run the OpenAI frontend server in the same process as the model serving engine.""" enable_request_id_headers: bool = False - """If specified, API server will add X-Request-Id header to responses. - Caution: this hurts performance at high QPS.""" + """If specified, API server will add X-Request-Id header to responses.""" enable_auto_tool_choice: bool = False - """If specified, exclude tool definitions in prompts when - tool_choice='none'.""" - exclude_tools_when_tool_choice_none: bool = False """Enable auto tool choice for supported models. Use `--tool-call-parser` to specify which parser to use.""" + exclude_tools_when_tool_choice_none: bool = False + """If specified, exclude tool definitions in prompts when + tool_choice='none'.""" tool_call_parser: Optional[str] = None """Select the tool call parser depending on the model that you're using. This is used to parse the model-generated tool call into OpenAI API format. @@ -172,8 +171,8 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" """Enable the /get_tokenizer_info endpoint. May expose chat templates and other tokenizer configuration.""" enable_log_outputs: bool = False - """If set to True, enable logging of model outputs (generations) - in addition to the input logging that is enabled by default.""" + """If True, log model outputs (generations). + Requires --enable-log-requests.""" h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT """Maximum size (bytes) of an incomplete HTTP event (header or body) for h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB).""" @@ -204,7 +203,7 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" frontend_kwargs["lora_modules"]["type"] = optional_type(str) frontend_kwargs["lora_modules"]["action"] = LoRAParserAction - # Special case: Middleware needs append action + # Special case: Middleware needs to append action frontend_kwargs["middleware"]["action"] = "append" frontend_kwargs["middleware"]["type"] = str if "nargs" in frontend_kwargs["middleware"]: @@ -274,6 +273,9 @@ def validate_parsed_serve_args(args: argparse.Namespace): if args.enable_auto_tool_choice and not args.tool_call_parser: raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") + if args.enable_log_outputs and not args.enable_log_requests: + raise TypeError("Error: --enable-log-outputs requires " + "--enable-log-requests") def create_parser_for_docs() -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4881022325625..7ad8e73d89d59 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -18,10 +18,19 @@ from openai.types.chat.chat_completion_audio import ( from openai.types.chat.chat_completion_message import ( Annotation as OpenAIAnnotation) # yapf: enable -from openai.types.responses import (ResponseFunctionToolCall, - ResponseInputItemParam, ResponseOutputItem, - ResponsePrompt, ResponseReasoningItem, - ResponseStatus) +from openai.types.responses import ( + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallCompletedEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallInterpretingEvent, ResponseCompletedEvent, + ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, + ResponseCreatedEvent, ResponseFunctionToolCall, ResponseInProgressEvent, + ResponseInputItemParam, ResponseOutputItem, ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, ResponsePrompt, ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent, + ResponseStatus, ResponseWebSearchCallCompletedEvent, + ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent) # Backward compatibility for OpenAI client versions try: # For older openai versions (< 1.100.0) @@ -30,7 +39,7 @@ except ImportError: # For newer openai versions (>= 1.100.0) from openai.types.responses import (ResponseFormatTextConfig as ResponseTextConfig) -from openai.types.responses.response import ToolChoice +from openai.types.responses.response import IncompleteDetails, ToolChoice from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, @@ -45,8 +54,8 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam, from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.pooling_params import PoolingParams -from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, - RequestOutputKind, SamplingParams) +from vllm.sampling_params import (BeamSearchParams, RequestOutputKind, + SamplingParams, StructuredOutputsParams) from vllm.utils import random_uuid, resolve_obj_by_qualname logger = init_logger(__name__) @@ -242,7 +251,7 @@ def get_logits_processors(processors: Optional[LogitsProcessors], elif processors: raise ValueError( "The `logits_processors` argument is not supported by this " - "server. See --logits-processor-pattern engine argugment " + "server. See --logits-processor-pattern engine argument " "for more information.") return None @@ -251,6 +260,26 @@ ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, ResponseReasoningItem, ResponseFunctionToolCall] +StreamingResponsesResponse: TypeAlias = Union[ + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseCompletedEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseWebSearchCallInProgressEvent, + ResponseWebSearchCallSearchingEvent, + ResponseWebSearchCallCompletedEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallInterpretingEvent, + ResponseCodeInterpreterCallCompletedEvent, +] + class ResponsesRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -344,11 +373,12 @@ class ResponsesRequest(OpenAIBaseModel): stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output - guided_decoding = None + structured_outputs = None if self.text is not None and self.text.format is not None: response_format = self.text.format - if response_format.type == "json_schema": - guided_decoding = GuidedDecodingParams.from_optional( + if (response_format.type == "json_schema" + and response_format.schema_ is not None): + structured_outputs = StructuredOutputsParams( json=response_format.schema_) elif response_format.type == "json_object": raise NotImplementedError("json_object is not supported") @@ -363,7 +393,7 @@ class ResponsesRequest(OpenAIBaseModel): stop_token_ids=stop_token_ids, output_kind=(RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY), - guided_decoding=guided_decoding, + structured_outputs=structured_outputs, ) def is_include_output_logprobs(self) -> bool: @@ -518,42 +548,9 @@ class ChatCompletionRequest(OpenAIBaseModel): default=None, description=("Additional kwargs to pass to the HF processor."), ) - guided_json: Optional[Union[str, dict, BaseModel]] = Field( + structured_outputs: Optional[StructuredOutputsParams] = Field( default=None, - description=("If specified, the output will follow the JSON schema."), - ) - guided_regex: Optional[str] = Field( - default=None, - description=( - "If specified, the output will follow the regex pattern."), - ) - guided_choice: Optional[list[str]] = Field( - default=None, - description=( - "If specified, the output will be exactly one of the choices."), - ) - guided_grammar: Optional[str] = Field( - default=None, - description=( - "If specified, the output will follow the context free grammar."), - ) - structural_tag: Optional[str] = Field( - default=None, - description=( - "If specified, the output will follow the structural tag schema."), - ) - guided_decoding_backend: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be either " - "'outlines' / 'lm-format-enforcer'"), - ) - guided_whitespace_pattern: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding."), + description="Additional kwargs for structured outputs", ) priority: int = Field( default=0, @@ -672,31 +669,33 @@ class ChatCompletionRequest(OpenAIBaseModel): if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs - guided_json_object = None - if self.response_format is not None: - if self.response_format.type == "json_object": - guided_json_object = True - elif self.response_format.type == "json_schema": - json_schema = self.response_format.json_schema - assert json_schema is not None - self.guided_json = json_schema.json_schema - elif self.response_format.type == "structural_tag": - structural_tag = self.response_format - assert structural_tag is not None and isinstance( - structural_tag, StructuralTagResponseFormat) - s_tag_obj = structural_tag.model_dump(by_alias=True) - self.structural_tag = json.dumps(s_tag_obj) + response_format = self.response_format + json_schema_from_tool = self._get_json_schema_from_tool() + if response_format is not None or json_schema_from_tool is not None: + # If structured outputs wasn't already enabled, + # we must enable it for these features to work + if self.structured_outputs is None: + self.structured_outputs = StructuredOutputsParams() - guided_decoding = GuidedDecodingParams.from_optional( - json=self._get_guided_json_from_tool() or self.guided_json, - regex=self.guided_regex, - choice=self.guided_choice, - grammar=self.guided_grammar, - json_object=guided_json_object, - backend=self.guided_decoding_backend, - whitespace_pattern=self.guided_whitespace_pattern, - structural_tag=self.structural_tag, - ) + # Set structured output params for response format + if response_format is not None: + if response_format.type == "json_object": + self.structured_outputs.json_object = True + elif response_format.type == "json_schema": + json_schema = response_format.json_schema + assert json_schema is not None + self.structured_outputs.json = json_schema.json_schema + elif response_format.type == "structural_tag": + structural_tag = response_format + assert structural_tag is not None and isinstance( + structural_tag, StructuralTagResponseFormat) + s_tag_obj = structural_tag.model_dump(by_alias=True) + self.structured_outputs.structural_tag = json.dumps( + s_tag_obj) + + # Set structured output params for tool calling + if json_schema_from_tool is not None: + self.structured_outputs.json = json_schema_from_tool extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} if self.kv_transfer_params: @@ -728,15 +727,14 @@ class ChatCompletionRequest(OpenAIBaseModel): truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, - guided_decoding=guided_decoding, + structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, - bad_words= self.bad_words, + bad_words=self.bad_words, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) - def _get_guided_json_from_tool( - self) -> Optional[Union[str, dict, BaseModel]]: + def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]: # user has chosen to not use any tool if self.tool_choice == "none" or self.tools is None: return None @@ -822,18 +820,24 @@ class ChatCompletionRequest(OpenAIBaseModel): @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and prompt_logprobs > 0: + if data.get("stream") and (prompt_logprobs > 0 + or prompt_logprobs == -1): raise ValueError( "`prompt_logprobs` are not available when `stream=True`.") - if prompt_logprobs < 0: - raise ValueError("`prompt_logprobs` must be a positive value.") - + if prompt_logprobs < 0 and prompt_logprobs != -1: + raise ValueError( + "`prompt_logprobs` must be a positive value or -1.") + if prompt_logprobs == -1 and not envs.VLLM_USE_V1: + raise ValueError("`prompt_logprobs=-1` is only supported with " + "vLLM engine V1.") if (top_logprobs := data.get("top_logprobs")) is not None: - if top_logprobs < 0: - raise ValueError("`top_logprobs` must be a positive value.") + if top_logprobs < 0 and top_logprobs != -1: + raise ValueError( + "`top_logprobs` must be a positive value or -1.") - if top_logprobs > 0 and not data.get("logprobs"): + if (top_logprobs == -1 + or top_logprobs > 0) and not data.get("logprobs"): raise ValueError( "when using `top_logprobs`, `logprobs` must be set to true." ) @@ -842,28 +846,31 @@ class ChatCompletionRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod - def check_guided_decoding_count(cls, data): + def check_structured_outputs_count(cls, data): if isinstance(data, ValueError): raise data - guide_count = sum([ - "guided_json" in data and data["guided_json"] is not None, - "guided_regex" in data and data["guided_regex"] is not None, - "guided_choice" in data and data["guided_choice"] is not None - ]) - # you can only use one kind of guided decoding - if guide_count > 1: + if "structured_outputs" not in data: + return data + + structured_outputs_kwargs = data['structured_outputs'] + count = sum( + structured_outputs_kwargs.get(k) is not None + for k in ("json", "regex", "choice")) + # you can only use one kind of constraints for structured outputs + if count > 1: raise ValueError( - "You can only use one kind of guided decoding " - "('guided_json', 'guided_regex' or 'guided_choice').") - # you can only either use guided decoding or tools, not both - if guide_count > 1 and data.get("tool_choice", "none") not in ( + "You can only use one kind of constraints for structured " + "outputs ('json', 'regex' or 'choice').") + # you can only either use structured outputs or tools, not both + if count > 1 and data.get("tool_choice", "none") not in ( "none", "auto", "required", ): raise ValueError( - "You can only either use guided decoding or tools, not both.") + "You can only either use constraints for structured outputs " + "or tools, not both.") return data @model_validator(mode="before") @@ -1016,37 +1023,9 @@ class CompletionRequest(OpenAIBaseModel): ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." ), ) - guided_json: Optional[Union[str, dict, BaseModel]] = Field( + structured_outputs: Optional[StructuredOutputsParams] = Field( default=None, - description="If specified, the output will follow the JSON schema.", - ) - guided_regex: Optional[str] = Field( - default=None, - description=( - "If specified, the output will follow the regex pattern."), - ) - guided_choice: Optional[list[str]] = Field( - default=None, - description=( - "If specified, the output will be exactly one of the choices."), - ) - guided_grammar: Optional[str] = Field( - default=None, - description=( - "If specified, the output will follow the context free grammar."), - ) - guided_decoding_backend: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default guided decoding backend " - "of the server for this specific request. If set, must be one of " - "'outlines' / 'lm-format-enforcer'"), - ) - guided_whitespace_pattern: Optional[str] = Field( - default=None, - description=( - "If specified, will override the default whitespace pattern " - "for guided json decoding."), + description="Additional kwargs for structured outputs", ) priority: int = Field( default=0, @@ -1177,20 +1156,10 @@ class CompletionRequest(OpenAIBaseModel): echo_without_generation = self.echo and self.max_tokens == 0 - guided_json_object = None - if (self.response_format is not None + if (self.structured_outputs is not None + and self.response_format is not None and self.response_format.type == "json_object"): - guided_json_object = True - - guided_decoding = GuidedDecodingParams.from_optional( - json=self.guided_json, - regex=self.guided_regex, - choice=self.guided_choice, - grammar=self.guided_grammar, - json_object=guided_json_object, - backend=self.guided_decoding_backend, - whitespace_pattern=self.guided_whitespace_pattern, - ) + self.structured_outputs.json_object = True extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} if self.kv_transfer_params: @@ -1222,7 +1191,7 @@ class CompletionRequest(OpenAIBaseModel): truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, - guided_decoding=guided_decoding, + structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, @@ -1230,29 +1199,35 @@ class CompletionRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod - def check_guided_decoding_count(cls, data): - guide_count = sum([ - "guided_json" in data and data["guided_json"] is not None, - "guided_regex" in data and data["guided_regex"] is not None, - "guided_choice" in data and data["guided_choice"] is not None - ]) - if guide_count > 1: + def check_structured_outputs_count(cls, data): + if "structured_outputs" not in data: + return data + + structured_outputs_kwargs = data['structured_outputs'] + count = sum( + structured_outputs_kwargs.get(k) is not None + for k in ("json", "regex", "choice")) + if count > 1: raise ValueError( - "You can only use one kind of guided decoding " - "('guided_json', 'guided_regex' or 'guided_choice').") + "You can only use one kind of constraints for structured " + "outputs ('json', 'regex' or 'choice').") return data @model_validator(mode="before") @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and prompt_logprobs > 0: + if data.get("stream") and (prompt_logprobs > 0 + or prompt_logprobs == -1): raise ValueError( "`prompt_logprobs` are not available when `stream=True`.") - if prompt_logprobs < 0: - raise ValueError("`prompt_logprobs` must be a positive value.") - + if prompt_logprobs < 0 and prompt_logprobs != -1: + raise ValueError( + "`prompt_logprobs` must be a positive value or -1.") + if prompt_logprobs == -1 and not envs.VLLM_USE_V1: + raise ValueError("`prompt_logprobs=-1` is only supported with " + "vLLM engine V1.") if (logprobs := data.get("logprobs")) is not None and logprobs < 0: raise ValueError("`logprobs` must be a positive value.") @@ -1270,9 +1245,20 @@ class CompletionRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def validate_prompt_and_prompt_embeds(cls, data): - if data.get("prompt") is None and data.get("prompt_embeds") is None: + prompt = data.get("prompt") + prompt_embeds = data.get("prompt_embeds") + + prompt_is_empty = (prompt is None + or (isinstance(prompt, str) and prompt == "")) + embeds_is_empty = (prompt_embeds is None + or (isinstance(prompt_embeds, list) + and len(prompt_embeds) == 0)) + + if prompt_is_empty and embeds_is_empty: raise ValueError( - "At least one of `prompt` or `prompt_embeds` must be set.") + "Either prompt or prompt_embeds must be provided and non-empty." + ) + return data @model_validator(mode="before") @@ -1342,6 +1328,14 @@ class EmbeddingChatRequest(OpenAIBaseModel): truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None # --8<-- [start:chat-embedding-extra-params] + add_generation_prompt: bool = Field( + default=False, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + add_special_tokens: bool = Field( default=False, description=( @@ -1833,7 +1827,8 @@ class InputTokensDetails(OpenAIBaseModel): class OutputTokensDetails(OpenAIBaseModel): - reasoning_tokens: int + reasoning_tokens: int = 0 + tool_output_tokens: int = 0 class ResponseUsage(OpenAIBaseModel): @@ -1848,7 +1843,7 @@ class ResponsesResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") created_at: int = Field(default_factory=lambda: int(time.time())) # error: Optional[ResponseError] = None - # incomplete_details: Optional[IncompleteDetails] = None + incomplete_details: Optional[IncompleteDetails] = None instructions: Optional[str] = None metadata: Optional[Metadata] = None model: str @@ -1884,9 +1879,18 @@ class ResponsesResponse(OpenAIBaseModel): status: ResponseStatus, usage: Optional[ResponseUsage] = None, ) -> "ResponsesResponse": + + incomplete_details: Optional[IncompleteDetails] = None + if status == 'incomplete': + incomplete_details = IncompleteDetails(reason='max_output_tokens') + # TODO: implement the other reason for incomplete_details, + # which is content_filter + # incomplete_details = IncompleteDetails(reason='content_filter') + return cls( id=request.request_id, created_at=created_time, + incomplete_details=incomplete_details, instructions=request.instructions, metadata=request.metadata, model=model_name, @@ -2089,7 +2093,7 @@ class DetokenizeResponse(OpenAIBaseModel): class TokenizerInfoResponse(OpenAIBaseModel): """ - Response containing tokenizer configuration + Response containing tokenizer configuration equivalent to tokenizer_config.json """ @@ -2179,7 +2183,7 @@ class TranscriptionRequest(OpenAIBaseModel): to_language: Optional[str] = None """The language of the output audio we transcribe to. - Please note that this is not currently used by supported models at this + Please note that this is not currently used by supported models at this time, but it is a placeholder for future use, matching translation api. """ diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 35edd2f85cd07..16564214e353a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -186,9 +186,9 @@ class OpenAIServingChat(OpenAIServing): lora_request = self._maybe_get_adapters( request, supports_default_mm_loras=True) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer() tool_parser = self.tool_parser @@ -418,7 +418,7 @@ class OpenAIServingChat(OpenAIServing): if not function_name_returned: # get partly generated arguments from the latest tool call param_match = re.search(r'.*"parameters":\s*(.*)', - current_text) + current_text, re.DOTALL) arguments = param_match.group(1) if param_match else "" arguments, _ = OpenAIServingChat._filter_delta_text( arguments, previous_text) @@ -489,6 +489,8 @@ class OpenAIServingChat(OpenAIServing): get_streamable_parser_for_assistant() for _ in range(num_choices) ] + harmony_tools_streamed = [False] * num_choices + tools_streamed = [False] * num_choices if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -662,13 +664,11 @@ class OpenAIServingChat(OpenAIServing): if self.use_harmony: harmony_parser = harmony_parsers[i] + prev_recipient = harmony_parser.current_recipient for token_id in output.token_ids: harmony_parser.process(token_id) - is_reasoning = \ - harmony_parser.current_channel == "analysis" - if not request.include_reasoning and is_reasoning: - # Skip the reasoning content. - continue + cur_channel = harmony_parser.current_channel + cur_recipient = harmony_parser.current_recipient delta_text = harmony_parser.last_content_delta or "" else: delta_text = output.text @@ -681,8 +681,7 @@ class OpenAIServingChat(OpenAIServing): delta_message: Optional[DeltaMessage] # just update previous_texts and previous_token_ids - if ((tool_choice_auto or self.reasoning_parser) - and not self.use_harmony): + if tool_choice_auto or self.reasoning_parser: assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -696,11 +695,54 @@ class OpenAIServingChat(OpenAIServing): current_token_ids = as_list(output.token_ids) if self.use_harmony: - if is_reasoning: - delta_message = DeltaMessage( - reasoning_content=delta_text) - else: + if cur_channel == "final": delta_message = DeltaMessage(content=delta_text) + elif cur_channel == "analysis": + if request.include_reasoning: + delta_message = DeltaMessage( + reasoning_content=delta_text) + else: + delta_message = None + elif (cur_channel == "commentary" and cur_recipient + and cur_recipient.startswith("functions.")): + # Count completed tool calls to determine index + base_index = 0 + for msg in harmony_parser.messages: + if (msg.channel == "commentary" + and msg.recipient + and msg.recipient.startswith( + "functions.")): + base_index += 1 + + if prev_recipient != cur_recipient: + tool_name = cur_recipient.split( + "functions.", 1)[1] + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_name, + arguments="", + ), + index=base_index, + ) + ]) + elif delta_text: + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=base_index, + function=DeltaFunctionCall( + arguments=delta_text), + ) + ]) + else: + delta_message = None + + if delta_message is not None: + harmony_tools_streamed[i] = True + else: + delta_message = None # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: if (self.reasoning_parser and not reasoning_end_arr[i] @@ -758,6 +800,7 @@ class OpenAIServingChat(OpenAIServing): delta_message = DeltaMessage(tool_calls=[ delta_tool_call, ]) + tools_streamed[i] = True elif request.tool_choice == "required": assert previous_texts is not None @@ -783,9 +826,7 @@ class OpenAIServingChat(OpenAIServing): if (delta_message and delta_message.tool_calls and delta_message.tool_calls[0].id is not None): history_tool_call_cnt += 1 - - # update the previous values for the next iteration - previous_texts[i] = current_text + tools_streamed[i] = True # handle streaming deltas for tools with "auto" tool choice # and reasoning parser @@ -859,6 +900,8 @@ class OpenAIServingChat(OpenAIServing): current_token_ids=current_token_ids, delta_token_ids=delta_token_ids, request=request)) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True # when only tool calls elif tool_choice_auto: assert tool_parser is not None @@ -871,6 +914,8 @@ class OpenAIServingChat(OpenAIServing): current_token_ids=current_token_ids, delta_token_ids=output.token_ids, request=request)) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True # when only reasoning elif self.reasoning_parser: @@ -907,7 +952,10 @@ class OpenAIServingChat(OpenAIServing): # wasn't ready to send a token, then # get the next token without streaming a chunk if delta_message is None: - continue + if output.finish_reason is None: + continue + else: + delta_message = DeltaMessage() # Log streaming delta if output logging is enabled if self.enable_log_outputs and self.request_logger: @@ -945,7 +993,7 @@ class OpenAIServingChat(OpenAIServing): # check to make sure we haven't "forgotten" to stream # any tokens that were generated but previously # matched by partial json parsing - # only happens if we are NOT using guided decoding + # only happens if we are NOT using structured outputs auto_tools_called = False if tool_parser: auto_tools_called = len( @@ -993,12 +1041,18 @@ class OpenAIServingChat(OpenAIServing): ]) # Send the finish response for each request.n only once + if auto_tools_called or tools_streamed[i] or ( + self.use_harmony + and harmony_tools_streamed[i]): + finish_reason_ = "tool_calls" + else: + finish_reason_ = output.finish_reason \ + if output.finish_reason else "stop" choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=output.finish_reason - if not auto_tools_called else "tool_calls", + finish_reason=finish_reason_, stop_reason=output.stop_reason, token_ids=(as_list(output.token_ids) if request.return_token_ids else None)) @@ -1117,6 +1171,7 @@ class OpenAIServingChat(OpenAIServing): for output in final_res.outputs: token_ids = output.token_ids out_logprobs = output.logprobs + tool_call_info = None if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, "Did not output logprobs" @@ -1131,31 +1186,42 @@ class OpenAIServingChat(OpenAIServing): logprobs = None if self.use_harmony: - reasoning_content, final_content, is_tool_call = ( - parse_chat_output(token_ids)) - if not request.include_reasoning: - reasoning_content = None - - if is_tool_call: - # TODO(woosuk): Implement tool call for gpt-oss. - # For now, only Responses API supports tool call for - # gpt-oss. - raise NotImplementedError( - "Tool call in Chat Completion API is not supported " - "for gpt-oss yet. Please use Responses API instead.") - else: - # Normal message + if self.tool_parser is not None: + tool_parser = self.tool_parser(tokenizer) + # NOTE: We use token_ids for openai tool parser + tool_call_info = tool_parser.extract_tool_calls( + "", + request=request, + token_ids=token_ids, # type: ignore + ) + reasoning_content, content = None, tool_call_info.content + if request.include_reasoning: + reasoning_content, content, _ = parse_chat_output( + token_ids) message = ChatMessage( role=role, reasoning_content=reasoning_content, - content=final_content, + content=content, + tool_calls=tool_call_info.tool_calls, + ) + else: + reasoning_content, content, _ = parse_chat_output( + token_ids) + if not request.include_reasoning: + reasoning_content = None + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=content, ) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason="tool_calls" if is_tool_call else + finish_reason="tool_calls" if + (tool_call_info is not None + and tool_call_info.tools_called) else output.finish_reason if output.finish_reason else "stop", stop_reason=output.stop_reason, ) @@ -1419,9 +1485,10 @@ class OpenAIServingChat(OpenAIServing): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None or step_top_logprobs.get( token_id) is None: - token = tokenizer.decode(token_id) if should_return_as_token_id: token = f"token_id:{token_id}" + else: + token = tokenizer.decode(token_id) logprobs_content.append( ChatCompletionLogProbsContent( @@ -1503,12 +1570,12 @@ class OpenAIServingChat(OpenAIServing): messages.append(sys_msg) # Add developer message. - dev_msg = get_developer_message() + dev_msg = get_developer_message(tools=request.tools) messages.append(dev_msg) # Add user message. for chat_msg in request.messages: - messages.append(parse_chat_input(chat_msg)) + messages.extend(parse_chat_input(chat_msg)) # Render prompt token ids. prompt_token_ids = render_for_completion(messages) diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index b4fdc36390319..fc56668aeb1b6 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -20,6 +20,7 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext, OpenAIServing, ServeContext) from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.logger import init_logger from vllm.outputs import ClassificationOutput, PoolingRequestOutput from vllm.pooling_params import PoolingParams @@ -49,19 +50,12 @@ class ClassificationMixin(OpenAIServing): return None try: - ctx.lora_request = self._maybe_get_adapters(ctx.request) + ctx.tokenizer = await self.engine_client.get_tokenizer() - ctx.tokenizer = await self.engine_client.get_tokenizer( - ctx.lora_request) - - ( - ctx.request_prompts, - ctx.engine_prompts, - ) = await self._preprocess_completion( - ctx.request, - ctx.tokenizer, - ctx.request.input, - ) + renderer = self._get_renderer(ctx.tokenizer) + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=ctx.request.input, + config=self._build_render_config(ctx.request)) return None @@ -117,6 +111,12 @@ class ClassificationMixin(OpenAIServing): usage=usage, ) + def _build_render_config(self, + request: ClassificationRequest) -> RenderConfig: + return RenderConfig( + max_length=self.max_model_len, + truncate_prompt_tokens=request.truncate_prompt_tokens) + class ServingClassification(ClassificationMixin): request_id_prefix = "classify" @@ -143,7 +143,7 @@ class ServingClassification(ClassificationMixin): request: ClassificationRequest, raw_request: Request, ) -> Union[ClassificationResponse, ErrorResponse]: - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = (f"{self.request_id_prefix}-" f"{self._base_request_id(raw_request)}") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b26140d4b9d7a..0c61c48da0bc8 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -26,14 +26,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo) -from vllm.entrypoints.openai.serving_engine import ( - EmbedsPrompt as ServingEngineEmbedsPrompt) from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - TextTokensPrompt, - clamp_prompt_logprobs, - is_text_tokens_prompt) + clamp_prompt_logprobs) # yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, is_tokens_prompt) @@ -115,6 +112,11 @@ class OpenAIServingCompletion(OpenAIServing): return self.create_error_response( "Echo is unsupported with prompt embeds.") + if (request.prompt_logprobs is not None + and request.prompt_embeds is not None): + return self.create_error_response( + "prompt_logprobs is not compatible with prompt embeds.") + request_id = ( f"cmpl-" f"{self._base_request_id(raw_request, request.request_id)}") @@ -130,14 +132,13 @@ class OpenAIServingCompletion(OpenAIServing): if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = await self.engine_client.get_tokenizer(lora_request - ) + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) - request_prompts, engine_prompts = await self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, + engine_prompts = await renderer.render_prompt_and_embeds( + prompt_or_prompts=request.prompt, + prompt_embeds=request.prompt_embeds, + config=self._build_render_config(request), ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") @@ -198,7 +199,7 @@ class OpenAIServingCompletion(OpenAIServing): self._log_inputs( request_id_item, - request_prompts[i], + engine_prompt, params=sampling_params, lora_request=lora_request, ) @@ -235,7 +236,7 @@ class OpenAIServingCompletion(OpenAIServing): result_generator = merge_async_iterators(*generators) - model_name = self._get_model_name(request.model, lora_request) + model_name = self.models.model_name(lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the @@ -249,7 +250,7 @@ class OpenAIServingCompletion(OpenAIServing): if stream: return self.completion_stream_generator( request, - request_prompts, + engine_prompts, result_generator, request_id, created_time, @@ -273,11 +274,9 @@ class OpenAIServingCompletion(OpenAIServing): # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: - request_prompt = request_prompts[i] - if is_text_tokens_prompt(request_prompt): - final_res.prompt = request_prompt["prompt"] - else: - final_res.prompt = None + engine_prompt = engine_prompts[i] + final_res.prompt = None if is_embeds_prompt( + engine_prompt) else engine_prompt.get("prompt") final_res_batch_checked = cast(list[RequestOutput], final_res_batch) @@ -313,8 +312,7 @@ class OpenAIServingCompletion(OpenAIServing): async def completion_stream_generator( self, request: CompletionRequest, - request_prompts: list[Union[TextTokensPrompt, - ServingEngineEmbedsPrompt]], + engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]], result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -350,14 +348,11 @@ class OpenAIServingCompletion(OpenAIServing): num_cached_tokens = res.num_cached_tokens first_iteration = False - if res.prompt is not None: - prompt_text = res.prompt - else: - request_prompt = request_prompts[prompt_idx] - if is_text_tokens_prompt(request_prompt): - prompt_text = request_prompt["prompt"] - else: - prompt_text = None + prompt_text = res.prompt + if prompt_text is None: + engine_prompt = engine_prompts[prompt_idx] + prompt_text = None if is_embeds_prompt( + engine_prompt) else engine_prompt.get("prompt") # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: @@ -378,6 +373,8 @@ class OpenAIServingCompletion(OpenAIServing): assert request.max_tokens is not None if request.echo and not has_echoed[i]: assert prompt_token_ids is not None + if request.return_token_ids: + prompt_text = "" assert prompt_text is not None if request.max_tokens == 0: # only return the prompt @@ -525,6 +522,8 @@ class OpenAIServingCompletion(OpenAIServing): for output in final_res.outputs: assert request.max_tokens is not None if request.echo: + if request.return_token_ids: + prompt_text = "" assert prompt_text is not None if request.max_tokens == 0: token_ids = prompt_token_ids @@ -676,3 +675,18 @@ class OpenAIServingCompletion(OpenAIServing): tokens=out_tokens, top_logprobs=out_top_logprobs, ) + + def _build_render_config( + self, + request: CompletionRequest, + max_input_length: Optional[int] = None, + ) -> RenderConfig: + max_input_tokens_len = self.max_model_len - (request.max_tokens or 0) + return RenderConfig( + max_length=max_input_tokens_len, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + cache_salt=request.cache_salt, + needs_detokenization=bool(request.echo + and not request.return_token_ids), + ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 0a0d98db2d0d8..647e7daed6598 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -24,12 +24,11 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, OpenAIServing, - RequestPrompt, ServeContext, TextTokensPrompt) # yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.entrypoints.renderer import RenderConfig from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, @@ -77,13 +76,13 @@ class EmbeddingMixin(OpenAIServing): try: ctx.lora_request = self._maybe_get_adapters(ctx.request) - tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request - ) + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) if isinstance(ctx.request, EmbeddingChatRequest): ( _, - ctx.request_prompts, + _, ctx.engine_prompts, ) = await self._preprocess_chat( ctx.request, @@ -93,25 +92,33 @@ class EmbeddingMixin(OpenAIServing): or ctx.chat_template, chat_template_content_format=ctx. chat_template_content_format, - # In embedding requests, we are not generating tokens, - # so there is no need to append extra tokens to the input - add_generation_prompt=False, + add_generation_prompt=ctx.request.add_generation_prompt, continue_final_message=False, add_special_tokens=ctx.request.add_special_tokens, ) else: - (ctx.request_prompts, - ctx.engine_prompts) = await self._preprocess_completion( - ctx.request, - tokenizer, - ctx.request.input, - add_special_tokens=ctx.request.add_special_tokens, - ) + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=ctx.request.input, + config=self._build_render_config(ctx.request), + ) return None except (ValueError, TypeError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) + def _build_render_config( + self, request: EmbeddingCompletionRequest) -> RenderConfig: + # Set max_length based on chunked processing capability + if self._should_use_chunked_processing(request): + max_length = None + else: + max_length = self.max_embed_len or self.max_model_len + + return RenderConfig( + max_length=max_length, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens) + @override def _build_response( self, @@ -287,8 +294,7 @@ class EmbeddingMixin(OpenAIServing): async def _create_single_prompt_generator( self, ctx: EmbeddingServeContext, - engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt], - request_prompt: RequestPrompt, + engine_prompt: EngineTokensPrompt, pooling_params: PoolingParams, trace_headers: Optional[Mapping[str, str]], prompt_index: int, @@ -297,16 +303,10 @@ class EmbeddingMixin(OpenAIServing): request_id_item = f"{ctx.request_id}-{prompt_index}" self._log_inputs(request_id_item, - request_prompt, + engine_prompt, params=pooling_params, lora_request=ctx.lora_request) - # Mypy has an existing bug related to inferring the variance - # of TypedDicts with `builtins.enumerate`: - # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 - engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) - # Return the original generator without wrapping return self.engine_client.encode( engine_prompt, @@ -355,20 +355,14 @@ class EmbeddingMixin(OpenAIServing): return self.create_error_response( "Engine prompts not available") - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") - max_pos_embeddings = self._get_max_position_embeddings() for i, engine_prompt in enumerate(ctx.engine_prompts): - request_prompt = ctx.request_prompts[i] - # Check if this specific prompt needs chunked processing - if self._is_text_tokens_prompt(request_prompt): + if self._is_text_tokens_prompt(engine_prompt): # Cast to TextTokensPrompt since we've verified # prompt_token_ids - text_tokens_prompt = cast(TextTokensPrompt, request_prompt) + text_tokens_prompt = cast(TextTokensPrompt, engine_prompt) if (len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings): # Use chunked processing for this prompt @@ -379,13 +373,8 @@ class EmbeddingMixin(OpenAIServing): continue # Normal processing for short prompts or non-token prompts - # Cast engine_prompt to the expected type for mypy - engine_prompt_typed = cast( - Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) generator = await self._create_single_prompt_generator( - ctx, engine_prompt_typed, request_prompt, pooling_params, - trace_headers, i) + ctx, engine_prompt, pooling_params, trace_headers, i) generators.append(generator) from vllm.utils import merge_async_iterators @@ -404,8 +393,8 @@ class EmbeddingMixin(OpenAIServing): ) -> Optional[ErrorResponse]: """Collect and aggregate batch results with support for chunked processing. - - For chunked requests, performs online aggregation to + + For chunked requests, performs online aggregation to minimize memory usage. For regular requests, collects results normally. """ @@ -421,10 +410,6 @@ class EmbeddingMixin(OpenAIServing): if not use_chunked: return await super()._collect_batch(ctx=ctx) - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") - if ctx.result_generator is None: return self.create_error_response( "Result generator not available") @@ -540,7 +525,7 @@ class EmbeddingMixin(OpenAIServing): data=final_embedding) # Get original prompt token IDs for this prompt - original_prompt = ctx.request_prompts[prompt_idx] + original_prompt = ctx.engine_prompts[prompt_idx] if not self._is_text_tokens_prompt(original_prompt): return self.create_error_response( f"Chunked prompt {prompt_idx} is not a " @@ -613,7 +598,7 @@ class OpenAIServingEmbedding(EmbeddingMixin): See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. """ - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = ( f"{self.request_id_prefix}-" f"{self._base_request_id(raw_request, request.request_id)}") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 796b8ab5fc2cb..4eb1f8b89d64f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio -import io import json import sys import time @@ -9,10 +7,8 @@ import traceback from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus -from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, - TypeVar, Union, cast, overload) +from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union -import pybase64 import torch from fastapi import Request from pydantic import BaseModel, ConfigDict, Field @@ -62,10 +58,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, TranslationRequest) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer, + RenderConfig) # yapf: enable -from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest @@ -82,16 +79,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, logger = init_logger(__name__) -CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, - EmbeddingCompletionRequest, RerankRequest, - ClassificationRequest, ScoreRequest, - TokenizeCompletionRequest] +CompletionLikeRequest = Union[ + CompletionRequest, + DetokenizeRequest, + EmbeddingCompletionRequest, + RerankRequest, + ClassificationRequest, + ScoreRequest, + TokenizeCompletionRequest, +] ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest, - ResponsesRequest, IOProcessorRequest] +AnyRequest = Union[ + CompletionLikeRequest, + ChatLikeRequest, + SpeechToTextRequest, + ResponsesRequest, + IOProcessorRequest, +] AnyResponse = Union[ CompletionResponse, @@ -135,9 +142,9 @@ class RequestProcessingMixin(BaseModel): Mixin for request processing, handling prompt preparation and engine input. """ + request_prompts: Optional[Sequence[RequestPrompt]] = [] - engine_prompts: Optional[Union[list[EngineTokensPrompt], - list[EngineEmbedsPrompt]]] = [] + engine_prompts: Optional[list[EngineTokensPrompt]] = [] model_config = ConfigDict(arbitrary_types_allowed=True) @@ -147,6 +154,7 @@ class ResponseGenerationMixin(BaseModel): Mixin for response generation, managing result generators and final batch results. """ + result_generator: Optional[AsyncGenerator[tuple[int, Union[ RequestOutput, PoolingRequestOutput]], None]] = None final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( @@ -155,8 +163,12 @@ class ResponseGenerationMixin(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, - Generic[RequestT]): +class ServeContext( + RequestProcessingMixin, + ResponseGenerationMixin, + BaseModel, + Generic[RequestT], +): # Shared across all requests request: RequestT raw_request: Optional[Request] = None @@ -227,6 +239,29 @@ class OpenAIServing: AsyncMicrobatchTokenizer] = {} self.log_error_stack = log_error_stack + def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: + """ + Get a Renderer instance with the provided tokenizer. + Uses shared async tokenizer pool for efficiency. + """ + return CompletionRenderer( + model_config=self.model_config, + tokenizer=tokenizer, + async_tokenizer_pool=self._async_tokenizer_pool) + + def _build_render_config( + self, + request: Any, + ) -> RenderConfig: + """ + Build and return a `RenderConfig` for an endpoint. + + Used by the renderer to control how prompts are prepared + (e.g., tokenization and length handling). Endpoints should + implement this with logic appropriate to their request type. + """ + raise NotImplementedError + def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: """ Return (and cache) an `AsyncMicrobatchTokenizer` bound to the @@ -298,8 +333,8 @@ class OpenAIServing: truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) - if truncate_prompt_tokens is not None and \ - truncate_prompt_tokens > self.max_model_len: + if (truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.max_model_len): return self.create_error_response( "truncate_prompt_tokens value is " "greater than max_model_len." @@ -340,21 +375,13 @@ class OpenAIServing: for i, engine_prompt in enumerate(ctx.engine_prompts): request_id_item = f"{ctx.request_id}-{i}" - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) - self._log_inputs(request_id_item, - ctx.request_prompts[i], - params=pooling_params, - lora_request=ctx.lora_request) - - # Mypy has an existing bug related to inferring the variance of - # TypedDicts with `builtins.enumerate`: - # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 - engine_prompt = cast( - Union[EngineTokensPrompt, EngineEmbedsPrompt], - engine_prompt) generator = self.engine_client.encode( engine_prompt, pooling_params, @@ -410,10 +437,11 @@ class OpenAIServing: return self.create_error_response(str(e)) def create_error_response( - self, - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + ) -> ErrorResponse: if self.log_error_stack: exc_type, _, _ = sys.exc_info() if exc_type is not None: @@ -424,10 +452,11 @@ class OpenAIServing: message=message, type=err_type, code=status_code.value)) def create_streaming_error_response( - self, - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + ) -> str: json_str = json.dumps( self.create_error_response(message=message, err_type=err_type, @@ -438,25 +467,25 @@ class OpenAIServing: self, request: AnyRequest, ) -> Optional[ErrorResponse]: - error_response = None if self._is_model_supported(request.model): return None if request.model in self.models.lora_requests: return None - if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and ( - load_result := await self.models.resolve_lora(request.model)): + if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and + (load_result := await self.models.resolve_lora(request.model))): if isinstance(load_result, LoRARequest): return None - if isinstance(load_result, ErrorResponse) and \ - load_result.error.code == HTTPStatus.BAD_REQUEST.value: + if (isinstance(load_result, ErrorResponse) and + load_result.error.code == HTTPStatus.BAD_REQUEST.value): error_response = load_result return error_response or self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND) + status_code=HTTPStatus.NOT_FOUND, + ) def _get_active_default_mm_loras( self, request: AnyRequest) -> Optional[LoRARequest]: @@ -487,7 +516,6 @@ class OpenAIServing: request: AnyRequest, supports_default_mm_loras: bool = False, ) -> Optional[LoRARequest]: - if request.model in self.models.lora_requests: return self.models.lora_requests[request.model] @@ -548,13 +576,15 @@ class OpenAIServing: prompt, add_special_tokens=add_special_tokens, truncation=True, - max_length=self.max_model_len) + max_length=self.max_model_len, + ) else: encoded = await async_tokenizer( prompt, add_special_tokens=add_special_tokens, truncation=True, - max_length=truncate_prompt_tokens) + max_length=truncate_prompt_tokens, + ) input_ids = encoded.input_ids input_text = prompt @@ -595,16 +625,22 @@ class OpenAIServing: # Note: EmbeddingRequest, ClassificationRequest, # and ScoreRequest doesn't have max_tokens - if isinstance(request, - (EmbeddingChatRequest, EmbeddingCompletionRequest, - ScoreRequest, RerankRequest, ClassificationRequest)): - + if isinstance( + request, + ( + EmbeddingChatRequest, + EmbeddingCompletionRequest, + ScoreRequest, + RerankRequest, + ClassificationRequest, + ), + ): # Note: input length can be up to the entire model context length # since these requests don't generate tokens. if token_num > self.max_model_len: operations: dict[type[AnyRequest], str] = { ScoreRequest: "score", - ClassificationRequest: "classification" + ClassificationRequest: "classification", } operation = operations.get(type(request), "embedding generation") @@ -618,8 +654,11 @@ class OpenAIServing: # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation - if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, - DetokenizeRequest)): + if isinstance( + request, + (TokenizeCompletionRequest, TokenizeChatRequest, + DetokenizeRequest), + ): return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) @@ -639,8 +678,8 @@ class OpenAIServing: f"{token_num} input tokens. Please reduce the length of " "the input messages.") - if max_tokens is not None and \ - token_num + max_tokens > self.max_model_len: + if (max_tokens is not None + and token_num + max_tokens > self.max_model_len): raise ValueError( "'max_tokens' or 'max_completion_tokens' is too large: " f"{max_tokens}. This model's maximum context length is " @@ -658,9 +697,7 @@ class OpenAIServing: add_special_tokens: bool = True, ) -> TextTokensPrompt: """ - A simpler implementation of - [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] - that assumes single input. + A simpler implementation that tokenizes a single prompt input. """ async for result in self._tokenize_prompt_inputs_async( request, @@ -679,9 +716,7 @@ class OpenAIServing: add_special_tokens: bool = True, ) -> AsyncGenerator[TextTokensPrompt, None]: """ - A simpler implementation of - [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] - that assumes multiple inputs. + A simpler implementation that tokenizes multiple prompt inputs. """ for prompt in prompt_inputs: if isinstance(prompt, str): @@ -698,156 +733,6 @@ class OpenAIServing: tokenizer=tokenizer, ) - async def _tokenize_prompt_input_or_inputs_async( - self, - request: AnyRequest, - tokenizer: Optional[AnyTokenizer], - input_or_inputs: Optional[Union[str, list[str], list[int], - list[list[int]]]], - add_special_tokens: bool = True, - ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: - """ - Tokenize/detokenize depending on the input format. - - According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_ - , each input can be a string or array of tokens. Note that each request - can pass one or more inputs. - """ - inputs_embeds = list[EmbedsPrompt]() - inputs_text = list[TextTokensPrompt]() - - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", - None) - - if (truncate_prompt_tokens or 0) < 0: - truncate_prompt_tokens = self.max_model_len - - if (isinstance(request, CompletionRequest) - and request.prompt_embeds is not None): - inputs_embeds.extend( - self._load_prompt_embeds(request.prompt_embeds, - truncate_prompt_tokens)) - - # Empty prompts are okay as long as there are prompt embeddings - if input_or_inputs is None or (inputs_embeds - and input_or_inputs == ""): - return [], inputs_embeds - - # Although our type checking is based on mypy, - # VSCode Pyright extension should still work properly - # "is False" is required for Pyright to perform type narrowing - # See: https://github.com/microsoft/pyright/issues/7672 - - # Parse and batch the input prompts - batch_inputs = parse_and_batch_prompt(input_or_inputs) - - # Process each input in the batch concurrently - tasks = [] - for prompt_input in batch_inputs: - if prompt_input["is_tokens"] is False: - assert tokenizer is not None, \ - "Tokenizer is required for text prompts" - task = self._normalize_prompt_text_to_input( - request, - prompt_input["content"], - tokenizer=tokenizer, - add_special_tokens=add_special_tokens) - else: - task = self._normalize_prompt_tokens_to_input( - request, prompt_input["content"], tokenizer=tokenizer) - tasks.append(task) - - # Wait for all tokenization tasks to complete - results = await asyncio.gather(*tasks) - inputs_text.extend(results) - - return inputs_text, inputs_embeds - - @overload - async def _preprocess_completion( - self, - request: Union[DetokenizeRequest, EmbeddingCompletionRequest, - RerankRequest, ClassificationRequest, ScoreRequest, - TokenizeCompletionRequest], - tokenizer: Optional[AnyTokenizer], - input_or_inputs: Union[str, list[str], list[int], list[list[int]]], - add_special_tokens: bool = ..., - ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: - ... - - @overload - async def _preprocess_completion( - self, - request: CompletionRequest, - tokenizer: Optional[AnyTokenizer], - input_or_inputs: Optional[Union[str, list[str], list[int], - list[list[int]]]], - add_special_tokens: bool = ..., - ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ - EngineTokensPrompt, EngineEmbedsPrompt]]]: - ... - - async def _preprocess_completion( - self, - request: CompletionLikeRequest, - tokenizer: Optional[AnyTokenizer], - input_or_inputs: Optional[Union[str, list[str], list[int], - list[list[int]]]], - add_special_tokens: bool = True, - ) -> tuple[Union[list[TextTokensPrompt], list[Union[ - TextTokensPrompt, EmbedsPrompt]]], Union[ - list[EngineTokensPrompt], list[Union[EngineTokensPrompt, - EngineEmbedsPrompt]]]]: - if not isinstance(request, - CompletionRequest) and input_or_inputs is None: - raise ValueError( - "Prompt embeds with non-completion requests is not" - " currently supported.") - - (request_prompts_text, request_prompts_embeds - ) = await self._tokenize_prompt_input_or_inputs_async( - request, - tokenizer, - input_or_inputs, - add_special_tokens=add_special_tokens, - ) - - engine_prompts_text = [ - EngineTokensPrompt( - prompt_token_ids=request_prompt_text["prompt_token_ids"]) - for request_prompt_text in request_prompts_text - ] - cache_salt = request.cache_salt if ( - hasattr(request, "cache_salt") - and request.cache_salt is not None) else None - if cache_salt: - for prompt_text in engine_prompts_text: - prompt_text["cache_salt"] = cache_salt - - # This check is equivalent to simply checking if - # `request_prompts_embeds` is empty, but it's difficult to propagate - # overloads to the private helper functions to enable this check. - # This overload is needed because only TextPrompts are allowed for - # non-completion requests and if we don't add the overload here, - # everywhere this function is used outside of serving_completion will - # need logic asserting that only text prompts are in the request. - if not isinstance(request, - CompletionRequest) and input_or_inputs is not None: - return request_prompts_text, engine_prompts_text - - engine_prompts_embeds = [ - EngineEmbedsPrompt( - prompt_embeds=request_prompt_embeds["prompt_embeds"]) - for request_prompt_embeds in request_prompts_embeds - ] - if cache_salt: - for prompt_embed in engine_prompts_embeds: - prompt_embed["cache_salt"] = cache_salt - - request_prompts = request_prompts_embeds + request_prompts_text - engine_prompts = engine_prompts_embeds + engine_prompts_text - return request_prompts, engine_prompts - async def _preprocess_chat( self, request: Union[ChatLikeRequest, ResponsesRequest], @@ -862,8 +747,11 @@ class OpenAIServing: chat_template_kwargs: Optional[dict[str, Any]] = None, tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, add_special_tokens: bool = False, - ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], - list[EngineTokensPrompt]]: + ) -> tuple[ + list[ConversationMessage], + Sequence[RequestPrompt], + list[EngineTokensPrompt], + ]: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( @@ -873,7 +761,7 @@ class OpenAIServing: tokenizer, model_config=model_config, ) - conversation, mm_data_future = parse_chat_messages_futures( + conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( messages, model_config, tokenizer, @@ -925,8 +813,8 @@ class OpenAIServing: if tokenizer is None: assert isinstance(request_prompt, str), ( - "Prompt has to be a string", \ - "when the tokenizer is not initialised" + "Prompt has to be a string", + "when the tokenizer is not initialised", ) prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) @@ -943,12 +831,17 @@ class OpenAIServing: "Prompt has to be either a string or a list of token ids") prompt_inputs = TextTokensPrompt( prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt) + prompt_token_ids=request_prompt, + ) engine_prompt = EngineTokensPrompt( prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + engine_prompt["multi_modal_uuids"] = mm_uuids + if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs @@ -1007,49 +900,15 @@ class OpenAIServing: prompt_token_ids=prompt_token_ids) request_prompt = prompt_token_ids # Update the sampling params. - sampling_params.max_tokens = (self.max_model_len - - len(prompt_token_ids)) + sampling_params.max_tokens = self.max_model_len - len( + prompt_token_ids) # OPTIMIZATION priority = orig_priority - 1 - @staticmethod - def _load_prompt_embeds( - prompt_embeds: Optional[Union[bytes, list[bytes]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None - ) -> list[EmbedsPrompt]: - - def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: - tensor = torch.load(io.BytesIO( - pybase64.b64decode(embed, validate=True)), - weights_only=True, - map_location=torch.device("cpu")) - assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( - torch.float32, - torch.bfloat16, - torch.float16, - ) - tensor = tensor.to_dense() - if tensor.dim() > 2: - tensor = tensor.squeeze(0) - assert tensor.dim() == 2 - if truncate_prompt_tokens is not None: - tensor = tensor[-truncate_prompt_tokens:] - return {"prompt_embeds": tensor} - - if prompt_embeds: - if isinstance(prompt_embeds, list): - return [ - _load_and_validate_embed(embed) for embed in prompt_embeds - ] - else: - return [_load_and_validate_embed(prompt_embeds)] - else: - return [] - def _log_inputs( self, request_id: str, - inputs: RequestPrompt, + inputs: Union[RequestPrompt, PromptType], params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], @@ -1061,11 +920,9 @@ class OpenAIServing: prompt = inputs elif isinstance(inputs, list): prompt_token_ids = inputs - elif 'prompt_embeds' in inputs: - prompt_embeds = inputs.get("prompt_embeds") else: - prompt = inputs["prompt"] - prompt_token_ids = inputs["prompt_token_ids"] + prompt = getattr(inputs, 'prompt', None) + prompt_token_ids = getattr(inputs, 'prompt_token_ids', None) self.request_logger.log_inputs( request_id, @@ -1101,10 +958,12 @@ class OpenAIServing: return raw_request.headers.get("X-Request-Id", default) @staticmethod - def _get_decoded_token(logprob: Logprob, - token_id: int, - tokenizer: AnyTokenizer, - return_as_token_id: bool = False) -> str: + def _get_decoded_token( + logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + return_as_token_id: bool = False, + ) -> str: if return_as_token_id: return f"token_id:{token_id}" @@ -1117,19 +976,10 @@ class OpenAIServing: return True return self.models.is_base_model(model_name) - def _get_model_name(self, - model_name: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> str: - if lora_request: - return lora_request.lora_name - if not model_name: - return self.models.base_model_paths[0].name - return model_name - def clamp_prompt_logprobs( prompt_logprobs: Union[PromptLogprobs, - None]) -> Union[PromptLogprobs, None]: + None], ) -> Union[PromptLogprobs, None]: if prompt_logprobs is None: return prompt_logprobs @@ -1137,6 +987,6 @@ def clamp_prompt_logprobs( if logprob_dict is None: continue for logprob_values in logprob_dict.values(): - if logprob_values.logprob == float('-inf'): + if logprob_values.logprob == float("-inf"): logprob_values.logprob = -9999.0 return prompt_logprobs diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 685c98c817c3d..0750c7ec3e9f1 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -4,7 +4,7 @@ import asyncio import base64 import time -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator from typing import Final, Literal, Optional, Union, cast import jinja2 @@ -26,8 +26,9 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, PoolingRequest, PoolingResponse, PoolingResponseData, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt +from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput @@ -90,7 +91,7 @@ class OpenAIServingPooling(OpenAIServing): if error_check_ret is not None: return error_check_ret - model_name = self._get_model_name(request.model) + model_name = self.models.model_name() request_id = f"pool-{self._base_request_id(raw_request)}" created_time = int(time.time()) @@ -102,8 +103,8 @@ class OpenAIServingPooling(OpenAIServing): if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = await self.engine_client.get_tokenizer(lora_request - ) + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) if getattr(request, "dimensions", None) is not None: return self.create_error_response( @@ -126,14 +127,11 @@ class OpenAIServingPooling(OpenAIServing): engine_prompts = await self.io_processor.pre_process_async( prompt=validated_prompt, request_id=request_id) - request_prompts: Sequence[RequestPrompt] = [ - "" - ] * len(engine_prompts) elif isinstance(request, PoolingChatRequest): ( _, - request_prompts, + _, engine_prompts, ) = await self._preprocess_chat( request, @@ -149,13 +147,10 @@ class OpenAIServingPooling(OpenAIServing): add_special_tokens=request.add_special_tokens, ) elif isinstance(request, PoolingCompletionRequest): - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, - tokenizer, - request.input, - add_special_tokens=request.add_special_tokens, - ) + engine_prompts = await renderer.render_prompt( + prompt_or_prompts=request.input, + config=self._build_render_config(request), + ) else: raise ValueError( f"Unsupported request of type {type(request)}") @@ -177,7 +172,7 @@ class OpenAIServingPooling(OpenAIServing): request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, - request_prompts[i], + engine_prompt, params=pooling_params, lora_request=lora_request) @@ -272,3 +267,10 @@ class OpenAIServingPooling(OpenAIServing): data=items, usage=usage, ) + + def _build_render_config( + self, request: PoolingCompletionRequest) -> RenderConfig: + return RenderConfig( + max_length=self.max_model_len, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 4c15de3030998..469d74272b0e6 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -4,6 +4,8 @@ import asyncio import json import time +import uuid +from collections import deque from collections.abc import AsyncGenerator, AsyncIterator, Sequence from contextlib import AsyncExitStack from copy import copy @@ -11,20 +13,25 @@ from http import HTTPStatus from typing import Callable, Final, Optional, Union import jinja2 -import openai.types.responses as openai_responses_types from fastapi import Request -from openai import BaseModel # yapf conflicts with isort for this block # yapf: disable -from openai.types.responses import (ResponseCreatedEvent, - ResponseFunctionToolCall, - ResponseInProgressEvent, - ResponseOutputItem, - ResponseOutputItemDoneEvent, - ResponseOutputMessage, ResponseOutputText, - ResponseReasoningItem, - ResponseReasoningTextDeltaEvent, - ResponseReasoningTextDoneEvent) +from openai.types.responses import ( + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallCompletedEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallInterpretingEvent, + ResponseCodeInterpreterToolCallParam, ResponseCompletedEvent, + ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, + ResponseCreatedEvent, ResponseFunctionToolCall, ResponseFunctionWebSearch, + ResponseInProgressEvent, ResponseOutputItem, ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText, + ResponseReasoningItem, ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, ResponseStatus, ResponseTextDeltaEvent, + ResponseTextDoneEvent, ResponseWebSearchCallCompletedEvent, + ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent, + response_function_web_search, response_text_delta_event) from openai.types.responses.response_output_text import (Logprob, LogprobTopLogprob) # yapf: enable @@ -41,21 +48,23 @@ from vllm.entrypoints.context import (ConversationContext, HarmonyContext, SimpleContext, StreamingHarmonyContext) from vllm.entrypoints.harmony_utils import ( get_developer_message, get_stop_tokens_for_assistant_actions, - get_system_message, get_user_message, parse_output_message, - parse_remaining_state, parse_response_input, render_for_completion) + get_system_message, get_user_message, has_custom_tools, + parse_output_message, parse_remaining_state, parse_response_input, + render_for_completion) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable -from vllm.entrypoints.openai.protocol import (ErrorResponse, +from vllm.entrypoints.openai.protocol import (DeltaMessage, ErrorResponse, InputTokensDetails, OutputTokensDetails, RequestResponseMetadata, ResponsesRequest, - ResponsesResponse, ResponseUsage) + ResponsesResponse, ResponseUsage, + StreamingResponsesResponse) # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.tool_server import MCPToolServer, ToolServer +from vllm.entrypoints.tool_server import ToolServer from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob @@ -168,6 +177,12 @@ class OpenAIServingResponses(OpenAIServing): # never remove messages from the store. self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {} + # HACK(wuhang): This is a hack. We should use a better store. + # FIXME: If enable_store=True, this may cause a memory leak since we + # never remove events from the store. + self.event_store: dict[str, tuple[deque[StreamingResponsesResponse], + asyncio.Event]] = {} + self.background_tasks: dict[str, asyncio.Task] = {} self.tool_server = tool_server @@ -176,7 +191,8 @@ class OpenAIServingResponses(OpenAIServing): self, request: ResponsesRequest, raw_request: Optional[Request] = None, - ) -> Union[AsyncGenerator[str, None], ResponsesResponse, ErrorResponse]: + ) -> Union[AsyncGenerator[StreamingResponsesResponse, None], + ResponsesResponse, ErrorResponse]: error_check_ret = await self._check_model(request) if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) @@ -228,8 +244,8 @@ class OpenAIServingResponses(OpenAIServing): try: lora_request = self._maybe_get_adapters(request) - model_name = self._get_model_name(request.model, lora_request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + model_name = self.models.model_name(lora_request) + tokenizer = await self.engine_client.get_tokenizer() if self.use_harmony: messages, request_prompts, engine_prompts = ( @@ -249,15 +265,6 @@ class OpenAIServingResponses(OpenAIServing): if raw_request: raw_request.state.request_metadata = request_metadata - if self.tool_server is not None and isinstance( - self.tool_server, - MCPToolServer) and request.stream and request.tools and any( - tool.type in ["web_search_preview", "code_interpreter"] - for tool in request.tools): - return self.create_error_response( - "MCP tool server is not supported in background mode and " - "streaming mode") - # Schedule the request and get the result generator. generators: list[AsyncGenerator[ConversationContext, None]] = [] @@ -267,6 +274,8 @@ class OpenAIServingResponses(OpenAIServing): builtin_tool_list.append("browser") if self.tool_server.has_tool("python"): builtin_tool_list.append("python") + if self.tool_server.has_tool("container"): + builtin_tool_list.append("container") if self.tool_server is not None: available_tools = builtin_tool_list @@ -329,25 +338,44 @@ class OpenAIServingResponses(OpenAIServing): self.response_store[response.id] = response # Run the request in the background. - task = asyncio.create_task( - self._run_background_request( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - created_time, - ), - name=f"create_{response.id}", - ) + if request.stream: + task = asyncio.create_task( + self._run_background_request_stream( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{request.request_id}", + ) + else: + task = asyncio.create_task( + self._run_background_request( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{response.id}", + ) # For cleanup. response_id = response.id self.background_tasks[response_id] = task task.add_done_callback( lambda _: self.background_tasks.pop(response_id, None)) + + if request.stream: + return self.responses_background_stream_generator( + request.request_id) return response if request.stream: @@ -430,7 +458,8 @@ class OpenAIServingResponses(OpenAIServing): async with AsyncExitStack() as exit_stack: try: - await context.init_tool_sessions(self.tool_server, exit_stack) + await context.init_tool_sessions(self.tool_server, exit_stack, + request.request_id) async for _ in result_generator: pass except asyncio.CancelledError: @@ -439,14 +468,22 @@ class OpenAIServingResponses(OpenAIServing): # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) + # NOTE: Implementation of stauts is still WIP, but for now + # we guarantee that if the status is not "completed", it is accurate. + # "completed" is implemented as the "catch-all" for now. + status: ResponseStatus = "completed" + if self.use_harmony: assert isinstance(context, HarmonyContext) output = self._make_response_output_items_with_harmony(context) - # TODO: these are all 0 for now! - num_prompt_tokens = context.num_prompt_tokens - num_generated_tokens = context.num_output_tokens - num_cached_tokens = context.num_cached_tokens - num_reasoning_tokens = context.num_reasoning_tokens + num_tool_output_tokens = context.num_tool_output_tokens + if len(output) > 0: + if context.finish_reason == "length": + status = "incomplete" + elif context.finish_reason == "abort": + status = "cancelled" + else: + status = "incomplete" else: assert isinstance(context, SimpleContext) final_res = context.last_output @@ -459,10 +496,13 @@ class OpenAIServingResponses(OpenAIServing): # Calculate usage. assert final_res.prompt_token_ids is not None - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = len(final_output.token_ids) - num_cached_tokens = final_res.num_cached_tokens - num_reasoning_tokens = 0 + num_tool_output_tokens = 0 + + assert isinstance(context, (SimpleContext, HarmonyContext)) + num_prompt_tokens = context.num_prompt_tokens + num_generated_tokens = context.num_output_tokens + num_cached_tokens = context.num_cached_tokens + num_reasoning_tokens = context.num_reasoning_tokens usage = ResponseUsage( input_tokens=num_prompt_tokens, @@ -471,7 +511,8 @@ class OpenAIServingResponses(OpenAIServing): input_tokens_details=InputTokensDetails( cached_tokens=num_cached_tokens), output_tokens_details=OutputTokensDetails( - reasoning_tokens=num_reasoning_tokens), + reasoning_tokens=num_reasoning_tokens, + tool_output_tokens=num_tool_output_tokens), ) response = ResponsesResponse.from_request( request, @@ -479,7 +520,7 @@ class OpenAIServingResponses(OpenAIServing): model_name=model_name, created_time=created_time, output=output, - status="completed", + status=status, usage=usage, ) @@ -537,6 +578,28 @@ class OpenAIServingResponses(OpenAIServing): )) return out + def _create_stream_response_logprobs( + self, + token_ids: Sequence[int], + logprobs: Optional[SampleLogprobs], + tokenizer: AnyTokenizer, + top_logprobs: Optional[int] = None + ) -> list[response_text_delta_event.Logprob]: + lgs = self._create_response_logprobs(token_ids=token_ids, + logprobs=logprobs, + tokenizer=tokenizer, + top_logprobs=top_logprobs) + return [ + response_text_delta_event.Logprob( + token=lg.token, + logprob=lg.logprob, + top_logprobs=[ + response_text_delta_event.LogprobTopLogprob( + token=tl.token, logprob=tl.logprob) + for tl in lg.top_logprobs + ]) for lg in lgs + ] + def _make_response_output_items( self, request: ResponsesRequest, @@ -614,7 +677,7 @@ class OpenAIServingResponses(OpenAIServing): self, context: HarmonyContext, ) -> list[ResponseOutputItem]: - output_items = [] + output_items: list[ResponseOutputItem] = [] num_init_messages = context.num_init_messages for msg in context.messages[num_init_messages:]: output_items.extend(parse_output_message(msg)) @@ -670,13 +733,21 @@ class OpenAIServingResponses(OpenAIServing): # New conversation. reasoning_effort = (request.reasoning.effort if request.reasoning else None) + # Temporary: OpenAI types doesn't have container tool + # so we used MCP to cover that, up for change tool_types = [tool.type for tool in request.tools] + if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL: + tool_types.append("container") enable_browser = ("web_search_preview" in tool_types and self.tool_server is not None and self.tool_server.has_tool("browser")) enable_code_interpreter = ("code_interpreter" in tool_types and self.tool_server is not None and self.tool_server.has_tool("python")) + enable_container = ("container" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("container")) + with_custom_tools = has_custom_tools(tool_types) sys_msg = get_system_message( reasoning_effort=reasoning_effort, browser_description=self.tool_server.get_tool_description( @@ -685,11 +756,17 @@ class OpenAIServingResponses(OpenAIServing): python_description=self.tool_server.get_tool_description( "python") if enable_code_interpreter and self.tool_server is not None else None, + container_description=self.tool_server.get_tool_description( + "container") + if enable_container and self.tool_server is not None else None, + instructions=request.instructions, + with_custom_tools=with_custom_tools, ) messages.append(sys_msg) - dev_msg = get_developer_message(request.instructions, - request.tools) - messages.append(dev_msg) + if with_custom_tools: + dev_msg = get_developer_message( + instructions=request.instructions, tools=request.tools) + messages.append(dev_msg) else: # Continue the previous conversation. # FIXME(woosuk): Currently, request params like reasoning and @@ -717,7 +794,7 @@ class OpenAIServingResponses(OpenAIServing): prev_msgs.append(msg) messages.extend(prev_msgs) # Append the new input. - # Reponses API supports simple text inputs without chat format. + # Responses API supports simple text inputs without chat format. if isinstance(request.input, str): messages.append(get_user_message(request.input)) else: @@ -736,6 +813,38 @@ class OpenAIServingResponses(OpenAIServing): prev_outputs.append(response_msg) return messages + async def _run_background_request_stream( + self, + request: ResponsesRequest, + *args, + **kwargs, + ): + event_deque: deque[StreamingResponsesResponse] = deque() + new_event_signal = asyncio.Event() + self.event_store[request.request_id] = (event_deque, new_event_signal) + response = None + try: + generator = self.responses_stream_generator( + request, *args, **kwargs) + async for event in generator: + event_deque.append(event) + new_event_signal.set() # Signal new event available + except Exception as e: + logger.exception("Background request failed for %s", + request.request_id) + response = self.create_error_response(str(e)) + finally: + new_event_signal.set() + + if response is not None and isinstance(response, ErrorResponse): + # If the request has failed, update the status to "failed". + response_id = request.request_id + async with self.response_store_lock: + stored_response = self.response_store.get(response_id) + assert stored_response is not None + if stored_response.status not in ("completed", "cancelled"): + stored_response.status = "failed" + async def _run_background_request( self, request: ResponsesRequest, @@ -759,10 +868,38 @@ class OpenAIServingResponses(OpenAIServing): if stored_response.status not in ("completed", "cancelled"): stored_response.status = "failed" + async def responses_background_stream_generator( + self, + response_id: str, + starting_after: Optional[int] = None, + ) -> AsyncGenerator[StreamingResponsesResponse, None]: + if response_id not in self.event_store: + raise ValueError(f"Unknown response_id: {response_id}") + + event_deque, new_event_signal = self.event_store[response_id] + start_index = 0 if starting_after is None else starting_after + 1 + current_index = start_index + + while True: + new_event_signal.clear() + + # Yield existing events from start_index + while current_index < len(event_deque): + event = event_deque[current_index] + yield event + if getattr(event, 'type', 'unknown') == "response.completed": + return + current_index += 1 + + await new_event_signal.wait() + async def retrieve_responses( self, response_id: str, - ) -> Union[ErrorResponse, ResponsesResponse]: + starting_after: Optional[int], + stream: Optional[bool], + ) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[ + StreamingResponsesResponse, None]]: if not response_id.startswith("resp_"): return self._make_invalid_id_error(response_id) @@ -771,6 +908,12 @@ class OpenAIServingResponses(OpenAIServing): if response is None: return self._make_not_found_error(response_id) + + if stream: + return self.responses_background_stream_generator( + response_id, + starting_after, + ) return response async def cancel_responses( @@ -829,7 +972,7 @@ class OpenAIServingResponses(OpenAIServing): status_code=HTTPStatus.BAD_REQUEST, ) - async def _process_streaming_events( + async def _process_simple_streaming_events( self, request: ResponsesRequest, sampling_params: SamplingParams, @@ -839,46 +982,289 @@ class OpenAIServingResponses(OpenAIServing): tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, created_time: int, - ) -> AsyncGenerator[str, None]: - sequence_number = 0 - - def _send_event(event: BaseModel): - nonlocal sequence_number - # Set sequence_number if the event has this attribute - if hasattr(event, 'sequence_number'): - event.sequence_number = sequence_number - sequence_number += 1 - # Get event type from the event's type field if it exists - event_type = getattr(event, 'type', 'unknown') - return (f"event: {event_type}\n" - f"data: {event.model_dump_json(indent=None)}\n\n") - - current_content_index = 0 # FIXME: this number is never changed + _increment_sequence_number_and_return: Callable[ + [StreamingResponsesResponse], StreamingResponsesResponse], + ) -> AsyncGenerator[StreamingResponsesResponse, None]: + current_content_index = 0 current_output_index = 0 - current_item_id = "" # FIXME: this number is never changed - sent_output_item_added = False + current_item_id = "" + reasoning_parser = None + if self.reasoning_parser: + reasoning_parser = self.reasoning_parser(tokenizer) + previous_text = "" + previous_token_ids: list[int] = [] + first_delta_sent = False + previous_delta_messages: list[DeltaMessage] = [] + async for ctx in result_generator: + assert isinstance(ctx, SimpleContext) + if ctx.last_output is None: + continue + if ctx.last_output.outputs: + output = ctx.last_output.outputs[0] + if reasoning_parser: + delta_message = \ + reasoning_parser.extract_reasoning_content_streaming( + previous_text=previous_text, + current_text=previous_text + output.text, + delta_text=output.text, + previous_token_ids=previous_token_ids, + current_token_ids=previous_token_ids + + output.token_ids, + delta_token_ids=output.token_ids, + ) + else: + delta_message = DeltaMessage(content=output.text, ) + previous_text += output.text + previous_token_ids += output.token_ids + if not delta_message: + continue + if not first_delta_sent: + current_item_id = str(uuid.uuid4()) + if delta_message.reasoning_content: + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseReasoningItem( + type="reasoning", + id=current_item_id, + summary=[], + status="in_progress", + ), + )) + else: + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + )) + yield _increment_sequence_number_and_return( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + current_content_index += 1 + first_delta_sent = True + # todo(kebe7jun) tool call support - initial_response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=created_time, - output=[], - status="in_progress", - usage=None, - ).model_dump() - yield _send_event( - ResponseCreatedEvent( - type="response.created", - sequence_number=-1, - response=initial_response, - )) - yield _send_event( - ResponseInProgressEvent( - type="response.in_progress", - sequence_number=-1, - response=initial_response, - )) + # check delta message and previous delta message are + # same as content or reasoning content + if (previous_delta_messages + and previous_delta_messages[-1].reasoning_content + is not None and delta_message.content is not None): + # from reasoning to normal content, send done + # event for reasoning + reason_content = ''.join( + pm.reasoning_content for pm in previous_delta_messages + if pm.reasoning_content is not None) + yield _increment_sequence_number_and_return( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=reason_content, + )) + current_content_index = 0 + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=reason_content, + type="reasoning_text", + ), + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + )) + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + )) + current_output_index += 1 + current_item_id = str(uuid.uuid4()) + yield _increment_sequence_number_and_return( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + current_content_index += 1 + # reset previous delta messages + previous_delta_messages = [] + + if delta_message.reasoning_content is not None: + yield _increment_sequence_number_and_return( + ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=delta_message.reasoning_content, + )) + elif delta_message.content is not None: + yield _increment_sequence_number_and_return( + ResponseTextDeltaEvent( + type="response.output_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=delta_message.content, + logprobs=self._create_stream_response_logprobs( + token_ids=output.token_ids, + logprobs=output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) if request.is_include_output_logprobs() else [], + )) + current_content_index += 1 + + previous_delta_messages.append(delta_message) + if previous_delta_messages: + if previous_delta_messages[-1].reasoning_content is not None: + reason_content = ''.join(pm.reasoning_content + for pm in previous_delta_messages + if pm.reasoning_content is not None) + yield _increment_sequence_number_and_return( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=reason_content, + )) + current_content_index += 1 + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=reason_content, + type="reasoning_text", + ), + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + )) + elif previous_delta_messages[-1].content is not None: + final_content = ''.join(pm.content + for pm in previous_delta_messages + if pm.content is not None) + yield _increment_sequence_number_and_return( + ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=final_content, + logprobs=[], + item_id=current_item_id, + )) + current_content_index += 1 + part = ResponseOutputText( + text=final_content, + type="output_text", + annotations=[], + ) + yield _increment_sequence_number_and_return( + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=part, + )) + current_content_index += 1 + item = ResponseOutputMessage( + type="message", + role="assistant", + content=[ + part, + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=item, + )) + + async def _process_harmony_streaming_events( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[Optional[ConversationContext]], + context: ConversationContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: int, + _increment_sequence_number_and_return: Callable[ + [StreamingResponsesResponse], StreamingResponsesResponse], + ) -> AsyncGenerator[StreamingResponsesResponse, None]: + current_content_index = -1 + current_output_index = 0 + current_item_id: str = "" + sent_output_item_added = False async for ctx in result_generator: @@ -906,7 +1292,7 @@ class OpenAIServingResponses(OpenAIServing): id=current_item_id, summary=[], ) - yield _send_event( + yield _increment_sequence_number_and_return( ResponseReasoningTextDoneEvent( type="response.reasoning_text.done", item_id=current_item_id, @@ -915,7 +1301,7 @@ class OpenAIServingResponses(OpenAIServing): content_index=current_content_index, text=previous_item.content[0].text, )) - yield _send_event( + yield _increment_sequence_number_and_return( ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, @@ -928,8 +1314,8 @@ class OpenAIServingResponses(OpenAIServing): text=previous_item.content[0].text, annotations=[], ) - yield _send_event( - openai_responses_types.ResponseTextDoneEvent( + yield _increment_sequence_number_and_return( + ResponseTextDoneEvent( type="response.output_text.done", sequence_number=-1, output_index=current_output_index, @@ -938,8 +1324,7 @@ class OpenAIServingResponses(OpenAIServing): logprobs=[], item_id=current_item_id, )) - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseContentPartDoneEvent( type="response.content_part.done", sequence_number=-1, @@ -948,8 +1333,8 @@ class OpenAIServingResponses(OpenAIServing): content_index=current_content_index, part=text_content, )) - yield _send_event( - openai_responses_types.ResponseOutputItemDoneEvent( + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, @@ -962,19 +1347,19 @@ class OpenAIServingResponses(OpenAIServing): ), )) + # stream the output of a harmony message if ctx.parser.last_content_delta: if (ctx.parser.current_channel == "final" and ctx.parser.current_recipient is None): if not sent_output_item_added: sent_output_item_added = True - yield _send_event( - openai_responses_types. + current_item_id = f"msg_{random_uuid()}" + yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseOutputMessage( + item=ResponseOutputMessage( id=current_item_id, type="message", role="assistant", @@ -982,23 +1367,23 @@ class OpenAIServingResponses(OpenAIServing): status="in_progress", ), )) - yield _send_event( - openai_responses_types. + current_content_index += 1 + yield _increment_sequence_number_and_return( ResponseContentPartAddedEvent( type="response.content_part.added", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, content_index=current_content_index, - part=openai_responses_types.ResponseOutputText( + part=ResponseOutputText( type="output_text", text="", annotations=[], logprobs=[], ), )) - yield _send_event( - openai_responses_types.ResponseTextDeltaEvent( + yield _increment_sequence_number_and_return( + ResponseTextDeltaEvent( type="response.output_text.delta", sequence_number=-1, content_index=current_content_index, @@ -1012,36 +1397,35 @@ class OpenAIServingResponses(OpenAIServing): and ctx.parser.current_recipient is None): if not sent_output_item_added: sent_output_item_added = True - yield _send_event( - openai_responses_types. + current_item_id = f"msg_{random_uuid()}" + yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseReasoningItem( + item=ResponseReasoningItem( type="reasoning", id=current_item_id, summary=[], status="in_progress", ), )) - yield _send_event( - openai_responses_types. + current_content_index += 1 + yield _increment_sequence_number_and_return( ResponseContentPartAddedEvent( type="response.content_part.added", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, content_index=current_content_index, - part=openai_responses_types.ResponseOutputText( + part=ResponseOutputText( type="output_text", text="", annotations=[], logprobs=[], ), )) - yield _send_event( + yield _increment_sequence_number_and_return( ResponseReasoningTextDeltaEvent( type="response.reasoning_text.delta", item_id=current_item_id, @@ -1058,14 +1442,13 @@ class OpenAIServingResponses(OpenAIServing): ) and ctx.parser.current_recipient == "python": if not sent_output_item_added: sent_output_item_added = True - yield _send_event( - openai_responses_types. + current_item_id = f"tool_{random_uuid()}" + yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( + item=ResponseCodeInterpreterToolCallParam( type="code_interpreter_call", id=current_item_id, code=None, @@ -1074,8 +1457,7 @@ class OpenAIServingResponses(OpenAIServing): status="in_progress", ), )) - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallInProgressEvent( type= "response.code_interpreter_call.in_progress", @@ -1083,8 +1465,7 @@ class OpenAIServingResponses(OpenAIServing): output_index=current_output_index, item_id=current_item_id, )) - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallCodeDeltaEvent( type="response.code_interpreter_call_code.delta", sequence_number=-1, @@ -1092,6 +1473,8 @@ class OpenAIServingResponses(OpenAIServing): item_id=current_item_id, delta=ctx.parser.last_content_delta, )) + + # stream tool call outputs if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: previous_item = ctx.parser.messages[-1] if (self.tool_server is not None @@ -1102,14 +1485,12 @@ class OpenAIServingResponses(OpenAIServing): action = None parsed_args = json.loads(previous_item.content[0].text) if function_name == "search": - action = (openai_responses_types. - response_function_web_search.ActionSearch( - type="search", - query=parsed_args["query"], - )) + action = (response_function_web_search.ActionSearch( + type="search", + query=parsed_args["query"], + )) elif function_name == "open": action = ( - openai_responses_types. response_function_web_search.ActionOpenPage( type="open_page", # TODO: translate to url @@ -1117,7 +1498,6 @@ class OpenAIServingResponses(OpenAIServing): )) elif function_name == "find": action = ( - openai_responses_types. response_function_web_search.ActionFind( type="find", pattern=parsed_args["pattern"], @@ -1128,13 +1508,13 @@ class OpenAIServingResponses(OpenAIServing): raise ValueError( f"Unknown function name: {function_name}") - yield _send_event( - openai_responses_types.ResponseOutputItemAddedEvent( + current_item_id = f"tool_{random_uuid()}" + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - response_function_web_search. + item=response_function_web_search. ResponseFunctionWebSearch( # TODO: generate a unique id for web search call type="web_search_call", @@ -1143,16 +1523,14 @@ class OpenAIServingResponses(OpenAIServing): status="in_progress", ), )) - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseWebSearchCallInProgressEvent( type="response.web_search_call.in_progress", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, )) - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseWebSearchCallSearchingEvent( type="response.web_search_call.searching", sequence_number=-1, @@ -1161,21 +1539,19 @@ class OpenAIServingResponses(OpenAIServing): )) # enqueue - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseWebSearchCallCompletedEvent( type="response.web_search_call.completed", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, )) - yield _send_event( - openai_responses_types.ResponseOutputItemDoneEvent( + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseFunctionWebSearch( + item=ResponseFunctionWebSearch( type="web_search_call", id=current_item_id, action=action, @@ -1187,8 +1563,7 @@ class OpenAIServingResponses(OpenAIServing): and self.tool_server.has_tool("python") and previous_item.recipient is not None and previous_item.recipient.startswith("python")): - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallCodeDoneEvent( type="response.code_interpreter_call_code.done", sequence_number=-1, @@ -1196,29 +1571,26 @@ class OpenAIServingResponses(OpenAIServing): item_id=current_item_id, code=previous_item.content[0].text, )) - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallInterpretingEvent( type="response.code_interpreter_call.interpreting", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, )) - yield _send_event( - openai_responses_types. + yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallCompletedEvent( type="response.code_interpreter_call.completed", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, )) - yield _send_event( - openai_responses_types.ResponseOutputItemDoneEvent( + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( + item=ResponseCodeInterpreterToolCallParam( type="code_interpreter_call", id=current_item_id, code=previous_item.content[0].text, @@ -1229,29 +1601,6 @@ class OpenAIServingResponses(OpenAIServing): ), )) - async def empty_async_generator(): - # A hack to trick Python to think this is a generator but in fact - # it immediately returns. - if False: - yield - - final_response = await self.responses_full_generator( - request, - sampling_params, - empty_async_generator(), - context, - model_name, - tokenizer, - request_metadata, - created_time=created_time, - ) - yield _send_event( - openai_responses_types.ResponseCompletedEvent( - type="response.completed", - sequence_number=-1, - response=final_response.model_dump(), - )) - async def responses_stream_generator( self, request: ResponsesRequest, @@ -1262,20 +1611,80 @@ class OpenAIServingResponses(OpenAIServing): tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, created_time: Optional[int] = None, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[StreamingResponsesResponse, None]: # TODO: # 1. Handle disconnect - if not isinstance(context, StreamingHarmonyContext): - raise NotImplementedError( - "Streaming is not supported for responses API without Harmony." - ) - created_time = created_time or int(time.time()) + sequence_number = 0 + + def _increment_sequence_number_and_return( + event: StreamingResponsesResponse + ) -> StreamingResponsesResponse: + nonlocal sequence_number + # Set sequence_number if the event has this attribute + if hasattr(event, 'sequence_number'): + event.sequence_number = sequence_number + sequence_number += 1 + return event + async with AsyncExitStack() as exit_stack: - await context.init_tool_sessions(self.tool_server, exit_stack) - async for event_data in self._process_streaming_events( + processer = None + if self.use_harmony: + await context.init_tool_sessions(self.tool_server, exit_stack, + request.request_id) + processer = self._process_harmony_streaming_events + else: + processer = self._process_simple_streaming_events + + initial_response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="in_progress", + usage=None, + ).model_dump() + yield _increment_sequence_number_and_return( + ResponseCreatedEvent( + type="response.created", + sequence_number=-1, + response=initial_response, + )) + yield _increment_sequence_number_and_return( + ResponseInProgressEvent( + type="response.in_progress", + sequence_number=-1, + response=initial_response, + )) + + async for event_data in processer( request, sampling_params, result_generator, context, - model_name, tokenizer, request_metadata, created_time): + model_name, tokenizer, request_metadata, created_time, + _increment_sequence_number_and_return): yield event_data + + async def empty_async_generator(): + # A hack to trick Python to think this is a generator but + # in fact it immediately returns. + if False: + yield + + final_response = await self.responses_full_generator( + request, + sampling_params, + empty_async_generator(), + context, + model_name, + tokenizer, + request_metadata, + created_time=created_time, + ) + yield _increment_sequence_number_and_return( + ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response.model_dump(), + )) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 847c014a11dc3..623b1c863f779 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -269,7 +269,7 @@ class ServingScores(OpenAIServing): ) -> Union[list[PoolingRequestOutput], ErrorResponse]: lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer() truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) @@ -353,7 +353,7 @@ class ServingScores(OpenAIServing): final_res_batch, request_id, created_time, - self._get_model_name(request.model), + self.models.model_name(), ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -399,7 +399,7 @@ class ServingScores(OpenAIServing): return self.request_output_to_rerank_response( final_res_batch, request_id, - self._get_model_name(request.model), + self.models.model_name(), documents, top_n, ) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 2f258255d5f16..3918d08ebf81d 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -22,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest, # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.renderer import RenderConfig from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -64,14 +65,15 @@ class OpenAIServingTokenization(OpenAIServing): try: lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer() + renderer = self._get_renderer(tokenizer) if isinstance(request, TokenizeChatRequest): tool_dicts = (None if request.tools is None else [tool.model_dump() for tool in request.tools]) ( _, - request_prompts, + _, engine_prompts, ) = await self._preprocess_chat( request, @@ -87,21 +89,18 @@ class OpenAIServingTokenization(OpenAIServing): add_special_tokens=request.add_special_tokens, ) else: - (request_prompts, - engine_prompts) = await self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, - ) + engine_prompts = await renderer.render_prompt( + prompt_or_prompts=request.prompt, + config=self._build_render_config(request), + ) except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") input_ids: list[int] = [] - for i, engine_prompt in enumerate(engine_prompts): + for engine_prompt in engine_prompts: self._log_inputs(request_id, - request_prompts[i], + engine_prompt, params=None, lora_request=lora_request) @@ -131,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing): lora_request = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer() self._log_inputs(request_id, request.tokens, @@ -158,6 +157,9 @@ class OpenAIServingTokenization(OpenAIServing): return self.create_error_response( f"Failed to get tokenizer info: {str(e)}") + def _build_render_config(self, request: TokenizeRequest) -> RenderConfig: + return RenderConfig(add_special_tokens=request.add_special_tokens) + @dataclass class TokenizerInfo: diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 44aa1208a54c7..35096b0461361 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -16,6 +16,7 @@ from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser from .minimax_tool_parser import MinimaxToolParser from .mistral_tool_parser import MistralToolParser +from .openai_tool_parser import OpenAIToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser @@ -46,4 +47,5 @@ __all__ = [ "Qwen3CoderToolParser", "SeedOssToolParser", "Step3ToolParser", + "OpenAIToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 6ef8fadf59ac5..37c360145b04a 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -35,7 +35,7 @@ class Internlm2ToolParser(ToolParser): self, request: ChatCompletionRequest) -> ChatCompletionRequest: if request.tools and request.tool_choice != 'none': # do not skip special tokens because internlm use the special - # tokens to indicated the start and end of the tool calls + # tokens to indicate the start and end of the tool calls # information. request.skip_special_tokens = False return request @@ -60,8 +60,8 @@ class Internlm2ToolParser(ToolParser): if '<|action_start|>' not in current_text: self.position = len(current_text) return DeltaMessage(content=delta_text) - # if the tool call is sended, return a empty delta message - # to make sure the finish_reason will be send correctly. + # if the tool call is sent, return an empty delta message + # to make sure the finish_reason will be sent correctly. if self.current_tool_id > 0: return DeltaMessage(content='') @@ -89,7 +89,7 @@ class Internlm2ToolParser(ToolParser): try: parsable_arr = action - # tool calls are generated in an object in inernlm2 + # tool calls are generated in an object in internlm2 # it's not support parallel tool calls try: tool_call_arr: dict = partial_json_parser.loads( diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index 6bf44a4345a9d..9a9a19ce2188e 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -176,7 +176,7 @@ class Llama4PythonicToolParser(ToolParser): index] += delta.function.arguments # HACK: serving_chat.py inspects the internal state of tool parsers - # when determining it's final streaming delta, automatically + # when determining its final streaming delta, automatically # adding autocompleted JSON. # These two lines avoid that nonsense while ensuring finish_reason # is set to tool_calls when at least one tool is called. diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c0691f122904e..e6b300fd84e94 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -143,7 +143,7 @@ class MistralToolParser(ToolParser): except json.JSONDecodeError: # use a regex to find the part corresponding to the tool call. # NOTE: This use case should not happen if the model is trained - # correctly. It's a easy possible fix so it's included, but + # correctly. It's an easy possible fix so it's included, but # can be brittle for very complex / highly nested tool calls raw_tool_call = self.tool_call_regex.findall(tool_content)[0] function_call_arr = json.loads(raw_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py new file mode 100644 index 0000000000000..c5d59514b9445 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from vllm.entrypoints.harmony_utils import parse_output_into_messages +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer import AnyTokenizer + + +@ToolParserManager.register_module("openai") +class OpenAIToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + token_ids: Sequence[int] | None = None, + ) -> ExtractedToolCallInformation: + if token_ids is None: + raise NotImplementedError( + "OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501 + ) + + parser = parse_output_into_messages(token_ids) + tool_calls = [] + final_content = None + + if len(parser.messages) > 0: + for msg in parser.messages: + if msg.recipient and msg.recipient.startswith("functions."): + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=msg.recipient.split("functions.")[1], + arguments=msg.content[0].text, + ), + )) + elif msg.channel == "final": + final_content = msg.content[0].text + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=final_content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + raise NotImplementedError( + "Not being used, manual parsing in serving_chat.py" # noqa: E501 + ) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py new file mode 100644 index 0000000000000..fb859d57be9fe --- /dev/null +++ b/vllm/entrypoints/renderer.py @@ -0,0 +1,395 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import io +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Annotated, Optional, Union + +import pybase64 +import torch +from pydantic import Field + +from vllm.config import ModelConfig +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import AsyncMicrobatchTokenizer + + +@dataclass(frozen=True) +class RenderConfig: + """Configuration to control how prompts are prepared.""" + + max_length: Optional[int] = None + """Maximum allowable total input token length. If provided, + token inputs longer than this raise ``ValueError``.""" + + truncate_prompt_tokens: Optional[int] = None + """Number of tokens to keep. ``None`` means no truncation. + ``0`` yields an empty list (and skips embeds). + ``-1`` maps to ``model_config.max_model_len``.""" + + add_special_tokens: Optional[bool] = True + """Whether to add model-specific special tokens during tokenization.""" + + cache_salt: Optional[str] = None + """String to disambiguate prefix cache entries.""" + + needs_detokenization: Optional[bool] = False + """If True, detokenize IDs back to text for inclusion in outputs.""" + + +class BaseRenderer(ABC): + """ + Base class for unified input processing and rendering. + + The Renderer serves as a unified input processor that consolidates + tokenization, chat template formatting, and multimodal input handling + into a single component. + It converts high-level API requests (OpenAI-style JSON) into token IDs and + multimodal features ready for engine consumption. + + Key responsibilities: + - Convert text prompts to token sequences with proper special tokens + - Apply chat templates and format conversations + - Handle multimodal inputs (images, audio, etc.) when applicable + - Manage prompt truncation and length validation + - Provide clean separation between API layer and engine core + """ + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[AnyTokenizer] = None, + ): + super().__init__() + self.model_config = model_config + self.tokenizer = tokenizer + + @abstractmethod + async def render_prompt( + self, + *, + prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + config: "RenderConfig", + ) -> list[EngineTokensPrompt]: + """ + Convert text or token inputs into engine-ready TokensPrompt objects. + + This method accepts text or token inputs and produces a + list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects + for the engine. + + Args: + prompt_or_prompts: One of: + - ``str``: Single text prompt. + - ``list[str]``: Batch of text prompts. + - ``list[int]``: Single pre-tokenized sequence. + - ``list[list[int]]``: Batch of pre-tokenized sequences. + config: Render configuration controlling how prompts are prepared + (e.g., tokenization and length handling). + + Returns: + list[EngineTokensPrompt]: Engine-ready token prompts. + + Raises: + ValueError: If input formats are invalid or length limits exceeded. + """ + raise NotImplementedError + + @abstractmethod + async def render_prompt_and_embeds( + self, + *, + prompt_or_prompts: Optional[Union[str, list[str], list[int], + list[list[int]]]] = None, + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, + config: "RenderConfig", + ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + """ + Convert text/token and/or base64-encoded embeddings inputs into + engine-ready prompt objects using a unified RenderConfig. + + At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be + provided and non-empty. If both are omitted or empty (e.g., empty + string and empty list), a ``ValueError`` is raised. + + Args: + prompt_or_prompts: Text or token inputs to include. + prompt_embeds: Base64-encoded bytes (or list thereof) containing a + torch-saved tensor to be used as prompt embeddings. + config: Render configuration controlling how prompts are prepared + (e.g., tokenization and length handling). + + Returns: + list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + Engine-ready prompt objects. + + Raises: + ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds`` + are omitted or empty (decoder prompt cannot be empty), or if + length limits are exceeded. + """ + raise NotImplementedError + + @classmethod + def load_prompt_embeds( + cls, + prompt_embeds: Union[bytes, list[bytes]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None, + cache_salt: Optional[str] = None, + ) -> list[EngineEmbedsPrompt]: + """Load and validate base64-encoded embeddings into prompt objects.""" + + def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: + tensor = torch.load( + io.BytesIO(pybase64.b64decode(embed, validate=True)), + weights_only=True, + map_location=torch.device("cpu"), + ) + assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( + torch.float32, + torch.bfloat16, + torch.float16, + ) + tensor = tensor.to_dense() + if tensor.dim() > 2: + tensor = tensor.squeeze(0) + assert tensor.dim() == 2 + if truncate_prompt_tokens is not None: + tensor = tensor[-truncate_prompt_tokens:] + embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor) + if cache_salt is not None: + embeds_prompt["cache_salt"] = cache_salt + return embeds_prompt + + if isinstance(prompt_embeds, list): + return [_load_and_validate_embed(embed) for embed in prompt_embeds] + + return [_load_and_validate_embed(prompt_embeds)] + + +class CompletionRenderer(BaseRenderer): + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[AnyTokenizer] = None, + async_tokenizer_pool: Optional[dict[AnyTokenizer, + AsyncMicrobatchTokenizer]] = None, + ): + super().__init__(model_config, tokenizer) + self.async_tokenizer_pool = async_tokenizer_pool + self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None + + async def render_prompt( + self, + *, + prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + config: "RenderConfig", + ) -> list[EngineTokensPrompt]: + """Implementation of prompt rendering for completion-style requests. + + Uses async tokenizer pooling for improved performance. See base class + for detailed parameter documentation. + """ + truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( + config.truncate_prompt_tokens, config.max_length) + if truncate_prompt_tokens == 0: + return [] + + # Parse and batch the input prompts + batch_inputs = parse_and_batch_prompt(prompt_or_prompts) + + tasks = [] + for prompt_input in batch_inputs: + if prompt_input["is_tokens"] is True: + # Token input + # Note: detokenization is needed when echo is enabled, + # where the input token IDs are decoded back to text. + task = self._maybe_detokenize(prompt_input["content"], + config.max_length, + truncate_prompt_tokens, + config.cache_salt, + config.needs_detokenization) + else: + # Text input + task = self._tokenize(prompt_input["content"], + config.max_length, + truncate_prompt_tokens, + config.add_special_tokens, + config.cache_salt) + tasks.append(task) + + # Wait for all text tokenization to finish + if tasks: + tokenized_text_prompts = await asyncio.gather(*tasks) + return tokenized_text_prompts + + return [] + + async def render_prompt_and_embeds( + self, + *, + prompt_or_prompts: Optional[Union[str, list[str], list[int], + list[list[int]]]] = None, + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, + config: "RenderConfig", + ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + """ + Render text/token prompts and/or precomputed embedding prompts. At + least one of `prompt_or_prompts` or `prompt_embeds` must be provided. + """ + truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( + config.truncate_prompt_tokens, config.max_length) + if truncate_prompt_tokens == 0: + return [] + + rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = [] + + if prompt_embeds is not None: + rendered.extend( + self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens, + config.cache_salt)) + if prompt_or_prompts is None or prompt_or_prompts == "": + return rendered + + token_prompts = await self.render_prompt( + prompt_or_prompts=prompt_or_prompts, + config=config, + ) + rendered.extend(token_prompts) + + return rendered + + def _validate_and_normalize_truncate_tokens( + self, + truncate_prompt_tokens: Optional[int], + max_length: Optional[int], + ) -> Optional[int]: + """Validate and normalize truncate_prompt_tokens parameter.""" + if truncate_prompt_tokens is None: + return None + + if truncate_prompt_tokens == 0: + return 0 + + if truncate_prompt_tokens < 0: + truncate_prompt_tokens = self.model_config.max_model_len + + if max_length is not None and truncate_prompt_tokens > max_length: + raise ValueError( + f"truncate_prompt_tokens ({truncate_prompt_tokens}) " + f"cannot be greater than max_length ({max_length}). " + f"Please select a smaller truncation size.") + + return truncate_prompt_tokens + + def _maybe_apply_truncation( + self, token_ids: list[int], + truncate_prompt_tokens: Optional[int]) -> list[int]: + """Apply truncation to token sequence.""" + if truncate_prompt_tokens is None: + return token_ids + if truncate_prompt_tokens >= len(token_ids): + return token_ids + + return token_ids[-truncate_prompt_tokens:] + + async def _tokenize( + self, + text: str, + max_length: Optional[int], + truncate_prompt_tokens: Optional[int], + add_special_tokens: Optional[bool], + cache_salt: Optional[str], + ) -> EngineTokensPrompt: + """Tokenize text input asynchronously.""" + async_tokenizer = self._get_async_tokenizer() + + # Handle encoder-specific preprocessing + if (self.model_config.encoder_config is not None + and self.model_config.encoder_config.get( + "do_lower_case", False)): + text = text.lower() + + # Tokenize texts + if truncate_prompt_tokens is None: + encoded = await async_tokenizer( + text, add_special_tokens=add_special_tokens) + else: + encoded = await async_tokenizer( + text, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) + + return self._create_tokens_prompt(encoded.input_ids, max_length, + cache_salt, text) + + async def _maybe_detokenize( + self, + token_ids: list[int], + max_length: Optional[int], + truncate_prompt_tokens: Optional[int], + cache_salt: Optional[str], + needs_detokenization: Optional[bool] = False, + ) -> EngineTokensPrompt: + """Optionally detokenize token IDs and build a tokens prompt.""" + token_ids = self._maybe_apply_truncation(token_ids, + truncate_prompt_tokens) + + prompt = None + if needs_detokenization is True: + async_tokenizer = self._get_async_tokenizer() + prompt = await async_tokenizer.decode(token_ids) + + return self._create_tokens_prompt(token_ids=token_ids, + max_length=max_length, + cache_salt=cache_salt, + prompt=prompt) + + def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: + """Get or create async tokenizer using shared pool.""" + async_tokenizer = self.async_tokenizer + if async_tokenizer is not None: + return async_tokenizer + + tokenizer = self.tokenizer + if self.tokenizer is None: + raise ValueError( + "No tokenizer available for text input processing") + + if self.async_tokenizer_pool is None: + async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) + else: + async_tokenizer = self.async_tokenizer_pool.get(tokenizer) + if async_tokenizer is None: + async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) + self.async_tokenizer_pool[tokenizer] = async_tokenizer + self.async_tokenizer = async_tokenizer + return async_tokenizer + + def _create_tokens_prompt( + self, + token_ids: list[int], + max_length: Optional[int] = None, + cache_salt: Optional[str] = None, + prompt: Optional[str] = None, + ) -> EngineTokensPrompt: + """Create validated EngineTokensPrompt.""" + if max_length is not None and len(token_ids) > max_length: + raise ValueError( + f"This model's maximum context length is {max_length} tokens. " + f"However, your request has {len(token_ids)} input tokens. " + "Please reduce the length of the input messages.") + + tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) + if cache_salt is not None: + tokens_prompt["cache_salt"] = cache_salt + if prompt is not None: + tokens_prompt["prompt"] = prompt + return tokens_prompt diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py index 758789a5e059d..f5f4d7d3b5565 100644 --- a/vllm/entrypoints/tool.py +++ b/vllm/entrypoints/tool.py @@ -4,6 +4,8 @@ import os from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from openai_harmony import Author, Message, Role, TextContent + from vllm.logger import init_logger if TYPE_CHECKING: @@ -99,6 +101,28 @@ class HarmonyPythonTool(Tool): return self.python_tool = PythonTool() + + async def validate(self): + if not self.enabled: + return + try: + message = Message( + author=Author(role=Role.ASSISTANT), + content=[TextContent(text="print('Hello, world!')")], + channel="analysis", + recipient="python", + content_type="code", + ) + msgs = [] + async for msg in self.python_tool.process(message): + msgs.append(msg) + assert msgs[0].content[0].text == "Hello, world!\n" + except Exception as e: + self.enabled = False + logger.warning_once( + "Code interpreter tool failed to initialize (%s), code " + "interpreter is disabled", e) + return logger.info_once("Code interpreter tool initialized") async def get_result(self, context: "ConversationContext") -> Any: diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py index 2f28595f27c6a..056a571fb2fd1 100644 --- a/vllm/entrypoints/tool_server.py +++ b/vllm/entrypoints/tool_server.py @@ -86,7 +86,8 @@ class ToolServer(ABC): pass @abstractmethod - def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]: + def new_session(self, tool_name: str, + session_id: str) -> AbstractAsyncContextManager[Any]: """ Create a session for the tool. """ @@ -124,7 +125,8 @@ class MCPToolServer(ToolServer): description=tool.description, parameters=tool.inputSchema) for tool in list_tools_response.tools - ]) + ], + ) self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp if tool_from_mcp.name not in self.urls: self.urls[tool_from_mcp.name] = url @@ -142,14 +144,16 @@ class MCPToolServer(ToolServer): return self.harmony_tool_descriptions.get(tool_name) @asynccontextmanager - async def new_session(self, tool_name: str): + async def new_session(self, tool_name: str, session_id: str): from mcp import ClientSession from mcp.client.sse import sse_client url = self.urls.get(tool_name) + headers = {"x-session-id": session_id} if not url: raise KeyError(f"Tool '{tool_name}' is not supported") - async with sse_client(url=url) as streams, ClientSession( - *streams) as session: + async with sse_client(url=url, + headers=headers) as streams, ClientSession( + *streams) as session: await session.initialize() yield session @@ -158,10 +162,13 @@ class DemoToolServer(ToolServer): def __init__(self): self.tools: dict[str, Tool] = {} + + async def init_and_validate(self): browser_tool = HarmonyBrowserTool() + python_tool = HarmonyPythonTool() + await python_tool.validate() if browser_tool.enabled: self.tools["browser"] = browser_tool - python_tool = HarmonyPythonTool() if python_tool.enabled: self.tools["python"] = python_tool logger.info("DemoToolServer initialized with tools: %s", @@ -182,7 +189,7 @@ class DemoToolServer(ToolServer): raise ValueError(f"Unknown tool {tool_name}") @asynccontextmanager - async def new_session(self, tool_name: str): + async def new_session(self, tool_name: str, session_id: str): if tool_name not in self.tools: raise KeyError(f"Tool '{tool_name}' is not supported") yield self.tools[tool_name] diff --git a/vllm/env_override.py b/vllm/env_override.py index ef425d433320d..b06703a2fbf9d 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -13,24 +13,6 @@ logger = init_logger(__name__) # that interact with vllm workers. # they are executed whenever `import vllm` is called. -if os.environ.get('NCCL_CUMEM_ENABLE', '0') != '0': - logger.warning( - "NCCL_CUMEM_ENABLE is set to %s, skipping override. " - "This may increase memory overhead with cudagraph+allreduce: " - "https://github.com/NVIDIA/nccl/issues/1234", - os.environ['NCCL_CUMEM_ENABLE']) -elif not os.path.exists('/dev/nvidia-caps-imex-channels'): - # NCCL requires NCCL_CUMEM_ENABLE to work with - # multi-node NVLink, typically on GB200-NVL72 systems. - # The ultimate way to detect multi-node NVLink is to use - # NVML APIs, which are too expensive to call here. - # As an approximation, we check the existence of - # /dev/nvidia-caps-imex-channels, used by - # multi-node NVLink to communicate across nodes. - # This will still cost some GPU memory, but it is worthwhile - # because we can get very fast cross-node bandwidth with NVLink. - os.environ['NCCL_CUMEM_ENABLE'] = '0' - # see https://github.com/vllm-project/vllm/pull/15951 # it avoids unintentional cuda initialization from torch.cuda.is_available() os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' diff --git a/vllm/envs.py b/vllm/envs.py index 1232bd7bf9635..19e2f8635275d 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -6,7 +6,7 @@ import json import os import sys import tempfile -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union if TYPE_CHECKING: VLLM_HOST_IP: str = "" @@ -37,6 +37,7 @@ if TYPE_CHECKING: VLLM_CONFIGURE_LOGGING: int = 1 VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_PREFIX: str = "" + VLLM_LOGGING_STREAM: str = "ext://sys.stdout" VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None VLLM_LOG_STATS_INTERVAL: float = 10. @@ -55,11 +56,12 @@ if TYPE_CHECKING: VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False - VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto" + VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", + "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True VLLM_XLA_USE_SPMD: bool = False - VLLM_WORKER_MULTIPROC_METHOD: str = "fork" + VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 @@ -69,13 +71,15 @@ if TYPE_CHECKING: VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_TARGET_DEVICE: str = "cuda" + VLLM_MAIN_CUDA_VERSION: str = "12.8" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None VLLM_USE_PRECOMPILED: bool = False VLLM_DOCKER_BUILD_CONTEXT: bool = False VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False - CMAKE_BUILD_TYPE: Optional[str] = None + CMAKE_BUILD_TYPE: Optional[Literal["Debug", "Release", + "RelWithDebInfo"]] = None VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms @@ -91,6 +95,7 @@ if TYPE_CHECKING: VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] + VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False @@ -130,29 +135,35 @@ if TYPE_CHECKING: VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_TPU_USING_PATHWAYS: bool = False - VLLM_USE_DEEP_GEMM: bool = False + VLLM_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False - VLLM_FLASHINFER_MOE_BACKEND: str = "throughput" + VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", + "latency"] = "throughput" VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 - VLLM_ALL2ALL_BACKEND: str = "naive" + VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter"] = \ + "allgather_reducescatter" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 - VLLM_KV_CACHE_LAYOUT: Optional[str] = None + VLLM_KV_CACHE_LAYOUT: Optional[Literal["NHD", "HND"]] = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False - VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" + VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal["FP", "INT8", "INT6", "INT4", + "NONE"] = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 @@ -162,12 +173,20 @@ if TYPE_CHECKING: VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None + VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False + VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False + VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False + VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False + VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False + VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False + VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True + VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" def get_default_cache_root(): @@ -196,6 +215,48 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: return bool(int(value)) +def env_with_choices( + env_name: str, + default: Optional[str], + choices: Union[list[str], Callable[[], list[str]]], + case_sensitive: bool = True) -> Callable[[], Optional[str]]: + """ + Create a lambda that validates environment variable against allowed choices + + Args: + env_name: Name of the environment variable + default: Default value if not set (can be None) + choices: List of valid string options or callable that returns list + case_sensitive: Whether validation should be case sensitive + + Returns: + Lambda function for environment_variables dict + """ + + def _get_validated_env() -> Optional[str]: + value = os.getenv(env_name) + if value is None: + return default + + # Resolve choices if it's a callable (for lazy loading) + actual_choices = choices() if callable(choices) else choices + + if not case_sensitive: + check_value = value.lower() + check_choices = [choice.lower() for choice in actual_choices] + else: + check_value = value + check_choices = actual_choices + + if check_value not in check_choices: + raise ValueError(f"Invalid value '{value}' for {env_name}. " + f"Valid options: {actual_choices}.") + + return value + + return _get_validated_env + + def get_vllm_port() -> Optional[int]: """Get the port from VLLM_PORT environment variable. @@ -235,10 +296,15 @@ environment_variables: dict[str, Callable[[], Any]] = { # ================== Installation Time Env Vars ================== # Target device of vLLM, supporting [cuda (by default), - # rocm, neuron, cpu] + # rocm, cpu] "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), + # Main CUDA version of vLLM, supporting [12.6, 12.8, 12.9], + # 12.8 is the default. This follows PyTorch but can be overridden. + "VLLM_MAIN_CUDA_VERSION": + lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() or "12.8", + # Maximum number of compilation jobs to run in parallel. # By default this is the number of CPUs "MAX_JOBS": @@ -271,7 +337,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" "CMAKE_BUILD_TYPE": - lambda: os.getenv("CMAKE_BUILD_TYPE"), + env_with_choices("CMAKE_BUILD_TYPE", None, + ["Debug", "Release", "RelWithDebInfo"]), # If set, vllm will print verbose logs during installation "VERBOSE": @@ -431,6 +498,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_LOGGING_LEVEL": lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), + # this is used for configuring the default logging stream + "VLLM_LOGGING_STREAM": + lambda: os.getenv("VLLM_LOGGING_STREAM", "ext://sys.stdout"), + # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages "VLLM_LOGGING_PREFIX": lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), @@ -456,15 +527,21 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), # Backend for attention computation - # Available options: + # Example options: # - "TORCH_SDPA": use torch.nn.MultiheadAttention # - "FLASH_ATTN": use FlashAttention # - "XFORMERS": use XFormers # - "ROCM_FLASH": use ROCmFlashAttention # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA + # - "FLASH_ATTN_MLA": use FlashAttention for MLA + # - "FLASHINFER_MLA": use FlashInfer for MLA + # - "CUTLASS_MLA": use CUTLASS for MLA + # All possible options loaded dynamically from _Backend enum "VLLM_ATTENTION_BACKEND": - lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + env_with_choices("VLLM_ATTENTION_BACKEND", None, + lambda: list(__import__('vllm.platforms.interface', \ + fromlist=['_Backend'])._Backend.__members__.keys())), # If set, vllm will use flashinfer sampler "VLLM_USE_FLASHINFER_SAMPLER": @@ -527,7 +604,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # - "shm": use shared memory and gRPC for communication # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": - lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto"), + env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto", + ["auto", "nccl", "shm"]), # If the env var is set, it enables GPU communication overlap # (experimental feature) in Ray's Compiled Graph. This flag is ignored if @@ -546,7 +624,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # Use dedicated multiprocess context for workers. # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": - lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "fork"), + env_with_choices("VLLM_WORKER_MULTIPROC_METHOD", "fork", + ["spawn", "fork"]), # Path to the cache for storing downloaded assets "VLLM_ASSETS_CACHE": @@ -729,6 +808,13 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ "VLLM_DISABLED_KERNELS"].split(","), + # Swaps the all reduce backend that we use to coordinate the DP padding + # information from NCCL to gloo. + "VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION": + lambda: + (os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() in + ("true", "1")), + # If set, use the V1 code path. "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), @@ -803,7 +889,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION": - lambda: os.getenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE").upper(), + env_with_choices("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE", + ["FP", "INT8", "INT6", "INT4", "NONE"]), # Custom quick allreduce kernel for MI3* cards # Due to the lack of the bfloat16 asm instruction, bfloat16 @@ -960,7 +1047,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": - lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. "VLLM_USE_DEEP_GEMM_E8M0": @@ -994,6 +1081,15 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))), + # If set to 1, use the FlashInfer CUTLASS backend for + # MXFP8 (activation) x MXFP4 (weight) MoE. + # This is separate from the TRTLLMGEN path controlled by + # VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8. + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS": + lambda: bool(int( + os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0") + )), + # If set to 1, use the FlashInfer # BF16 (activation) x MXFP4 (weight) MoE backend. "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": @@ -1031,26 +1127,29 @@ environment_variables: dict[str, Callable[[], Any]] = { # all2all backend for vllm's expert parallel communication # Available options: - # - "naive": naive all2all implementation using all-reduce + # - "naive": naive all2all implementation using broadcasts + # - "allgather_reducescatter": all2all implementation based on allgather and + # reducescatter # - "pplx": use pplx kernels # - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_low_latency", use deepep low-latency kernels "VLLM_ALL2ALL_BACKEND": - lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), + env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter", + ["naive", "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter"]), - # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. Both - # require compute capability 10.0 or above. + # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. + # Both require compute capability 10.0 or above. # Available options: # - "throughput": [default] # Uses CUTLASS kernels optimized for high-throughput batch inference. # - "latency": # Uses TensorRT-LLM kernels optimized for low-latency inference. - # To set this backend, define the environment variable: - # export VLLM_FLASHINFER_MOE_BACKEND=latency. - # If not set, defaults to "throughput". - "VLLM_FLASHINFER_MOE_BACKEND": lambda: os.getenv( - "VLLM_FLASHINFER_MOE_BACKEND", "throughput" - ), + "VLLM_FLASHINFER_MOE_BACKEND": + env_with_choices("VLLM_FLASHINFER_MOE_BACKEND", "throughput", + ["throughput", "latency"]), # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for @@ -1063,7 +1162,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # vllm should use flashinfer fused allreduce. The variable should be a # JSON with the following format: # { <world size>: <max size in mb> } - # Unspecified world sizes will fallback to + # Unspecified world sizes will fall back to # { 2: 64, 4: 1, <everything else>: 0.5 } "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB": lambda: json.loads(os.getenv( @@ -1106,7 +1205,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # leave the layout choice to the backend. Mind that backends may only # implement and support a subset of all possible layouts. "VLLM_KV_CACHE_LAYOUT": - lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None), + env_with_choices("VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"]), # Enable checking whether the generated logits contain NaNs, # indicating corrupted output. Useful for debugging low level bugs @@ -1131,9 +1230,16 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_CUDNN_PREFILL": lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), - # If set to 1, use the TRTLLM attention backend in flashinfer. + # If set to 1/True, use the TRTLLM attention backend in flashinfer. + # If set to 0/False, use the default attention backend in flashinfer. + # If not set, auto-detect the attention backend in flashinfer. "VLLM_USE_TRTLLM_ATTENTION": - lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + lambda: (None if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ else + os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true")), + + # If set to 1, when we use fp8 kv, we do not quantize Q to fp8 + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION": + lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))), # If set, it means we pre-downloaded cubin files and flashinfer will # read the cubin files directly. @@ -1191,6 +1297,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), + # If set, use the fp8 mfma in rocm paged attention. + "VLLM_ROCM_FP8_MFMA_PAGE_ATTN": + lambda: bool(int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0"))), + # Whether to use pytorch symmetric memory for allreduce "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), @@ -1199,6 +1309,29 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), + # Allows vllm use container tool + "VLLM_GPT_OSS_USE_CONTAINER_TOOL": + lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))), + + # Allows harmony instructions to be injected on system messages + "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": + lambda: bool( + int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))), + + # Add optional custom scopes for profiling, disable to avoid overheads + "VLLM_CUSTOM_SCOPES_FOR_PROFILING": + lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))), + + # Represent block hashes in KV cache events as 64-bit integers instead of + # raw bytes. Defaults to True for backward compatibility. + "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": + lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))), + + # Name of the shared memory buffer used for object storage. + # Only effective when mm_config.mm_processor_cache_type == "shm". + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": + lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", + "VLLM_OBJECT_STORAGE_SHM_BUFFER"), } # --8<-- [end:env-vars-definition] @@ -1269,9 +1402,11 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "VLLM_USE_CUDNN_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "VLLM_ROCM_USE_AITER", "VLLM_ROCM_USE_AITER_PAGED_ATTN", "VLLM_ROCM_USE_AITER_LINEAR", @@ -1287,6 +1422,7 @@ def compute_hash() -> str: "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", + "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", ] for key in environment_variables_to_hash: # if this goes out of sync with environment_variables, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 813232cd19281..d18bef1256af5 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -13,6 +13,7 @@ from typing_extensions import TypeVar import vllm.platforms from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -54,6 +55,7 @@ class ExecutorBase(ABC): self._init_executor() self.is_sleeping = False self.sleeping_tags: set[str] = set() + self.kv_output_aggregator = None @abstractmethod def _init_executor(self) -> None: @@ -231,7 +233,7 @@ class ExecutorBase(ABC): def shutdown(self) -> None: """Shutdown the executor.""" - return + self.collective_rpc("shutdown") def __del__(self): self.shutdown() @@ -252,6 +254,11 @@ class ExecutorBase(ABC): exception.""" self.check_health() + def init_kv_output_aggregator(self, finished_count: Optional[int]) -> None: + """Init KVOutputAggregator""" + self.kv_output_aggregator = KVOutputAggregator( + finished_count or self.parallel_config.world_size) + class DistributedExecutorBase(ExecutorBase): """Abstract superclass of distributed executor implementations.""" diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 37c3fe59c65dd..78d0ee6c1e3fc 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -117,10 +117,12 @@ class RayDistributedExecutor(DistributedExecutorBase): self.driver_worker.execute_method) def shutdown(self) -> None: - logger.info( - "Shutting down Ray distributed executor. If you see error log " - "from logging.cc regarding SIGTERM received, please ignore because " - "this is the expected termination process in Ray.") + if logger: + # Somehow logger can be None here. + logger.info( + "Shutting down Ray distributed executor. If you see error log " + "from logging.cc regarding SIGTERM received, please ignore " + "because this is the expected termination process in Ray.") if hasattr(self, "forward_dag") and self.forward_dag is not None: self.forward_dag.teardown() import ray diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index aabc9ed9b80a2..3b566e88a9ec2 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import os +from concurrent.futures import Future, ThreadPoolExecutor +from functools import cached_property +from multiprocessing import Lock from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -10,9 +12,13 @@ import torch.distributed as dist import vllm.envs as envs from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, run_method) from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.executor.utils import get_and_update_mm_cache +from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -27,15 +33,7 @@ class UniProcExecutor(ExecutorBase): """ self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - local_rank = 0 - # set local rank as the device index if specified - device_info = self.vllm_config.device_config.device.__str__().split( - ":") - if len(device_info) > 1: - local_rank = int(device_info[1]) - rank = 0 + distributed_init_method, rank, local_rank = self._distributed_args() is_driver_worker = True kwargs = dict( vllm_config=self.vllm_config, @@ -44,19 +42,58 @@ class UniProcExecutor(ExecutorBase): distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) + self.mm_receiver_cache = worker_receiver_cache_from_config( + self.vllm_config, MULTIMODAL_REGISTRY, Lock()) + + self.async_output_thread: Optional[ThreadPoolExecutor] = None + if self.max_concurrent_batches > 1: + self.async_output_thread = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="WorkerAsyncOutput") + self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") self.collective_rpc("load_model") + def _distributed_args(self) -> tuple[str, int, int]: + """Return (distributed_init_method, rank, local_rank).""" + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + # set local rank as the device index if specified + device_info = self.vllm_config.device_config.device.__str__().split( + ":") + local_rank = int(device_info[1]) if len(device_info) > 1 else 0 + return distributed_init_method, 0, local_rank + + @cached_property + def max_concurrent_batches(self) -> int: + return 2 if self.scheduler_config.async_scheduling else 1 + def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, args: Tuple = (), - kwargs: Optional[Dict] = None) -> List[Any]: + kwargs: Optional[Dict] = None, + non_block: bool = False) -> List[Any]: if kwargs is None: kwargs = {} - answer = run_method(self.driver_worker, method, args, kwargs) - return [answer] + if self.mm_receiver_cache is not None and method == "execute_model": + get_and_update_mm_cache(self.mm_receiver_cache, args) + + if not non_block: + return [run_method(self.driver_worker, method, args, kwargs)] + + try: + result = run_method(self.driver_worker, method, args, kwargs) + if isinstance(result, AsyncModelRunnerOutput): + if (async_thread := self.async_output_thread) is not None: + return [async_thread.submit(result.get_output)] + result = result.get_output() + future = Future[Any]() + future.set_result(result) + except Exception as e: + future = Future[Any]() + future.set_exception(e) + return [future] def check_health(self) -> None: # UniProcExecutor will always be healthy as long as @@ -71,6 +108,10 @@ class UniProcExecutor(ExecutorBase): self.shutdown() return + def shutdown(self) -> None: + if worker := self.driver_worker: + worker.shutdown() + UniProcExecutorAsync = UniProcExecutor @@ -104,8 +145,9 @@ class ExecutorWithExternalLauncher(UniProcExecutor): assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ ("To get deterministic execution in V1, " "please set VLLM_ENABLE_V1_MULTIPROCESSING=0") - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, - rpc_rank=0) + super()._init_executor() + + def _distributed_args(self) -> tuple[str, int, int]: # engines are launched in torchrun-compatible launchers # so we can use the env:// method. # required env vars: @@ -116,17 +158,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor): distributed_init_method = "env://" rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) - is_driver_worker = True - kwargs = dict( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker, - ) - self.collective_rpc("init_worker", args=([kwargs], )) - self.collective_rpc("init_device") - self.collective_rpc("load_model") + return distributed_init_method, rank, local_rank def determine_num_available_blocks(self) -> Tuple[int, int]: """ diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c57c51d289ac8..3b535423f7bca 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -13,6 +13,8 @@ import torch.distributed as dist import vllm.envs as envs from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -75,14 +77,73 @@ class DPMetadata: Gather the num_tokens across all DP ranks and return results in a CPU tensor of size dp_size. """ + from vllm.distributed.parallel_state import get_dp_group + device = current_platform.device_type + group = get_dp_group().device_group + + # Transfering this tensor from GPU to CPU will introduce a GPU sync + # point that could adversely affect performance of vllm with asynch + # scheduling. This environment variable exists to quickly disable + # this optimization if we run into this case. + if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: + logger.info_once( + "Using CPU all reduce to syncronize DP padding between ranks.") + device = "cpu" + group = get_dp_group().cpu_group num_tokens_across_dp = [0] * dp_size num_tokens_across_dp[dp_rank] = num_tokens num_tokens_tensor = torch.tensor(num_tokens_across_dp, - device="cpu", + device=device, dtype=torch.int32) + dist.all_reduce(num_tokens_tensor, group=group) + return num_tokens_tensor.cpu() + + @staticmethod + def should_ubatch_across_dp( + should_ubatch: bool, orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int, dp_size: int, + dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]: + """ + 1. Decides if each DP rank is going to microbatch. Either all ranks + run with microbatching or none of them do. If this function decides + not to run with microbatching. It will "abort" meaning that no padding + information will be returned to the caller. It will return (False, None) + + 2. Determines the total number of tokens that each rank will run. + All ranks will be padded out so that the run with the same number + of tokens + + Returns: tuple[ + should_ubatch: Are all DP ranks going to microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. Will be + None if should_ubatch if False + ] + """ + + device = current_platform.device_type + tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32) + tensor[0][dp_rank] = orig_num_tokens_per_ubatch + tensor[1][dp_rank] = padded_num_tokens_per_ubatch + tensor[2][dp_rank] = 1 if should_ubatch else 0 + from vllm.distributed.parallel_state import get_dp_group - dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) - return num_tokens_tensor + dist.all_reduce(tensor, group=get_dp_group().device_group) + + result: bool = bool(torch.all(tensor[2] == 1).item()) + if not result: + return result, None + + orig_num_tokens_tensor = tensor[0, :] + padded_num_tokens_tensor = tensor[1, :] + + orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) + padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) + if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): + logger.debug("Aborting ubatching %s %s", orig_min_num_tokens, + padded_max_num_tokens) + return False, None + return result, padded_num_tokens_tensor.cpu() @staticmethod def make( @@ -106,14 +167,15 @@ class DPMetadata: # If num_tokens_across_dp is None, it will be computed by all_reduce # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize - assert (num_tokens_across_dp is None - or num_tokens_across_dp[dp_rank] == batchsize) + assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank] + == batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}" if num_tokens_across_dp is None: num_tokens_across_dp = DPMetadata.num_tokens_across_dp( batchsize, dp_size, dp_rank) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) - return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) + return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu, + num_tokens_across_dp) @contextmanager def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): @@ -166,9 +228,12 @@ class ForwardContext: Type AttentionMetadata for v0, Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata - set dynamically for each forward pass + Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one + for each microbatch. + Set dynamically for each forward pass """ - attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] + attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"], + list[dict[str, "AttentionMetadata"]]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass @@ -178,6 +243,8 @@ class ForwardContext: cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE batch_descriptor: Optional[BatchDescriptor] = None + ubatch_slices: Optional[UBatchSlices] = None + def __post_init__(self): assert self.cudagraph_runtime_mode in [ CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ @@ -195,6 +262,39 @@ def get_forward_context() -> ForwardContext: return _forward_context +def create_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + dp_metadata: Optional[DPMetadata] = None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: Optional[BatchDescriptor] = None, + ubatch_slices: Optional[UBatchSlices] = None): + return ForwardContext(no_compile_layers=vllm_config.compilation_config. + static_forward_context, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata, + dp_metadata=dp_metadata, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices) + + +@contextmanager +def override_forward_context(forward_context: Optional[ForwardContext]): + """A context manager that overrides the current forward context. + This is used to override the forward context for a specific + forward pass. + """ + global _forward_context + prev_context = _forward_context + _forward_context = forward_context + try: + yield + finally: + _forward_context = prev_context + + @contextmanager def set_forward_context( attn_metadata: Any, @@ -203,7 +303,8 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None): + batch_descriptor: Optional[BatchDescriptor] = None, + ubatch_slices: Optional[UBatchSlices] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -212,6 +313,7 @@ def set_forward_context( need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() + dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1 and ( attn_metadata is not None or num_tokens is not None): @@ -219,20 +321,14 @@ def set_forward_context( attn_metadata, num_tokens or 0, num_tokens_across_dp) - global _forward_context - prev_context = _forward_context - _forward_context = ForwardContext( - no_compile_layers=vllm_config.compilation_config. - static_forward_context, - virtual_engine=virtual_engine, - attn_metadata=attn_metadata, - dp_metadata=dp_metadata, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ) + forward_context = create_forward_context(attn_metadata, vllm_config, + virtual_engine, dp_metadata, + cudagraph_runtime_mode, + batch_descriptor, ubatch_slices) try: - yield + with override_forward_context(forward_context): + yield finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: @@ -269,5 +365,3 @@ def set_forward_context( logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), forward_stats) - - _forward_context = prev_context diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 065d0ab59291a..6a005aa634e85 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -52,6 +52,9 @@ class TokensPrompt(TypedDict): prompt_token_ids: list[int] """A list of token IDs to pass to the model.""" + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + token_type_ids: NotRequired[list[int]] """A list of token type IDs to pass to the cross encoder model.""" diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 094fcf021b619..cb3a5cdb840e6 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,13 +9,11 @@ from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs, MultiModalUUIDDict) from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, EncoderDecoderInputs, ProcessorInputs, PromptType, @@ -31,7 +29,7 @@ class InputPreprocessor: def __init__( self, model_config: ModelConfig, - tokenizer: Optional[TokenizerGroup], + tokenizer: Optional[AnyTokenizer], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, ) -> None: @@ -42,32 +40,28 @@ class InputPreprocessor: self.mm_registry = mm_registry self.mm_processor_cache = mm_processor_cache - def get_tokenizer_group(self) -> TokenizerGroup: + def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: raise ValueError("You cannot pass text prompts when " "`skip_tokenizer_init` is True") return self.tokenizer - def get_bos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: + def get_bos_token_id(self) -> Optional[int]: if self.tokenizer is None: logger.warning("Using None for BOS token id because tokenizer " "is not initialized") return None - return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + return self.tokenizer.bos_token_id - def get_eos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: + def get_eos_token_id(self) -> Optional[int]: if self.tokenizer is None: logger.warning("Using None for EOS token id because tokenizer " "is not initialized") return None - return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + return self.tokenizer.eos_token_id def get_decoder_start_token_id(self) -> Optional[int]: """ @@ -190,14 +184,13 @@ class InputPreprocessor: def _tokenize_prompt( self, prompt: str, - lora_request: Optional[LoRARequest], tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[int]: """ Apply the model's tokenizer to a text prompt, returning the corresponding token IDs. """ - tokenizer = self.get_tokenizer_group() + tokenizer = self.get_tokenizer() tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) encoder_config = self.model_config.encoder_config @@ -205,50 +198,39 @@ class InputPreprocessor: if encoder_config and encoder_config.get("do_lower_case", False): prompt = prompt.lower() - return tokenizer.encode(prompt=prompt, - lora_request=lora_request, - **tokenization_kwargs) + return tokenizer.encode(prompt, **tokenization_kwargs) async def _tokenize_prompt_async( self, prompt: str, - lora_request: Optional[LoRARequest], tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[int]: """ Async version of [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt]. """ - tokenizer = self.get_tokenizer_group() + tokenizer = self.get_tokenizer() tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) - return await tokenizer.encode_async(prompt=prompt, - lora_request=lora_request, - **tokenization_kwargs) + return tokenizer.encode(prompt, **tokenization_kwargs) - def _get_mm_tokenizer( - self, - lora_request: Optional[LoRARequest], - ) -> AnyTokenizer: + def _get_mm_tokenizer(self) -> AnyTokenizer: # PrithviGeoSpatialMAE needs to be initialized without a tokenizer # while using also multi-modal input if not self.tokenizer: return cast(AnyTokenizer, object()) # Dummy - tokenizer_group = self.get_tokenizer_group() - return tokenizer_group.get_lora_tokenizer(lora_request) + tokenizer = self.get_tokenizer() + return tokenizer - async def _get_mm_tokenizer_async( - self, - lora_request: Optional[LoRARequest], - ) -> AnyTokenizer: + async def _get_mm_tokenizer_async(self) -> AnyTokenizer: # PrithviGeoSpatialMAE needs to be initialized without a tokenizer # while using also multi-modal input if not self.tokenizer: return cast(AnyTokenizer, object()) # Dummy - tokenizer_group = self.get_tokenizer_group() - return await tokenizer_group.get_lora_tokenizer_async(lora_request) + tokenizer = self.get_tokenizer() + return tokenizer def _process_multimodal( self, @@ -256,16 +238,14 @@ class InputPreprocessor: mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, returning the corresponding token IDs and metadata. """ - tokenizer = self._get_mm_tokenizer(lora_request) + tokenizer = self._get_mm_tokenizer() mm_processor = self.mm_registry.create_processor( self.model_config, @@ -276,13 +256,23 @@ class InputPreprocessor: if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply( + mm_input = mm_processor.apply( prompt, mm_data, hf_processor_mm_kwargs=mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) + mm_hashes = mm_input["mm_hashes"] + + # Validate that all mm items have a string as their hash + if not contains_only_strings(mm_hashes): + raise ValueError( + f"mm_hashes must contain only strings, got: {mm_hashes}. " + "This is likely due to an incorrect custom implementation of " + "MultiModalProcessor.apply method.") + + return mm_input async def _process_multimodal_async( self, @@ -290,16 +280,14 @@ class InputPreprocessor: mm_data: MultiModalDataDict, mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Async version of [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal]. """ - tokenizer = await self._get_mm_tokenizer_async(lora_request) + tokenizer = await self._get_mm_tokenizer_async() mm_processor = self.mm_registry.create_processor( self.model_config, @@ -310,13 +298,23 @@ class InputPreprocessor: if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply( + mm_input = mm_processor.apply( prompt, mm_data, hf_processor_mm_kwargs=mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) + mm_hashes = mm_input["mm_hashes"] + + # Validate that all mm items have a string as their hash + if not contains_only_strings(mm_hashes): + raise ValueError( + f"mm_hashes must contain only strings, got: {mm_hashes}. " + "This is likely due to an incorrect custom implementation of " + "MultiModalProcessor.apply method.") + + return mm_input def _process_embeds( self, @@ -368,10 +366,8 @@ class InputPreprocessor: self, parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_token_ids = self._truncate_inputs( parsed_content["prompt_token_ids"], tokenization_kwargs) @@ -383,8 +379,7 @@ class InputPreprocessor: multi_modal_data, parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) else: inputs = token_inputs(prompt_token_ids=prompt_token_ids) @@ -398,10 +393,8 @@ class InputPreprocessor: self, parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_token_ids = self._truncate_inputs( parsed_content["prompt_token_ids"], tokenization_kwargs) @@ -413,8 +406,7 @@ class InputPreprocessor: multi_modal_data, parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) else: inputs = token_inputs(prompt_token_ids=prompt_token_ids, ) @@ -428,10 +420,8 @@ class InputPreprocessor: self, parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -442,13 +432,11 @@ class InputPreprocessor: multi_modal_data, parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) else: prompt_token_ids = self._tokenize_prompt( prompt_text, - lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) inputs = token_inputs( @@ -465,10 +453,8 @@ class InputPreprocessor: self, parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -479,13 +465,11 @@ class InputPreprocessor: multi_modal_data, parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) else: prompt_token_ids = await self._tokenize_prompt_async( prompt_text, - lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) inputs = token_inputs( @@ -502,10 +486,8 @@ class InputPreprocessor: self, prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> SingletonInputs: """ Extract the singleton inputs from a prompt. @@ -513,7 +495,6 @@ class InputPreprocessor: Arguments: * prompt: single encoder or decoder input prompt - * lora_request: this is only valid for decoder prompts Returns: @@ -526,22 +507,19 @@ class InputPreprocessor: if parsed["type"] == "tokens": return self._process_tokens( parsed["content"], - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if parsed["type"] == "text": return self._process_text( parsed["content"], tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if parsed["type"] == "str": return self._process_text( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) assert_never(parsed) @@ -550,10 +528,8 @@ class InputPreprocessor: self, prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> SingletonInputs: """ Async version of @@ -566,22 +542,19 @@ class InputPreprocessor: if parsed["type"] == "tokens": return await self._process_tokens_async( parsed["content"], - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if parsed["type"] == "text": return await self._process_text_async( parsed["content"], tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if parsed["type"] == "str": return await self._process_text_async( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) assert_never(parsed) @@ -692,8 +665,7 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> EncoderDecoderInputs: """ For encoder/decoder models only: @@ -735,7 +707,7 @@ class InputPreprocessor: encoder_inputs = self._prompt_to_llm_inputs( prompt["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if (decoder_input := prompt["decoder_prompt"]) is None: decoder_inputs = None @@ -751,7 +723,7 @@ class InputPreprocessor: inputs = self._prompt_to_llm_inputs( prompt, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model @@ -768,8 +740,7 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> EncoderDecoderInputs: """ Async version of @@ -782,7 +753,7 @@ class InputPreprocessor: encoder_task = self._prompt_to_llm_inputs_async( prompt["encoder_prompt"], tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if (decoder_input := prompt["decoder_prompt"]) is None: @@ -792,7 +763,7 @@ class InputPreprocessor: decoder_task = self._prompt_to_llm_inputs_async( decoder_input, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) encoder_inputs, decoder_inputs = await asyncio.gather( @@ -808,7 +779,7 @@ class InputPreprocessor: inputs = await self._prompt_to_llm_inputs_async( prompt, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model @@ -834,10 +805,8 @@ class InputPreprocessor: self, prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> DecoderOnlyInputs: """ For decoder-only models: @@ -847,7 +816,6 @@ class InputPreprocessor: Arguments: * prompt: input prompt - * lora_request Returns: @@ -857,8 +825,7 @@ class InputPreprocessor: prompt_comps = self._prompt_to_llm_inputs( prompt, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -867,10 +834,8 @@ class InputPreprocessor: self, prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> DecoderOnlyInputs: """ Async version of @@ -879,8 +844,7 @@ class InputPreprocessor: prompt_comps = await self._prompt_to_llm_inputs_async( prompt, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -889,10 +853,8 @@ class InputPreprocessor: self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: @@ -901,7 +863,7 @@ class InputPreprocessor: return self._process_encoder_decoder_prompt( prompt, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if is_explicit_encoder_decoder_prompt(prompt): @@ -912,18 +874,15 @@ class InputPreprocessor: return self._process_decoder_only_prompt( prompt, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) async def preprocess_async( self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[LoRARequest] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> ProcessorInputs: """ Async version of @@ -935,7 +894,7 @@ class InputPreprocessor: return await self._process_encoder_decoder_prompt_async( prompt, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) if is_explicit_encoder_decoder_prompt(prompt): @@ -946,10 +905,21 @@ class InputPreprocessor: return await self._process_decoder_only_prompt_async( prompt, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) def clear_cache(self) -> None: if self.mm_processor_cache is not None: self.mm_processor_cache.clear_cache() + + +# Helper function to validate that a nested dictionary contains +# only strings or list of strings as the leaf values. +def contains_only_strings(obj: object): + if isinstance(obj, str): + return True + if isinstance(obj, list): + return all(isinstance(x, str) for x in obj) + if isinstance(obj, dict): + return all(contains_only_strings(v) for v in obj.values()) + return False diff --git a/vllm/logger.py b/vllm/logger.py index 8f06eb03c7f93..2861e0f1686c4 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -20,9 +20,10 @@ VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX +VLLM_LOGGING_STREAM = envs.VLLM_LOGGING_STREAM _FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " - "[%(filename)s:%(lineno)d] %(message)s") + "[%(fileinfo)s:%(lineno)d] %(message)s") _DATE_FORMAT = "%m-%d %H:%M:%S" DEFAULT_LOGGING_CONFIG = { @@ -38,7 +39,7 @@ DEFAULT_LOGGING_CONFIG = { "class": "logging.StreamHandler", "formatter": "vllm", "level": VLLM_LOGGING_LEVEL, - "stream": "ext://sys.stdout", + "stream": VLLM_LOGGING_STREAM, }, }, "loggers": { diff --git a/vllm/logging_utils/formatter.py b/vllm/logging_utils/formatter.py index 0affef10078dc..004b79f3ea6e2 100644 --- a/vllm/logging_utils/formatter.py +++ b/vllm/logging_utils/formatter.py @@ -2,16 +2,77 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging +from pathlib import Path + +from vllm import envs class NewLineFormatter(logging.Formatter): """Adds logging prefix to newlines to align multi-line messages.""" def __init__(self, fmt, datefmt=None, style="%"): - logging.Formatter.__init__(self, fmt, datefmt, style) + super().__init__(fmt, datefmt, style) + + self.use_relpath = envs.VLLM_LOGGING_LEVEL == "DEBUG" + if self.use_relpath: + self.root_dir = Path(__file__).resolve().parent.parent.parent def format(self, record): - msg = logging.Formatter.format(self, record) + + def shrink_path(relpath: Path) -> str: + """ + Shortens a file path for logging display: + - Removes leading 'vllm' folder if present. + - If path starts with 'v1', + keeps the first two and last two levels, + collapsing the middle as '...'. + - Otherwise, keeps the first and last two levels, + collapsing the middle as '...'. + - If the path is short, returns it as-is. + - Examples: + vllm/model_executor/layers/quantization/utils/fp8_utils.py -> + model_executor/.../quantization/utils/fp8_utils.py + vllm/model_executor/layers/quantization/awq.py -> + model_executor/layers/quantization/awq.py + vllm/v1/attention/backends/mla/common.py -> + v1/attention/backends/mla/common.py + + Args: + relpath (Path): The relative path to be shortened. + Returns: + str: The shortened path string for display. + """ + parts = list(relpath.parts) + new_parts = [] + if parts and parts[0] == "vllm": + parts = parts[1:] + if parts and parts[0] == "v1": + new_parts += parts[:2] + parts = parts[2:] + elif parts: + new_parts += parts[:1] + parts = parts[1:] + if len(parts) > 2: + new_parts += ["..."] + parts[-2:] + else: + new_parts += parts + return "/".join(new_parts) + + if self.use_relpath: + abs_path = getattr(record, "pathname", None) + if abs_path: + try: + relpath = Path(abs_path).resolve().relative_to( + self.root_dir) + except Exception: + relpath = Path(record.filename) + else: + relpath = Path(record.filename) + record.fileinfo = shrink_path(relpath) + else: + record.fileinfo = record.filename + + msg = super().format(record) if record.message != "": parts = msg.split(record.message) msg = msg.replace("\n", "\r\n" + parts[0]) diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py deleted file mode 100644 index 7fc4cfe026aee..0000000000000 --- a/vllm/lora/fully_sharded_layers.py +++ /dev/null @@ -1,355 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# pylint: disable=unused-argument -from typing import TYPE_CHECKING, Optional, Union, cast - -import torch -import torch.nn as nn -from transformers import PretrainedConfig - -from vllm.config import LoRAConfig -from vllm.distributed.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA, - RowParallelLinearWithLoRA) -from vllm.platforms import current_platform - -if TYPE_CHECKING: - pass - - -def _fully_sharded_can_replace(can_replace): - """ - decorator which adds the condition of fully sharded loras - intended to wrap can_replace_layer() - """ - - def dec(*args, **kwargs): - return (can_replace(*args, **kwargs) - and kwargs["lora_config"].fully_sharded_loras) - - return dec - - -def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): - """ - For `ColumnParallelLinearWithLoRA` or classes that inherit from - `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. - """ - assert (layer.n_slices == len(layer.lora_a_stacked) == len( - layer.lora_b_stacked) == len(layer.output_slices)) - if layer.lora_bias_stacked is not None: - assert layer.n_slices == len(layer.lora_bias_stacked) - - output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - - # Since communication is needed, the buffer is directly initialized as a - # tensor rather than a tuple of tensor. - buffers = torch.zeros( - (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - - shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( - buffers, x, layer.lora_a_stacked, 1.0) - - if not current_platform.can_update_inplace(): - buffers = shrunk_buffers - - buffers = tensor_model_parallel_all_gather(buffers) - - lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( - output, - buffers, - layer.lora_b_stacked, - layer.lora_bias_stacked, - layer.output_slices, - offset_start=0, - add_input=True) - - if not current_platform.can_update_inplace(): - output = lora_output - - output = output.view(*out_orig_shape) - # now have column partitioned and packed output - return output - - -# these layers are based on the tensor parallelism strategy given in -# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, -# https://arxiv.org/abs/2311.03285. - - -class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): - """ - Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, - # their `lora_a` and `lora_b` have different sharding patterns. After - # completing the `lora_a` GEMM , a gather operation is performed. - # Therefore, the sharding of `lora_a` only needs to correspond with the - # gather operation. - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class MergedColumnParallelLinearWithShardedLoRA( - MergedColumnParallelLinearWithLoRA): - """ - Differs from MergedColumnParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - #NOTE: lora_a contains 2 subloras, and each sublora could be None. - output_shard_size = self.lora_a_stacked[0].shape[2] - output_start_idx = self.tp_rank * output_shard_size - lora_a = [ - lora_a[0][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[0] is not None else None, - lora_a[1][:, output_start_idx:output_start_idx + - output_shard_size] if lora_a[1] is not None else None, - ] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): - """ - Differs from QKVParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked[0].shape[2] - start_idx = tp_rank * shard_size - lora_a = lora_a[:, start_idx:start_idx + shard_size] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): - """ - Differs from MergedQKVParallelLinearWithLoRA by slicing the - LoRA A's also. - - Based on S-LoRA, slicing happens along the rank dim. - """ - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - # NOTE: lora_a contains 3 subloras, and each sublora could be None. - shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] - start_idx = [self.tp_rank * shard_size[i] for i in range(3)] - lora_a = [ - lora_a[0][:, start_idx[0]:start_idx[0] + - shard_size[0]] if lora_a[0] is not None else None, - lora_a[1][:, start_idx[1]:start_idx[1] + - shard_size[1]] if lora_a[1] is not None else None, - lora_a[2][:, start_idx[2]:start_idx[2] + - shard_size[2]] if lora_a[2] is not None else None, - ] - return lora_a - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _mcp_apply(x, bias, self) - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) - - -class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): - """ - Differs from RowParallelLinearWithLoRA by slicing the - LoRA B's also. - - Based on S-LoRA, slicing happens along the output dim. - This yields a combined partial sum from the row parallel base - layer and column partitioned output from the LoRA. - """ - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - shard_size = self.lora_b_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - if bias is None: - return bias - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - shard_size = self.lora_bias_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, - output.shape[-1]), output.shape - buffer = torch.zeros( - (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - - shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( - buffer, x, self.lora_a_stacked, 1.0) - if not current_platform.can_update_inplace(): - buffer = shrunk_buffer - - buffer = tensor_model_parallel_all_reduce(buffer) - - # following S-LoRA, allows the fusing of all_gather and all_reduce - # by adding the column partitioned lora output to a slice of output - # tensor, which is a partial sum due to row parallel. All that - # remains is a standard all_reduce. User should be aware though that - # the output is not the same as a normal row_parallel, it should be - # reduced before being used - # NOTE offset are based on the rank. - shard_size = self.lora_b_stacked[0].shape[2] - offset_start = self.tp_rank * shard_size - lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( - output, - buffer, - self.lora_b_stacked, - self.lora_bias_stacked, - self.output_slices, - offset_start=offset_start, - add_input=True, - ) - - if not current_platform.can_update_inplace(): - output = lora_output - - output = output.view(*out_orig_shape) - return output - - @classmethod - @_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # specifying kwargs so they can be easily accessed in decorator - return super().can_replace_layer( - source_layer=source_layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config, - decorate=False, - ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py deleted file mode 100644 index d8503b20459f6..0000000000000 --- a/vllm/lora/layers.py +++ /dev/null @@ -1,1192 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# pylint: disable=unused-argument -import math -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union, cast - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PretrainedConfig - -from vllm.adapter_commons.layers import AdapterMapping -from vllm.config import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) -from vllm.distributed.utils import divide -# yapf: disable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -# yapf: enable -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.platforms import current_platform - -if TYPE_CHECKING: - from vllm.lora.punica_wrapper import PunicaWrapperBase - - -def _get_lora_device(base_layer: nn.Module) -> torch.device: - # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 - """Returns the device for where to place the LoRA tensors.""" - # unquantizedLinear - if hasattr(base_layer, "weight"): - return base_layer.weight.device - # Compressed Tensor - elif hasattr(base_layer, "weight_packed"): - return base_layer.weight_packed.device - # GPTQ/AWQ - elif hasattr(base_layer, "qweight"): - return base_layer.qweight.device - # HQQ marlin - elif hasattr(base_layer, "W_q"): - return base_layer.W_q.device - else: - raise ValueError(f"Unsupported base layer: {base_layer}") - - -def _not_fully_sharded_can_replace(can_replace): - """ - decorator which adds the condition of not using fully sharded loras - intended to wrap can_replace_layer() - """ - - def dec(*args, **kwargs): - decorate = kwargs.pop("decorate") if "decorate" in kwargs else True - condition = (not kwargs["lora_config"].fully_sharded_loras - if decorate else True) - return can_replace(*args, **kwargs) and condition - - return dec - - -@dataclass -class LoRAMapping(AdapterMapping): - is_prefill: bool = False - - -class BaseLayerWithLoRA(nn.Module): - - def slice_lora_a( - self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: - """Slice lora a if splitting for tensor parallelism.""" - ... - - def slice_lora_b( - self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: - """Slice lora b if splitting with tensor parallelism.""" - ... - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """Initializes lora matrices.""" - ... - - def reset_lora(self, index: int): - """Resets the lora weights at index back to 0.""" - ... - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - """Overwrites lora tensors at index.""" - ... - - def set_mapping( - self, - punica_wrapper, - ): - self.punica_wrapper: PunicaWrapperBase = punica_wrapper - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - """Returns True if the layer can be replaced by this LoRA layer.""" - raise NotImplementedError - - -class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): - - def __init__(self, base_layer: VocabParallelEmbedding) -> None: - super().__init__() - self.base_layer = base_layer - self.embeddings_slice: Optional[tuple[int, int]] - self.embeddings_weights: Optional[torch.Tensor] - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: - - if self.base_layer.num_added_embeddings_per_partition > 0: - # We can start adding lora weights - self.embeddings_weights = self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:self. - base_layer.num_org_embeddings_per_partition + - self.base_layer.num_added_embeddings_per_partition] - self.embeddings_slice = ( - self.base_layer.shard_indices.added_vocab_start_index - - self.base_layer.org_vocab_size, - self.base_layer.shard_indices.added_vocab_end_index - - self.base_layer.org_vocab_size) - self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:].fill_(0) - else: - self.embeddings_slice = None - self.embeddings_weights = None - - self.embeddings_tensors = torch.zeros( - ( - max_loras, - lora_config.lora_extra_vocab_size, - self.base_layer.embedding_dim, - ), - dtype=self.base_layer.weight.dtype, - device=self.base_layer.weight.device, - ) - self.lora_a_stacked = torch.zeros( - ( - max_loras, - self.base_layer.org_vocab_size + - lora_config.lora_extra_vocab_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ) - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - self.base_layer.embedding_dim, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ) - self.lora_a_stacked_2d = self.lora_a_stacked.view( - self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], - self.lora_a_stacked.shape[2], - ) - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = 0 - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( - lora_a, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ].copy_(embeddings_tensor, non_blocking=True) - if self.embeddings_slice is not None: - # TODO(yard1): Optimize this copy, we don't need to copy - # everything, just the modified part - embeddings = self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * - self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2], - )[self.embeddings_slice[0]:self.embeddings_slice[1]] - assert self.embeddings_weights is not None - self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, - 1, 0) - - # NB: Don't use torch.narrow here. torch.narrow triggers some - # Dynamic Shape specialization in torch.compile - num_tokens = x.shape[0] - indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] - indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] - - full_lora_a_embeddings = F.embedding( - x + indices_1, - self.lora_a_stacked_2d, - ) - full_output = self.base_layer.forward(x + - (indices_0 * added_tokens_mask)) - - full_output_org = full_output - if full_output.ndim == 3: - full_output = full_output.view( - full_output.shape[0] * full_output.shape[1], -1) - if full_lora_a_embeddings.ndim == 3: - full_lora_a_embeddings = full_lora_a_embeddings.view( - full_lora_a_embeddings.shape[0] * - full_lora_a_embeddings.shape[1], - -1, - ) - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_embedding( - full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) - - if not current_platform.can_update_inplace(): - full_output = lora_output - - return full_output.view_as(full_output_org) - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is VocabParallelEmbedding - - @property - def weight(self): - return self.base_layer.weight - - -class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): - - def __init__(self, base_layer: LinearBase): - super().__init__() - self.base_layer = base_layer - self.input_size = self.base_layer.input_size - self.device = _get_lora_device(self.base_layer) - self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None - - self.output_slices: tuple[int, ...] - self.tp_size: int - self.output_size: int - self.n_slices: int - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - self.lora_config = lora_config - # - if isinstance(self.base_layer, ReplicatedLinear): - lora_a_out_size = lora_config.max_lora_rank - lora_b_out_size = self.output_size - - elif isinstance(self.base_layer, ColumnParallelLinear): - lora_a_out_size = (lora_config.max_lora_rank if - not lora_config.fully_sharded_loras else divide( - lora_config.max_lora_rank, self.tp_size)) - lora_b_out_size = self.output_size - - elif isinstance(self.base_layer, RowParallelLinear): - lora_a_out_size = lora_config.max_lora_rank - lora_b_out_size = (self.output_size if - not lora_config.fully_sharded_loras else divide( - self.output_size, self.tp_size)) - else: - raise NotImplementedError - - self.lora_a_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_a_out_size, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_b_out_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - if lora_config.bias_enabled: - lora_bias_out_size = lora_b_out_size - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_bias_out_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.output_slices = (self.lora_b_stacked[0].shape[2], ) - - def reset_lora(self, index: int): - for s_index in range(self.n_slices): - self.lora_a_stacked[s_index][index] = 0 - self.lora_b_stacked[s_index][index] = 0 - if self.lora_config.bias_enabled: - # Make mypy happy - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - self.lora_bias_stacked[s_index][index] = 0 - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - # Except for QKVParallelLinearWithLoRA and - # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers - # store weights in a tuple of size 1. These two layers will - # override this function. - assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == - self.n_slices == 1) - - self.reset_lora(index) - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - self.lora_a_stacked[0][index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[0][index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if lora_bias is not None: - - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - assert len(self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( - lora_bias.T, non_blocking=True) - - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - - # In transformers backend, x and output have extra batch dimension like - # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), - # therefore we need to flatten the batch dimensions. - if x.ndim == 3 and output.ndim == 3: - output = output.flatten(0, 1) - x = x.flatten(0, 1) - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_linear( - output, x, self.lora_a_stacked, self.lora_b_stacked, - self.lora_bias_stacked, 1.0, self.output_slices) - if not current_platform.can_update_inplace(): - output = lora_output - - return output - - @property - def weight(self) -> torch.Tensor: - - # unquantizedLinear - if hasattr(self.base_layer, "weight"): - return self.base_layer.weight - # Compressed Tensor - elif hasattr(self.base_layer, "weight_packed"): - return self.base_layer.weight_packed - # GPTQ/AWQ - elif hasattr(self.base_layer, "qweight"): - return self.base_layer.qweight - # marlin - elif hasattr(self.base_layer, "B"): - return self.base_layer.B - # HQQ marlin - elif hasattr(self.base_layer, "W_q"): - return self.base_layer.W_q - else: - raise ValueError(f"Unsupported base layer: {self.base_layer}") - - @property - def bias(self) -> Optional[torch.Tensor]: - if hasattr(self.base_layer, "bias"): - return self.base_layer.bias - else: - return None - - -class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): - - def __init__(self, base_layer: ReplicatedLinear) -> None: - super().__init__(base_layer, ) - # To ensure interface compatibility, set to 1 always. - self.tp_size = 1 - self.output_size = self.base_layer.output_size - self.n_slices = 1 - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of ReplicatedLinearWithLoRA - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) - - # Matrix multiply. - output = self.apply(input_, bias) - - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) - - if not self.base_layer.return_bias: - return output - - return output, output_bias - - # ReplicatedLinear should always be replaced, regardless of the fully - # sharded LoRAs setting, because it is, by definition, copied per GPU. - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is ReplicatedLinear - - -class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - """ - LoRA on top of ColumnParallelLinear layer. - LoRA B is sliced for tensor parallelism. - There are two types for the `base_layer`: - 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. - 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. - """ - - def __init__(self, base_layer: ColumnParallelLinear) -> None: - super().__init__(base_layer) - # The base_layer type is ColumnParallelLinear or - # MergedColumnParallelLinear, their weight sharding logic is - # inconsistent when TP is greater than 1. - self.is_merged_col_linear = type( - base_layer) is MergedColumnParallelLinear - self.tp_size = get_tensor_model_parallel_world_size() - self.output_size = self.base_layer.output_size_per_partition - # There is only one LoRA layer - self.n_slices = 1 - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - return lora_a - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - # Applicable to cases where the base_layer is - # MergedColumnParallelLinear. - if self.is_merged_col_linear: - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size // 2 - offset = lora_b.shape[-1] // 2 - - left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * - shard_size] - right_weight = lora_b[:, offset + tp_rank * shard_size:offset + - (tp_rank + 1) * shard_size] - lora_b = torch.cat([left_weight, right_weight], dim=1) - # Applicable to cases where the base_layer is - # ColumnParallelLinear. - else: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - # TODO: Fix the slicing logic of bias. - if bias is None: - return bias - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of ColumnParallelLinear - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) - - # Matrix multiply. - output_parallel = self.apply(input_, bias) - if self.base_layer.gather_output: - # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) - else: - output = output_parallel - - if not self.base_layer.return_bias: - return output - - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) - return output, output_bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is ColumnParallelLinear or ( - type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 1) - - -class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - """ColumnParallelLinear layer that is composed of 2 sublayers (slices) - packed together (e.g. gate_proj + up_proj -> gate_up_proj). - - This means we have 2 LoRAs, each applied to one half of the layer. - - Both slices must have the same size. - """ - - def __init__( - self, base_layer: Union[MergedColumnParallelLinear, - QKVParallelLinear]) -> None: - super().__init__(base_layer) - # There are two LoRA layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - # the output_sizes in MergedColumnParallelLinear is not sharded by tp - # we need to divide it by the tp_size to get correct slices size - output_sizes = self.base_layer.output_sizes - self.output_slices = tuple( - divide(output_size, self.tp_size) for output_size in output_sizes) - self.n_slices = len(self.output_slices) - self.output_ids = (self.tp_rank, ) * self.n_slices - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overriding this function is to enhance code - maintainability. - """ - self.lora_config = lora_config - - lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) - - self.lora_a_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for _ in range(self.n_slices)) - self.lora_b_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) for output_size in self.output_slices) - if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) for output_size in self.output_slices) - - def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - return lora_a - - def slice_lora_b( - self, lora_b: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - sliced_lora_b = [None] * self.n_slices - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): - if (lora_b_i := lora_b[i]) is not None: - sliced_lora_b[i] = lora_b_i[:, - shard_size * shard_id:shard_size * - (shard_id + 1)] - return sliced_lora_b - - def slice_bias( - self, bias: list[Union[torch.Tensor, - None]]) -> list[Union[torch.Tensor, None]]: - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): - if (bias_i := bias[i]) is not None: - bias[i] = bias_i[shard_size * shard_id:shard_size * - (shard_id + 1)] - return bias - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - for i in range(self.n_slices): - if (lora_a_i := lora_a[i]) is not None: - self.lora_a_stacked[i][ - index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( - lora_a_i.T, non_blocking=True) - if (lora_b_i := lora_b[i]) is not None: - self.lora_b_stacked[i][ - index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( - lora_b_i.T, non_blocking=True) - - if lora_bias is not None: - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) - for i in range(self.n_slices): - if (lora_bias_i := lora_bias[i]) is not None: - self.lora_bias_stacked[i][index, - 0, :lora_bias_i.shape[0]].copy_( - lora_bias_i.T, - non_blocking=True) - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return (type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 2) - - -class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - """ - ColumnParallelLinear layer that is specifically designed for - qkv_proj. Certain models, such as chatglm3 and baichuan-7b, - only contains a single LoRA within their qkv_proj layer. - - During inference with Tensor Parallel, the weights of lora_b - must be accurately partitioned according to the respective ranks. - - Q slice may have different shape than K and V slices (which both have - the same shape). - """ - - def __init__(self, base_layer: QKVParallelLinear) -> None: - super().__init__(base_layer) - self.q_proj_total_size = (self.base_layer.total_num_heads * - self.base_layer.head_size) - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * - self.base_layer.head_size) - # There is only one LoRA layer - self.n_slices = 1 - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - tp_rank = get_tensor_model_parallel_rank() - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas - lora_b_q = lora_b[:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - k_offset = self.q_proj_total_size - lora_b_k = lora_b[:, k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - v_offset = k_offset + self.kv_proj_total_size - lora_b_v = lora_b[:, v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - bias_q = bias[self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - k_offset = self.q_proj_total_size - bias_k = bias[k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - v_offset = k_offset + self.kv_proj_total_size - bias_v = bias[v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] - bias = torch.cat([bias_q, bias_k, bias_v], dim=1) - return bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is QKVParallelLinear and len( - packed_modules_list) == 1 - - -class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): - """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) - packed together in qkv proj fashion - (q_proj + k_proj + v_proj -> qkv_proj). - - This means we have 3 LoRAs, each applied to one slice of the layer. - - Q slice may have different shape than K and V slices (which both have - the same shape). - """ - - def __init__(self, base_layer: QKVParallelLinear) -> None: - super().__init__(base_layer) - # There are three LoRA layer. - self.n_slices = len(self.base_layer.output_sizes) - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.q_shard_id = self.tp_rank - self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas - - self.output_slices = ( - self.q_proj_shard_size, - self.kv_proj_shard_size, - self.kv_proj_shard_size, - ) - self.output_ids = ( - self.q_shard_id, - self.kv_shard_id, - self.kv_shard_id, - ) - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overloading this function is to handle inconsistent - weight dimensions in qkv lora. - """ - super().create_lora_weights(max_loras, lora_config, model_config) - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return (type(source_layer) is QKVParallelLinear - and len(packed_modules_list) == 3) - - -#TODO: Implement this -class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): - pass - - -class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - - def __init__(self, base_layer: RowParallelLinear) -> None: - super().__init__(base_layer) - - self.tp_size = get_tensor_model_parallel_world_size() - # reset input_size - self.input_size = self.base_layer.input_size_per_partition - self.output_size = self.base_layer.output_size - - self.tp_rank = get_tensor_model_parallel_rank() - # There is only one LoRA layer. - self.n_slices = 1 - - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - - shard_size = self.input_size - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] - return lora_a - - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - return lora_b - - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias - - def forward( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: - """Forward of RowParallelLinear - - Args: - input_: tensor whose last dimension is `input_size`. If - `input_is_parallel` is set, then the last dimension - is `input_size // tp_size`. - - Returns: - - output - - bias - """ - # set up backprop all-reduce. - if self.base_layer.input_is_parallel: - input_parallel = input_ - else: - # TODO: simplify code below - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.base_layer.tp_size) - input_parallel = splitted_input[self.tp_rank].contiguous() - - # Matrix multiply. - output_parallel = self.apply(input_parallel) - if self.base_layer.reduce_results and self.base_layer.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) - else: - output_ = output_parallel - - if not self.base_layer.skip_bias_add: - output = (output_ + self.base_layer.bias - if self.base_layer.bias is not None else output_) - output_bias = None - else: - output = output_ - output_bias = self.base_layer.bias - - if not self.base_layer.return_bias: - return output - - return output, output_bias - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - return type(source_layer) is RowParallelLinear - - -class LogitsProcessorWithLoRA(BaseLayerWithLoRA): - """ - LoRA wrapper for LogitsProcessor, with extra logic to handle the - application of the LoRA adapter and added LoRA vocabulary. - - Args: - base_layer: LogitsProcessor layer - hidden_size: hidden size of the model - dtype: data type of the model - device: device of the model - sharded_to_full_mapping: index mapping from sharded vocab to full vocab - received from base_layer.get_sharded_to_full_mapping(). If None, - no reindexing will be done. - """ - - def __init__(self, base_layer: LogitsProcessor, hidden_size: int, - dtype: torch.dtype, device: torch.device, - sharded_to_full_mapping: Optional[list[int]]) -> None: - super().__init__() - self.base_layer = base_layer - self.hidden_size = hidden_size - self.dtype = dtype - self.device = device - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.sharded_to_full_mapping = sharded_to_full_mapping - - @property - def logits_as_input(self): - return self.base_layer.logits_as_input - - @property - def vocab_size(self): - return self.base_layer.vocab_size - - @property - def scale(self): - return self.base_layer.scale - - @property - def soft_cap(self): - return self.base_layer.soft_cap - - @property - def use_all_gather(self): - return self.base_layer.use_all_gather - - @property - def org_vocab_size(self): - return self.base_layer.org_vocab_size - - @property - def include_gpu_probs_tensor(self): - return self.base_layer.include_gpu_probs_tensor - - @property - def should_modify_greedy_probs_inplace(self): - return self.base_layer.should_modify_greedy_probs_inplace - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - # TODO: Verify if this condition can be further relaxed - if 32000 < self.base_layer.vocab_size > 257024: - raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 257024") - self.lora_a_stacked = torch.zeros( - ( - max_loras, - 1, - lora_config.max_lora_rank, - self.hidden_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - # Pad for kernel compatibility - math.ceil(self.base_layer.vocab_size / - lora_config.lora_vocab_padding_size) * - lora_config.lora_vocab_padding_size, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.embeddings_tensors = torch.full( - (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), - fill_value=float("-inf"), - dtype=self.dtype, - device=self.device, - ) - if self.sharded_to_full_mapping is not None: - self.sharded_to_full_mapping_gpu = torch.tensor( - self.sharded_to_full_mapping, - device=self.device, - dtype=torch.long) - else: - self.sharded_to_full_mapping_gpu = None - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - self.embeddings_tensors[index] = float("-inf") - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if embeddings_tensor is not None: - self.embeddings_tensors[ - index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], - ] = embeddings_tensor - - def _get_logits( - self, - hidden_states: torch.Tensor, - lm_head: VocabParallelEmbedding, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: - # Get the logits for the next tokens. - logits = lm_head.quant_method.apply(lm_head, hidden_states) - if embedding_bias is not None: - logits += embedding_bias - - # Gather logits for TP - logits = self.base_layer._gather_logits(logits) - - if logits is None: - return None - - if self.sharded_to_full_mapping_gpu is not None: - # Reindex full logits tensor to ensure 1:1 mapping between - # index and token_id - # Example for: - # org_vocab_size = 4 - # added_vocab_size = 2 - # pad_to_size = 8 - # tp_size = 2 - - # indices: [0, 1, 2, 3, 4, 5, 6, 7] - # token_id: [0, 1, 4, -1, 2, 3, 5, -1] - - # Therefore, the mapping is expected to be: - # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, - # we get: - # indices: [0, 1, 2, 3, 4, 5, 6, 7] - # token_id: [0, 1, 2, 3, 4, 5, -1, -1] - logits = logits[:, self.sharded_to_full_mapping_gpu] - - lora_logits = torch.empty( - self.embeddings_tensors.shape[0] + 1, - self.embeddings_tensors.shape[1], - hidden_states.shape[0], - dtype=self.embeddings_tensors.dtype, - device=self.embeddings_tensors.device, - ) - torch.matmul(self.embeddings_tensors, - hidden_states.T, - out=lora_logits[:-1]) - - neg_inf, pos_inf = current_platform.get_infinity_values( - lora_logits.dtype) - - lora_logits[-1] = neg_inf - lora_logits = lora_logits.mT - indices_padded = self.punica_wrapper.sampler_indices_padded - - if current_platform.is_tpu(): - indices_padded = indices_padded[:logits.size(0)] - - lora_logits = (lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, - posinf=pos_inf, - neginf=neg_inf)) - - logits[:, - self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits - - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_logits( - logits, hidden_states, self.lora_a_stacked, - self.lora_b_stacked, 1.0) - - if not current_platform.can_update_inplace(): - logits = lora_output - - # Remove paddings in vocab (if any). - logits = logits[:, :self.base_layer.vocab_size] - return logits - - def forward(self, *args, **kwargs): - return type(self.base_layer).forward(self, *args, **kwargs) - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - # Special handling for the LogitsProcessor. - return False diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py new file mode 100644 index 0000000000000..d3bb145dc7bf8 --- /dev/null +++ b/vllm/lora/layers/__init__.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.lora.layers.column_parallel_linear import ( + ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA) +from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA +from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA +from vllm.lora.layers.row_parallel_linear import ( + RowParallelLinearWithLoRA, RowParallelLinearWithShardedLoRA) +from vllm.lora.layers.utils import LoRAMapping +from vllm.lora.layers.vocal_parallel_embedding import ( + VocabParallelEmbeddingWithLoRA) + +__all__ = [ + "BaseLayerWithLoRA", + "VocabParallelEmbeddingWithLoRA", + "LogitsProcessorWithLoRA", + "ColumnParallelLinearWithLoRA", + "ColumnParallelLinearWithShardedLoRA", + "MergedColumnParallelLinearWithLoRA", + "MergedColumnParallelLinearWithShardedLoRA", + "MergedQKVParallelLinearWithLoRA", + "MergedQKVParallelLinearWithShardedLoRA", + "QKVParallelLinearWithLoRA", + "QKVParallelLinearWithShardedLoRA", + "RowParallelLinearWithLoRA", + "RowParallelLinearWithShardedLoRA", + "ReplicatedLinearWithLoRA", + "LoRAMapping", +] diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py new file mode 100644 index 0000000000000..a80a033e39b40 --- /dev/null +++ b/vllm/lora/layers/base.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + + +class BaseLayerWithLoRA(nn.Module): + + def slice_lora_a( + self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py new file mode 100644 index 0000000000000..85a1f86ce6bf2 --- /dev/null +++ b/vllm/lora/layers/base_linear.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, cast + +import torch +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed.utils import divide +# yapf: disable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, ReplicatedLinear, + RowParallelLinear) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA +from .utils import _get_lora_device + + +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: LinearBase): + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None + + self.output_slices: tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + # + if isinstance(self.base_layer, ReplicatedLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_out_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) + + def reset_lora(self, index: int): + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + # Except for QKVParallelLinearWithLoRA and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) + + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + + return output + + @property + def weight(self) -> torch.Tensor: + + # unquantizedLinear + if hasattr(self.base_layer, "weight"): + return self.base_layer.weight + # Compressed Tensor + elif hasattr(self.base_layer, "weight_packed"): + return self.base_layer.weight_packed + # GPTQ/AWQ + elif hasattr(self.base_layer, "qweight"): + return self.base_layer.qweight + # marlin + elif hasattr(self.base_layer, "B"): + return self.base_layer.B + # HQQ marlin + elif hasattr(self.base_layer, "W_q"): + return self.base_layer.W_q + else: + raise ValueError(f"Unsupported base layer: {self.base_layer}") + + @property + def bias(self) -> Optional[torch.Tensor]: + if hasattr(self.base_layer, "bias"): + return self.base_layer.bias + else: + return None diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py new file mode 100644 index 0000000000000..658fd23165da0 --- /dev/null +++ b/vllm/lora/layers/column_parallel_linear.py @@ -0,0 +1,622 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) +from vllm.distributed.utils import divide +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear) +from vllm.platforms import current_platform + +from .base_linear import BaseLinearLayerWithLoRA +from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace + + +def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert (layer.n_slices == len(layer.lora_a_stacked) == len( + layer.lora_b_stacked) == len(layer.output_slices)) + if layer.lora_bias_stacked is not None: + assert layer.n_slices == len(layer.lora_bias_stacked) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + buffers, x, layer.lora_a_stacked, 1.0) + + if not current_platform.can_update_inplace(): + buffers = shrunk_buffers + + buffers = tensor_model_parallel_all_gather(buffers) + + lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__(base_layer) + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size = self.base_layer.output_size_per_partition + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. + if bias is None: + return bias + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + + if not self.base_layer.return_bias: + return output + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (e.g. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: + super().__init__(base_layer) + # There are two LoRA layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ + self.lora_config = lora_config + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + if lora_config.bias_enabled: + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + sliced_lora_b = [None] * self.n_slices + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + sliced_lora_b[i] = lora_b_i[:, + shard_size * shard_id:shard_size * + (shard_id + 1)] + return sliced_lora_b + + def slice_bias( + self, bias: list[Union[torch.Tensor, + None]]) -> list[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] + return bias + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + bias_q = bias[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + bias_k = bias[k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + bias_v = bias[v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias = torch.cat([bias_q, bias_k, bias_v], dim=1) + return bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) + + +# These following layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + #NOTE: lora_a contains 2 subloras, and each sublora could be None. + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[0] is not None else None, + lora_a[1][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[1] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): + """ + Differs from QKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): + """ + Differs from MergedQKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + # NOTE: lora_a contains 3 subloras, and each sublora could be None. + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][:, start_idx[0]:start_idx[0] + + shard_size[0]] if lora_a[0] is not None else None, + lora_a[1][:, start_idx[1]:start_idx[1] + + shard_size[1]] if lora_a[1] is not None else None, + lora_a[2][:, start_idx[2]:start_idx[2] + + shard_size[2]] if lora_a[2] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py new file mode 100644 index 0000000000000..a50dcfa748f2f --- /dev/null +++ b/vllm/lora/layers/logits_processor.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[list[int]]) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError("When using LoRA, vocab size must be " + "32000 >= vocab_size <= 257024") + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + + if current_platform.is_tpu() or current_platform.is_xpu(): + indices_padded = indices_padded[:logits.size(0)] + + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) + + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + if not current_platform.can_update_inplace(): + logits = lora_output + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False diff --git a/vllm/lora/layers/qkv_x_parallel_linear.py b/vllm/lora/layers/qkv_x_parallel_linear.py new file mode 100644 index 0000000000000..367482d0ee078 --- /dev/null +++ b/vllm/lora/layers/qkv_x_parallel_linear.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .base import BaseLayerWithLoRA + + +#TODO: Implement this +class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): + pass diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py new file mode 100644 index 0000000000000..3356297c1537a --- /dev/null +++ b/vllm/lora/layers/replicated_linear.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.model_executor.layers.linear import ReplicatedLinear + +from .base_linear import BaseLinearLayerWithLoRA + + +class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__(base_layer, ) + # To ensure interface compatibility, set to 1 always. + self.tp_size = 1 + self.output_size = self.base_layer.output_size + self.n_slices = 1 + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + # ReplicatedLinear should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py new file mode 100644 index 0000000000000..18ef6fd1ddd78 --- /dev/null +++ b/vllm/lora/layers/row_parallel_linear.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_reduce) +# yapf: disable +from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.platforms import current_platform + +from .base_linear import BaseLinearLayerWithLoRA +from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace + + +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + + self.tp_rank = get_tensor_model_parallel_rank() + # There is only one LoRA layer. + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + + shard_size = self.input_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + + +# The following layer is based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + if bias is None: + return bias + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + shard_size = self.lora_bias_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros( + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.lora_bias_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py new file mode 100644 index 0000000000000..772d32a44c22d --- /dev/null +++ b/vllm/lora/layers/utils.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + +import torch +import torch.nn as nn + + +@dataclass +class LoRAMapping: + index_mapping: tuple[int, ...] + prompt_mapping: tuple[int, ...] + is_prefill: bool = False + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # HQQ marlin + elif hasattr(base_layer, "W_q"): + return base_layer.W_q.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs["lora_config"].fully_sharded_loras) + + return dec diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py new file mode 100644 index 0000000000000..4d6218d970977 --- /dev/null +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.platforms import current_platform + +from .base import BaseLayerWithLoRA + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, + 1, 0) + + # NB: Don't use torch.narrow here. torch.narrow triggers some + # Dynamic Shape specialization in torch.compile + num_tokens = x.shape[0] + indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] + indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] + + full_lora_a_embeddings = F.embedding( + x + indices_1, + self.lora_a_stacked_2d, + ) + full_output = self.base_layer.forward(x + + (indices_0 * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], + -1, + ) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + + if not current_platform.can_update_inplace(): + full_output = lora_output + + return full_output.view_as(full_output_org) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + @property + def weight(self): + return self.base_layer.weight diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 3072047a2606c..25f90f2fa932b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,19 +4,14 @@ import math import os from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Callable, Optional, TypeVar, Union import regex as re import safetensors.torch import torch from torch import nn -from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, - AdapterModelManager) -from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, - get_adapter, list_adapters, - remove_adapter, set_adapter_mapping) -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights @@ -33,10 +28,25 @@ from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper from vllm.model_executor.utils import get_packed_modules_mapping -from vllm.utils import is_pin_memory_available +from vllm.utils import LRUCache, is_pin_memory_available logger = init_logger(__name__) +T = TypeVar("T") + + +class AdapterLRUCache(LRUCache[int, T]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: int, value: Optional[T]): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + _GLOBAL_LORA_ID = 0 @@ -57,7 +67,7 @@ def is_moe_model(model: nn.Module) -> bool: return False -class LoRAModel(AdapterModel): +class LoRAModel: """A LoRA fine-tuned model.""" def __init__( @@ -313,7 +323,7 @@ class LoRAModel(AdapterModel): weights_mapper=weights_mapper) -class LoRAModelManager(AdapterModelManager): +class LoRAModelManager: """A manager that manages multiple LoRA-fine-tuned models.""" def __init__( @@ -336,6 +346,11 @@ class LoRAModelManager(AdapterModelManager): vocab_size: the vocab size of the model. lora_config: the LoRA configuration. """ + self.model: SupportsLoRA = model + self._registered_adapters: dict[int, LoRAModel] = {} + # Dict instead of a set for compatibility with LRUCache. + self._active_adapters: dict[int, None] = {} + self.adapter_type = "LoRA" self.lora_config = lora_config self.device = device self.max_num_seqs = max_num_seqs @@ -347,9 +362,8 @@ class LoRAModelManager(AdapterModelManager): max_num_batched_tokens, max_batches=self.max_num_seqs, device=self.device, - max_loras=self.lora_config.max_loras) - - super().__init__(model) + max_loras=self.lora_config.max_loras, + ) self.supported_lora_modules = get_supported_lora_modules(self.model) assert self.supported_lora_modules, "No supported LoRA modules found in" @@ -370,7 +384,9 @@ class LoRAModelManager(AdapterModelManager): self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() self.model.lora_manager = self - self.adapter_type = 'LoRA' + + def __len__(self) -> int: + return len(self._registered_adapters) @property def capacity(self) -> int: @@ -669,28 +685,39 @@ class LoRAModelManager(AdapterModelManager): return lora_model.get_lora(org_module_name) def deactivate_adapter(self, adapter_id: int) -> bool: - return deactivate_adapter(adapter_id, self._active_adapters, - self._deactivate_adapter) + if adapter_id not in self._active_adapters: + return False + self._deactivate_adapter(adapter_id) + self._active_adapters.pop(adapter_id, None) + return True def add_adapter(self, adapter: LoRAModel) -> bool: logger.debug("Adding lora. Model id: %d, " "int id: %d", adapter.id, adapter.id) - return add_adapter(adapter, self._registered_adapters, self.capacity, - self._add_adapter) + if adapter.id in self._registered_adapters: + return False + if len(self._registered_adapters) >= self.capacity: + raise RuntimeError("No free adapter slots.") + self._add_adapter(adapter) + return True def set_adapter_mapping(self, mapping: LoRAMapping) -> None: - self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, - self._set_adapter_mapping) + if self._last_mapping != mapping: + self._set_adapter_mapping(mapping) + self._last_mapping = mapping def remove_adapter(self, adapter_id: int) -> bool: - return remove_adapter(adapter_id, self._registered_adapters, - self.deactivate_adapter) + self.deactivate_adapter(adapter_id) + if adapter_id not in self._registered_adapters: + return False + self._registered_adapters.pop(adapter_id, None) + return True - def list_adapters(self) -> dict[int, Any]: - return list_adapters(self._registered_adapters) + def list_adapters(self) -> dict[int, LoRAModel]: + return dict(self._registered_adapters) - def get_adapter(self, adapter_id: int) -> Optional[Any]: - return get_adapter(adapter_id, self._registered_adapters) + def get_adapter(self, adapter_id: int) -> Optional[LoRAModel]: + return self._registered_adapters.get(adapter_id) class LoRALRUCache(AdapterLRUCache[LoRAModel]): diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 8b8e5cb7d5fae..dc7249c386021 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -9,7 +9,7 @@ import os from dataclasses import MISSING, dataclass, field, fields from typing import Literal, Optional, Union -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.tensorizer import TensorizerConfig diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 572e39e0eced0..163bb412235ce 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase): add_inputs=True, **kwargs) + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + return self._sampler_indices_padded[:] + def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, @@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase): buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) + bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale) bgmv_expand(buffer, lora_b_stacked, y, - self.sampler_indices, + sampler_indices, add_inputs=True) return y.view_as(y_org) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 5bbba7830c1b1..523525d46f0b3 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -6,8 +6,6 @@ from typing import Optional import msgspec -from vllm.adapter_commons.request import AdapterRequest - class LoRARequest( msgspec.Struct, @@ -24,8 +22,6 @@ class LoRARequest( lora_int_id must be globally unique for a given adapter. This is currently not enforced in vLLM. """ - __metaclass__ = AdapterRequest - lora_name: str lora_int_id: int lora_path: str = "" @@ -35,6 +31,8 @@ class LoRARequest( tensorizer_config_dict: Optional[dict] = None def __post_init__(self): + if self.lora_int_id < 1: + raise ValueError(f"id must be > 0, got {self.lora_int_id}") if self.lora_local_path: warnings.warn( "The 'lora_local_path' attribute is deprecated " diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index ab0a9fbd255de..10ba390bffd9e 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -11,23 +11,23 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, from torch import nn from transformers import PretrainedConfig -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, - RowParallelLinearWithShardedLoRA) # being imported for _all_lora_classes below # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA) from vllm.model_executor.layers.linear import LinearBase @@ -239,7 +239,7 @@ def get_adapter_absolute_path(lora_path: str) -> str: except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError, HFValidationError): # Handle errors that may occur during the download - # Return original path instead instead of throwing error here + # Return original path instead of throwing error here logger.exception("Error downloading the HuggingFace model") return lora_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 248d2954f1ef4..e27b7d5fcf223 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -6,12 +6,7 @@ from typing import Any, Literal, Optional, Union import torch -from vllm.adapter_commons.utils import (add_adapter_worker, - apply_adapters_worker, - list_adapters_worker, - set_active_adapters_worker) -from vllm.adapter_commons.worker_manager import AbstractWorkerManager -from vllm.config import LoRAConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) @@ -22,7 +17,7 @@ from vllm.lora.utils import get_adapter_absolute_path logger = init_logger(__name__) -class WorkerLoRAManager(AbstractWorkerManager): +class WorkerLoRAManager: """WorkerLoRAManager that manages LoRA models on the worker side. Every request, the requested LoRAs will be loaded (unless they are already @@ -51,7 +46,7 @@ class WorkerLoRAManager(AbstractWorkerManager): self.vocab_size = vocab_size self.lora_config = lora_config self.max_position_embeddings = max_position_embeddings - super().__init__(device) + self.device = device # Lazily initialized by create_lora_manager. self._adapter_manager: LoRAModelManager @@ -164,19 +159,34 @@ class WorkerLoRAManager(AbstractWorkerManager): def set_active_adapters(self, requests: set[Any], mapping: Optional[Any]) -> None: - set_active_adapters_worker(requests, mapping, self._apply_adapters, - self._adapter_manager.set_adapter_mapping) + self._apply_adapters(requests) + if mapping is not None: + self._adapter_manager.set_adapter_mapping(mapping) def _apply_adapters(self, adapter_requests: set[Any]) -> None: - apply_adapters_worker(adapter_requests, self.list_adapters, - self._adapter_manager.adapter_slots, - self.remove_adapter, self.add_adapter) + existing_adapters = self.list_adapters() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > self._adapter_manager.adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + "than the number of GPU model slots " + f"({self._adapter_manager.adapter_slots}).") + requested_ids = set(models_map) + for adapter_id in existing_adapters - requested_ids: + self.remove_adapter(adapter_id) + for adapter_id in requested_ids - existing_adapters: + self.add_adapter(models_map[adapter_id]) def add_adapter(self, adapter_request: Any) -> bool: - return add_adapter_worker(adapter_request, self.list_adapters, - self._load_adapter, - self._adapter_manager.add_adapter, - self._adapter_manager.activate_adapter) + if adapter_request.adapter_id in self.list_adapters(): + return False + loaded_adapter = self._load_adapter(adapter_request) + loaded = self._adapter_manager.add_adapter(loaded_adapter) + self._adapter_manager.activate_adapter(loaded_adapter.id) + return loaded def remove_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.remove_adapter(adapter_id) @@ -185,7 +195,7 @@ class WorkerLoRAManager(AbstractWorkerManager): self._adapter_manager.remove_all_adapters() def list_adapters(self) -> set[int]: - return list_adapters_worker(self._adapter_manager.list_adapters) + return set(self._adapter_manager.list_adapters()) class LRUCacheWorkerLoRAManager(WorkerLoRAManager): diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 6b5a107396c92..e7eb8247d5efd 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -73,11 +73,6 @@ class CustomOp(nn.Module): # NOTE(woosuk): This is a placeholder for future extensions. return self.forward_native(*args, **kwargs) - def forward_neuron(self, *args, **kwargs): - # By default, we assume that Neuron ops are compatible with the - # PyTorch-native implementation. - return self.forward_native(*args, **kwargs) - def forward_oot(self, *args, **kwargs): # By default, we assume that OOT ops are compatible with the # PyTorch-native implementation. @@ -105,8 +100,6 @@ class CustomOp(nn.Module): return self.forward_tpu elif current_platform.is_xpu(): return self.forward_xpu - elif current_platform.is_neuron(): - return self.forward_neuron elif current_platform.is_out_of_tree(): return self.forward_oot else: diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index eb7e494e32861..235df1a77c5ce 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -95,13 +95,6 @@ class SiluAndMul(CustomOp): self.op(out, x) return out - def forward_neuron(self, x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] // 2 - x_reshaped = x.view(-1, x.shape[-1]) - s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d]) - result = s * x_reshaped[:, d:] - return result.view(*x.shape[:-1], d) - @CustomOp.register("mul_and_silu") class MulAndSilu(CustomOp): @@ -362,7 +355,7 @@ class ReLUSquaredActivation(CustomOp): return torch.square(F.relu(x)) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - #TODO : implement cuda kenrels + #TODO : implement cuda kernels return self.forward_native(x) @@ -461,7 +454,7 @@ class XIELU(CustomOp): ) return result.view(original_shape) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward_native(self, input: torch.Tensor) -> torch.Tensor: if self._xielu_cuda_obj is not None and input.is_cuda: if not torch._dynamo.is_compiling(): return self._xielu_cuda_fn(input) @@ -471,6 +464,9 @@ class XIELU(CustomOp): ) return self._xielu_python(input) + def forward_cuda(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_native(input) + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/vllm/model_executor/layers/fla/__init__.py b/vllm/model_executor/layers/fla/__init__.py new file mode 100644 index 0000000000000..0e89cf9f79439 --- /dev/null +++ b/vllm/model_executor/layers/fla/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang diff --git a/vllm/model_executor/layers/fla/ops/__init__.py b/vllm/model_executor/layers/fla/ops/__init__.py new file mode 100644 index 0000000000000..c19cc14ba6928 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/__init__.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule +from .layernorm_guard import RMSNormGated + +__all__ = [ + "RMSNormGated", + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py new file mode 100644 index 0000000000000..e7d295aff2392 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, + beta=beta, + g_cumsum=g, + cu_seqlens=cu_seqlens, + output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type='cuda') + def forward(ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len( + beta.shape + ) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2) + q, k, v, beta, g = map( + lambda x: rearrange(x, 'b h t ... -> b t h ...'), + (q, k, v, beta, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if initial_state is not None and initial_state.shape[0] != len( + cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1]**-0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, + use_qk_l2norm_in_kernel) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py new file mode 100644 index 0000000000000..34006f87f457b --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp +from .utils import is_nvidia_hopper, use_cuda_graph + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] for num_stages in [2, 3, 4] for BV in [32, 64] + ], + key=['H', 'K', 'V', 'BT', 'USE_G'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += (boh * H + i_h) * K * V + v += (bos * H + i_h) * V + k += (bos * Hg + i_h // (H // Hg)) * K + w += (bos * H + i_h) * K + if SAVE_NEW_VALUE: + v_new += (bos * H + i_h) * V + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), + (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), + (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), + (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), + (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, + b_h2.to(p_h2.dtype.element_ty), + boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, + b_h3.to(p_h3.dtype.element_ty), + boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), + (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, + b_h4.to(p_h4.dtype.element_ty), + boundary_check=(0, 1)) + + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), + (i_t * BT, i_v * BV), (BT, BV), + (1, 0)) if SAVE_NEW_VALUE else None + b_v_new = tl.zeros([BT, BV], dtype=tl.float32) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), + (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) + b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) + + if SAVE_NEW_VALUE: + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v_new, + b_v_new.to(p_v_new.dtype.element_ty), + boundary_check=(0, 1)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + b_v_new = b_v_new.to(k.dtype.element_ty) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v_new) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v_new) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v_new) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), + (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v_new) + + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), + (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), + (64, BV), (1, 0)) + tl.store(p_ht, + b_h2.to(p_ht.dtype.element_ty), + boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), + (64, BV), (1, 0)) + tl.store(p_ht, + b_h3.to(p_ht.dtype.element_ty), + boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), + (64, BV), (1, 0)) + tl.store(p_ht, + b_h4.to(p_ht.dtype.element_ty), + boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices( + cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len( + chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty( + N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta['BV']), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT) + return h, v_new, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py new file mode 100644 index 0000000000000..332751a1860a9 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .op import exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({ + 'BK': BK, + 'BV': BV + }, + num_warps=num_warps, + num_stages=num_stages) for BK in BKV_LIST + for BV in BKV_LIST for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), + (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), + (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), + (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + if FLA_GDN_FIX_BT: + BT = 64 + else: + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1]**-0.5 + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta['BV']), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return o diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000000..d1adc6978f245 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .op import exp + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_G': lambda args: args['g_cumsum'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + b_beta = tl.load(p_beta, boundary_check=(0, )) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), + (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), + (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * exp(b_g_diff) + + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), + (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g_cumsum (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. + Default: None + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + + B, T, Hg, K = k.shape + + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + g_cumsum=g_cumsum, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + ) + return A diff --git a/vllm/model_executor/layers/fla/ops/cumsum.py b/vllm/model_executor/layers/fla/ops/cumsum.py new file mode 100644 index 0000000000000..370a45fe16358 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/cumsum.py @@ -0,0 +1,226 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune(configs=[ + triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8] +], + key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE']) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ), + (i_t * BT, ), (BT, ), (0, )) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ), + (i_t * BT, ), (BT, ), (0, )) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + # [BT] + b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, )) + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune(configs=[ + triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST + for num_warps in [2, 4, 8] +], + key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE']) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), + (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), + (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), + (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), + (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2**(chunk_size.bit_length() - + 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid](g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices( + cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2**(chunk_size.bit_length() - + 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid](g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse) + return g + + +@input_guard +def chunk_local_cumsum(g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs) -> torch.Tensor: + if not head_first and g.shape[1] < g.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2) + if cu_seqlens is not None: + assert g.shape[ + 0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, + head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, + head_first, output_dtype) + else: + raise ValueError(f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise") diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py new file mode 100644 index 0000000000000..b278e37415748 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .op import exp + + +@triton.heuristics({ + 'USE_INITIAL_STATE': + lambda args: args['h0'] is not None, + 'IS_VARLEN': + lambda args: args['cu_seqlens'] is not None, + "IS_CONTINUOUS_BATCHING": + lambda args: args['ssm_state_indices'] is not None, + "IS_SPEC_DECODING": + lambda args: args['num_accepted_tokens'] is not None, +}) +@triton.jit(do_not_specialize=['N', 'T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.constexpr, # num of sequences + T: tl.constexpr, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl. + constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + p_g = g + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_init_state_token + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + + i_t).to(tl.int64) * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_g += HV + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if scale is None: + scale = k.shape[-1]**-0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/index.py b/vllm/model_executor/layers/fla/ops/index.py new file mode 100644 index 0000000000000..9eca32bc31a04 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/index.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +from vllm.triton_utils import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + indices = torch.cat([ + torch.arange(n) + for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() + ]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], + 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + return torch.cat([ + cu_seqlens.new_tensor([0]), + triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + ]).cumsum(-1) diff --git a/vllm/model_executor/layers/fla/ops/l2norm.py b/vllm/model_executor/layers/fla/ops/l2norm.py new file mode 100644 index 0000000000000..ef9788ceaf20e --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/l2norm.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.autotune(configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16, 32] +], + key=['D']) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.autotune(configs=[ + triton.Config({'BT': BT}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST +], + key=['D']) +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def l2norm_fwd(x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + + def grid(meta): + return (triton.cdiv(T, meta['BT']), ) + + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + ) + else: + l2norm_fwd_kernel1[(T, )]( + x, + y, + eps=eps, + D=D, + BD=BD, + ) + + return y.view(x_shape_og) diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py new file mode 100644 index 0000000000000..a733c6c81e369 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Tri Dao +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2024, Tri Dao. + +# ruff: noqa: E501 +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from vllm.triton_utils import tl, triton + +from .utils import input_guard + + +def rms_norm_ref(x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + upcast=True): + dtype = x.dtype + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + z = z.float() if z is not None else z + if z is not None and not norm_before_gate: + x = x * F.silu(z) + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * + weight) + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + if z is not None and norm_before_gate: + out *= F.silu(z) + return out.to(dtype) + + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, +}) +@triton.jit +def layer_norm_fwd_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def layer_norm_fwd( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor = None, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + is_rms_norm: bool = False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N, ) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N, ) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = torch.empty((ngroups * M, ), dtype=torch.float32, + device=x.device) if not is_rms_norm else None + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + layer_norm_fwd_kernel[grid](x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps) + return out, mean, rstd + + +class LayerNormFn(torch.autograd.Function): + + @input_guard + @staticmethod + def forward(ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.group_size = group_size + ctx.norm_before_gate = norm_before_gate + ctx.is_rms_norm = is_rms_norm + return y.reshape(x_shape_og) + + +def layernorm_fn(x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, + norm_before_gate, is_rms_norm) + + +def rmsnorm_fn(x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, + norm_before_gate, True) + + +class LayerNormGated(nn.Module): + + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return layernorm_fn(x, + self.weight, + self.bias, + z=z, + group_size=self.group_size, + eps=self.eps, + norm_before_gate=self.norm_before_gate) + + +class RMSNormGated(nn.Module): + + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return rmsnorm_fn(x, + self.weight, + self.bias, + z=z, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate) diff --git a/vllm/model_executor/layers/fla/ops/op.py b/vllm/model_executor/layers/fla/ops/op.py new file mode 100644 index 0000000000000..8c29434ca106a --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/op.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +from vllm.triton_utils import tl, tldevice, triton + +if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': + div = tldevice.fast_dividef + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + + @triton.jit + def div_normal(x, y): + return x / y + + div = div_normal + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +if not hasattr(tl, 'gather'): + + @triton.jit + def gather(src, index, axis, _builder=None): + # This is a fallback implementation when tl.gather is not supported + # In order to pass triton compiler, there is no actual gather operation + return src +else: + gather = tl.gather diff --git a/vllm/model_executor/layers/fla/ops/solve_tril.py b/vllm/model_executor/layers/fla/ops/solve_tril.py new file mode 100644 index 0000000000000..97cb0d800d411 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/solve_tril.py @@ -0,0 +1,365 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices +from .utils import input_guard + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), + (16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), + (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + b_A = -tl.where( + tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T - i_t * 16)): + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store(p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5] + ], + key=['H', 'BT', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_32x32_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices, + T, H: tl.constexpr, BT: tl.constexpr, + IS_VARLEN: tl.constexpr): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), + (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), + Ai_11, + input_precision='ieee') + tl.store(p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] for num_stages in [2, 3, 4, 5] + ], + key=['H', 'BT', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices, + T, H: tl.constexpr, BT: tl.constexpr, + IS_VARLEN: tl.constexpr): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), + (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), + (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), + (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), + (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), + (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), + (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), + (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), + (16, 16), (1, 0)) + p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), + (16, 16), (1, 0)) + p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), + (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) + + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), + Ai_11, + input_precision='ieee') + Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), + Ai_22, + input_precision='ieee') + Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), + Ai_33, + input_precision='ieee') + + Ai_31 = -tl.dot(Ai_33, + tl.dot(A_31, Ai_11, input_precision='ieee') + + tl.dot(A_32, Ai_21, input_precision='ieee'), + input_precision='ieee') + Ai_42 = -tl.dot(Ai_44, + tl.dot(A_42, Ai_22, input_precision='ieee') + + tl.dot(A_43, Ai_32, input_precision='ieee'), + input_precision='ieee') + Ai_41 = -tl.dot(Ai_44, + tl.dot(A_41, Ai_11, input_precision='ieee') + + tl.dot(A_42, Ai_21, input_precision='ieee') + + tl.dot(A_43, Ai_31, input_precision='ieee'), + input_precision='ieee') + + p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), + (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), + (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), + (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), + (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), + (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), + (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), + (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), + (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), + (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), + (16, 16), (1, 0)) + tl.store(p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_33, + Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_44, + Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_31, + Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_32, + Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_41, + Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_42, + Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_43, + Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + + fill_zeros = tl.zeros((16, 16), dtype=tl.float32) + p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), + (16, 16), (1, 0)) + p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), + (16, 16), (1, 0)) + p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), + (16, 16), (1, 0)) + p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), + (16, 16), (1, 0)) + p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), + (16, 16), (1, 0)) + p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), + (16, 16), (1, 0)) + tl.store(p_Ai_12, + fill_zeros.to(p_Ai_12.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_13, + fill_zeros.to(p_Ai_13.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_14, + fill_zeros.to(p_Ai_14.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_23, + fill_zeros.to(p_Ai_23.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_24, + fill_zeros.to(p_Ai_24.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + tl.store(p_Ai_34, + fill_zeros.to(p_Ai_34.dtype.element_ty, + fp_downcast_rounding="rtne"), + boundary_check=(0, 1)) + + +@input_guard +def solve_tril(A: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float) -> torch.Tensor: + """ + Compute the inverse of the lower triangular matrix + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, K] + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float` + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + + B, T, H, BT = A.shape + Ad = torch.empty(B, + T, + H, + 16, + device=A.device, + dtype=torch.float if BT != 16 else output_dtype) + + chunk_indices = prepare_chunk_indices( + cu_seqlens, 16) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + if BT == 16: + return Ad + + Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + return Ai diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py new file mode 100644 index 0000000000000..7fd90cee45d0e --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from enum import Enum +from typing import Any, Callable, Literal, Optional + +import torch + +from vllm.triton_utils import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache( + fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] + cache_size = 4 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries, cache_size + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \ + and all(a is b for a, b in zip(args, last_args)) \ + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): + cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [ + (args, kwargs, last_result) + ] + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard( + fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else + i.contiguous() for i in args) + contiguous_kwargs = { + k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) + for k, v in kwargs.items() + } + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return 'cpu' + + +@functools.cache +def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = get_available_device() if get_available_device() != 'hip' else 'cuda' +device_torch_lib = getattr(torch, device) +device_platform = _check_platform() + +is_amd = (device_platform == 'amd') +is_intel = (device_platform == 'intel') +is_nvidia = (device_platform == 'nvidia') +is_intel_alchemist = (is_intel + and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) +is_nvidia_hopper = (is_nvidia + and ('NVIDIA H' in torch.cuda.get_device_name(0) + or torch.cuda.get_device_capability()[0] >= 9)) +use_cuda_graph = (is_nvidia + and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i) + ['max_shared_mem'] for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/vllm/model_executor/layers/fla/ops/wy_fast.py b/vllm/model_executor/layers/fla/ops/wy_fast.py new file mode 100644 index 0000000000000..70374eb650642 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/wy_fast.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + +from .index import prepare_chunk_indices + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices, + T, H: tl.constexpr, Hg: tl.constexpr, + K: tl.constexpr, V: tl.constexpr, + BT: tl.constexpr, BK: tl.constexpr, + BV: tl.constexpr, IS_VARLEN: tl.constexpr): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to( + tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ), + (i_t * BT, ), (BT, ), (0, )) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T, ), (H, ), (i_t * BT, ), + (BT, ), (0, )) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), + (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0, )) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0, ))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), + (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), + (1, 0)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), + (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices( + cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3007643d7a288..6730f051e3d71 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) +from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.triton_utils import HAS_TRITON _config: Optional[dict[str, Any]] = None @@ -36,6 +37,7 @@ __all__ = [ "FusedMoEPermuteExpertsUnpermute", "FusedMoEActivationFormat", "FusedMoEPrepareAndFinalize", + "activation_without_mul", "override_config", "get_config", ] @@ -43,7 +45,6 @@ __all__ = [ if HAS_TRITON: # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa - import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 @@ -56,13 +57,12 @@ if HAS_TRITON: from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, fused_experts, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + TritonExperts, fused_experts, fused_topk, get_config_file_name, + grouped_topk) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) __all__ += [ - "fused_moe", "fused_topk", "fused_experts", "get_config_file_name", diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a5326dfe84f6d..e9dfb22bea27b 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from math import log2 from typing import Optional import torch @@ -7,9 +8,12 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( + deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used) @@ -24,35 +28,28 @@ def _silu_mul_fp8_quant_deep_gemm( y_q_ptr, # fp8 quantized activations (E, T, H) y_s_ptr, # 16-bit scales (E, T, G) counts_ptr, # int32 num tokens per expert (E) - # Sizes --------------------------------------------------------------- H: tl.constexpr, # hidden dimension (per output) GROUP_SIZE: tl.constexpr, # elements per group (usually 128) - # Strides for input (elements) --------------------------------------- stride_i_e, stride_i_t, stride_i_h, - # Strides for y_q (elements) ----------------------------------------- stride_yq_e, stride_yq_t, stride_yq_h, - # Strides for y_s (elements) ----------------------------------------- stride_ys_e, stride_ys_t, stride_ys_g, - # Stride for counts (elements) stride_counts_e, - # Numeric params ------------------------------------------------------ eps: tl.constexpr, fp8_min: tl.constexpr, fp8_max: tl.constexpr, use_ue8m0: tl.constexpr, - # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, NUM_STAGES: tl.constexpr, @@ -101,17 +98,15 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) -def silu_mul_fp8_quant_deep_gemm( +def silu_mul_fp8_quant_deep_gemm_cuda( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + num_parallel_tokens=16, group_size: int = 128, - eps: float = 1e-10, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales - - y has shape (E, T, 2*H). The first half of the last dimension is + y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. - Returns `(y_q, y_s)` where * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) @@ -120,22 +115,17 @@ def silu_mul_fp8_quant_deep_gemm( E, T, H2 = y.shape assert H2 % 2 == 0, "last dim of y must be even (2*H)" H = H2 // 2 - G = H // group_size - assert H % group_size == 0, "H must be divisible by group_size" - assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ - "tokens_per_expert must be shape (E,)" + G = (H + group_size - 1) // group_size + assert H % 8 == 0, "H must be divisible by 8" + assert group_size == 128, "H must be divisible by 8" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) - # allocate outputs fp8_dtype = torch.float8_e4m3fn y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) - # strides (elements) - stride_i_e, stride_i_t, stride_i_h = y.stride() - stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() - - # desired scale strides (elements): (T*G, 1, T) stride_ys_e = T * G stride_ys_t = 1 stride_ys_g = T @@ -144,68 +134,100 @@ def silu_mul_fp8_quant_deep_gemm( dtype=torch.float32, device=y.device) - stride_cnt_e = tokens_per_expert.stride()[0] + use_ue8m0 = is_deep_gemm_e8m0_used() - # Static grid over experts and H-groups. - # A loop inside the kernel handles the token dim - grid = (E * G, ) + if E <= 16: + max_empirical_parallelism = 64 + elif E <= 32: + max_empirical_parallelism = 16 + else: + max_empirical_parallelism = 4 - f_info = torch.finfo(fp8_dtype) - fp8_max = f_info.max - fp8_min = f_info.min + # We never want to launch more than Tx number of threads + # This computes the clip. + num_parallel_tokens = max( + 1, + min(max_empirical_parallelism, 2**int(log2(min(num_parallel_tokens, + T))))) + cuda_arch = current_platform.get_device_capability( + device_id=y.device.index).to_int() - _silu_mul_fp8_quant_deep_gemm[grid]( - y, - y_q, - y_s, - tokens_per_expert, - H, - group_size, - stride_i_e, - stride_i_t, - stride_i_h, - stride_yq_e, - stride_yq_t, - stride_yq_h, - stride_ys_e, - stride_ys_t, - stride_ys_g, - stride_cnt_e, - eps, - fp8_min, - fp8_max, - is_deep_gemm_e8m0_used(), - BLOCK=group_size, - NUM_STAGES=4, - num_warps=1, - ) + if cuda_arch >= 80: + torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(y, tokens_per_expert, + y_q, y_s, group_size, + use_ue8m0, + num_parallel_tokens) + else: + # Default to triton if not on cuda or if arch is too old + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G, ) + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, + ) + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + eps: float = 1e-10 + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) return y_q, y_s class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - # The Deep Gemm kernels only support block size of 128 - DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] - - def __init__(self, - max_num_tokens: int, - num_dispatchers: int, - block_shape: list[int], - per_act_token_quant=False): + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + quant_config: FusedMoEQuantConfig, + ): """ max_num_tokens: Maximum number of tokens from a DP Rank num_dispatchers: The number of DP dispatchers. - block_shape: Block quantization block shape. - per_act_token_quant: Per activation token quantization flag. + quant_config: Quantization configuration """ - super().__init__( - FusedMoEQuantConfig( - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE + super().__init__(quant_config) + assert self.block_shape == deep_gemm_block_shape() self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -263,12 +285,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -294,11 +311,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # for the M expectation of each batch, correctly setting this value # may lead to better performance. expected_m = max_num_tokens - fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), + fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale), workspace1, expert_num_tokens, expected_m) - a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, - expert_num_tokens) + a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda( + workspace1, expert_num_tokens) - fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, - expert_num_tokens, expected_m) + fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale), + output, expert_num_tokens, expected_m) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 89d7412ee2236..8b9070f098898 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -8,55 +8,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( + deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, - max_num_tokens: int, - num_dispatchers: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, - allow_deep_gemm: bool = False): - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" - - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - block_shape=block_shape, - per_act_token_quant=per_act_token_quant, - )) + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + quant_config: FusedMoEQuantConfig, + allow_deep_gemm: bool = False, + ): + super().__init__(quant_config) self.batched_triton_experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=num_dispatchers, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape, + quant_config=self.quant_config, ) - self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 - and self.block_shape - == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) + self.allow_deep_gemm = (allow_deep_gemm + and self.quant_config.use_fp8_w8a8 and + self.block_shape == deep_gemm_block_shape()) self.batched_deep_gemm_experts = BatchedDeepGemmExperts( max_num_tokens=max_num_tokens, num_dispatchers=num_dispatchers, - block_shape=self.block_shape, # type: ignore[arg-type] + quant_config=self.quant_config, ) if self.allow_deep_gemm else None assert (self.batched_deep_gemm_experts is not None @@ -143,12 +125,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -158,7 +135,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids, - activation, global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, - workspace2, expert_tokens_meta, + activation, global_num_experts, expert_map, a1q_scale, + workspace13, workspace2, expert_tokens_meta, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 0b501cd87fb5d..742df3dbdc6af 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1,103 +1,322 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import torch -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) import vllm.envs as envs from vllm.config import ParallelConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.utils import cdiv +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.utils import cdiv, has_triton_kernels from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +if TYPE_CHECKING and has_triton_kernels: + from triton_kernels.matmul_ogs import PrecisionConfig + logger = init_logger(__name__) -def _get_quant_config_quantization_args( - quant_config: Optional[QuantizationConfig], - prop_name: str, -) -> Optional[QuantizationArgs]: - if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') - and "Linear" in quant_config.target_scheme_map and - "input_activations" in quant_config.target_scheme_map["Linear"]): - return quant_config.target_scheme_map["Linear"].get(prop_name) - else: - return None - - -def get_quant_config_input_quant( - quant_config: Optional[QuantizationConfig] -) -> Optional[QuantizationArgs]: - return _get_quant_config_quantization_args(quant_config, - "input_activations") - - -def get_quant_config_weight_quant( - quant_config: Optional[QuantizationConfig] -) -> Optional[QuantizationArgs]: - return _get_quant_config_quantization_args(quant_config, "weights") - - -def get_config_quant_dtype( - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - use_mxfp4_w4a4: bool, -) -> Union[None, torch.dtype, str]: +def _get_config_dtype_str( + dtype: torch.dtype, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, +) -> Optional[str]: + """ + Return a string used to construct the filename that contains the + tuning info for a particular quantization scheme. See + try_get_optimal_moe_config in fused_moe.py. + """ if use_fp8_w8a8: - return torch.float8_e4m3fn - elif use_int8_w8a8: - return torch.int8 + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w4a16" elif use_mxfp4_w4a4: - return "mxfp4" + return "mxfp4_w4a4" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" return None +def _quant_flags_to_group_shape( + quant_dtype: Union[torch.dtype, str, None], + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]], +) -> tuple[Optional[GroupShape], Optional[GroupShape]]: + """ + Convert MoE quantization flags into more generic GroupShapes. + """ + a_shape: Optional[GroupShape] + w_shape: Optional[GroupShape] + if block_shape is not None: + assert not per_act_token_quant + assert not per_out_ch_quant + # TODO(bnell): this is not quite right for activations since first + # dim should be 1. + a_shape = GroupShape(row=block_shape[0], col=block_shape[1]) + w_shape = GroupShape(row=block_shape[0], col=block_shape[1]) + else: + w_shape = None + a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR + + if per_act_token_quant: + a_shape = GroupShape.PER_TOKEN + + if per_out_ch_quant: + w_shape = GroupShape.PER_TOKEN + + return a_shape, w_shape + + +@dataclass +class FusedMoEQuantDesc: + """ + A quantization descriptor for fused MoE ops. This class can describe + either activations or weights. + """ + + # The quantized type of this parameters. None means unquantized or + # already quantized. + # TODO (bnell): use scalar_type instead of Union. + dtype: Union[torch.dtype, str, None] = None + + # A field that describes the quantization group shape, from quant_utils.py. + # * (-1, -1) for per-tensor quantization + # * (1, -1) for per-row quantization + # * (-1, 1) for per-column quantization + # * (128, 128) for 128x128 deepseek style block quantization + # * (1, 128) for deepseek style activation quantization + # (i.e. per-token-per-group) + shape: Optional[GroupShape] = None + + # Quantization scales. + # TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc? + scale: Union[torch.Tensor, "PrecisionConfig", None] = None + + # Quantization alphas or gscales, used for nvfp4 types. + # TODO(bnell): put some of these in subclasses + alpha_or_gscale: Optional[torch.Tensor] = None + + # Zero points for int4/int8 types + zp: Optional[torch.Tensor] = None + + # Biases for GPT triton MoE + bias: Optional[torch.Tensor] = None + + +# TODO(bnell): have subclasses for specific moe methods? +# e.g. for specific arguments bias, precision, etc. @dataclass class FusedMoEQuantConfig: - # The post quantization activation type. - # TODO (bnell): use scalar_type instead of Union. - quant_dtype: Union[torch.dtype, str, None] = None - per_act_token_quant: bool = False - per_out_ch_quant: bool = False - block_shape: Optional[list[int]] = None + """ + The FusedMoEQuantConfig contains all the quantization parameters for + a single FusedMoEMethodBase operation. It consists of four + FusedMoEQuantDescs, one for each activation and set of weights. - # TODO: add col major flag? - # add detailed quant info for input, intermediates, weights, etc? + Each FusedMoEMethodBase must implement a get_fused_moe_quant_config + method to construct a FusedMoEQuantConfig for use with that class. + + FusedMoEQuant configs are only used for modular kernels, fused_experts + (from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and + triton_kernel_moe_forward. Other MoE methods can ignore the + FusedMoEQuantConfig (for now) and hardcode it to None. + + There are currently some restrictions on what can be expressed: + - Most MoE ops only support similar quantization strategies for + each parameter, e.g. both weights must have the same GroupShape + and both activations must share the same GroupShape. One exception to + this is the cutlass moe which allows per channel quantization on the + outputs. Note: this restrictions are not always rigorously checked. + - Not all fused MoE functions support all the parameters, e.g. zero points, + global scales, alphas and biases are not universally supported. + - Fully general GroupShapes are not allowed. Activations only support + per token, per tensor or K-blocked. + - Weights are not required to have a GroupShape since they have already + been quantized. + + Other notes: + - PrecisionConfigs are specific to GPT OSS Triton. + - As a follow up it would probably make sense to subclass FusedMoEQuantDesc + or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses + so that only the required quantization parameters are used/stored. + """ + + # TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking + _a1: FusedMoEQuantDesc + _a2: FusedMoEQuantDesc + _w1: FusedMoEQuantDesc + _w2: FusedMoEQuantDesc def __post_init__(self): assert (not self.per_act_token_quant or self.block_shape is None), "illegal quantization" + # + # Convenience accessors for various properties. + # + + @property + def quant_dtype(self) -> Union[torch.dtype, str, None]: + return self._a1.dtype + @property def is_quantized(self) -> bool: return self.quant_dtype is not None @property def is_per_act_token(self) -> bool: - return self.per_act_token_quant + return self._a1.shape == GroupShape.PER_TOKEN + + @property + def per_act_token_quant(self) -> bool: + return self._a1.shape == GroupShape.PER_TOKEN + + @property + def per_out_ch_quant(self) -> bool: + return self._w1.shape == GroupShape.PER_TOKEN + + @property + def is_per_tensor(self) -> bool: + return self._a1.shape == GroupShape.PER_TENSOR + + @property + def block_shape(self) -> Optional[list[int]]: + if (self._a1.shape is not None + and self._a1.shape != GroupShape.PER_TENSOR + and self._a1.shape != GroupShape.PER_TOKEN): + return [self._a1.shape.row, self._a1.shape.col] + else: + return None @property def is_block_quantized(self) -> bool: return self.block_shape is not None @property - def is_per_tensor(self) -> bool: - return not self.per_act_token_quant and self.block_shape is None + def a1_scale(self) -> Optional[torch.Tensor]: + assert self._a1.scale is None or isinstance(self._a1.scale, + torch.Tensor) + return self._a1.scale + + @property + def a1_gscale(self) -> Optional[torch.Tensor]: + return self._a1.alpha_or_gscale + + @property + def a2_scale(self) -> Optional[torch.Tensor]: + assert self._a2.scale is None or isinstance(self._a2.scale, + torch.Tensor) + return self._a2.scale + + @property + def a2_gscale(self) -> Optional[torch.Tensor]: + return self._a2.alpha_or_gscale + + @property + def w1_scale(self) -> Optional[torch.Tensor]: + assert self._w1.scale is None or isinstance(self._w1.scale, + torch.Tensor) + return self._w1.scale + + @property + def w1_zp(self) -> Optional[torch.Tensor]: + return self._w1.zp + + @property + def w1_bias(self) -> Optional[torch.Tensor]: + return self._w1.bias + + @property + def w1_precision(self) -> Optional["PrecisionConfig"]: + assert self._w1.scale is None or isinstance(self._w1.scale, + PrecisionConfig) + return self._w1.scale + + @property + def g1_alphas(self) -> Optional[torch.Tensor]: + return self._w1.alpha_or_gscale + + @property + def w2_scale(self) -> Optional[torch.Tensor]: + assert self._w2.scale is None or isinstance(self._w2.scale, + torch.Tensor) + return self._w2.scale + + @property + def w2_zp(self) -> Optional[torch.Tensor]: + return self._w2.zp + + @property + def w2_bias(self) -> Optional[torch.Tensor]: + return self._w2.bias + + @property + def w2_precision(self) -> Optional["PrecisionConfig"]: + assert self._w2.scale is None or isinstance(self._w2.scale, + PrecisionConfig) + return self._w2.scale + + @property + def g2_alphas(self) -> Optional[torch.Tensor]: + return self._w2.alpha_or_gscale + + @property + def use_fp8_w8a8(self) -> bool: + return self.quant_dtype == torch.float8_e4m3fn + + @property + def use_int8_w8a8(self) -> bool: + return self.quant_dtype == torch.int8 + + @property + def use_int8_w8a16(self) -> bool: + return (self._a1.dtype is None and self._w1.dtype == torch.int8) + + @property + def use_int4_w4a16(self) -> bool: + return (self._a1.dtype is None and self._w1.dtype == "int4") + + @property + def use_mxfp4_w4a4(self) -> bool: + return self.quant_dtype == "mxfp4" + + @property + def use_nvfp4_w4a4(self) -> bool: + return self.quant_dtype == "nvfp4" + + def config_name(self, dtype: torch.dtype) -> Optional[str]: + """ + Return a string used to construct the filename that contains the + tuning info for a particular quantization scheme. See + try_get_optimal_moe_config in fused_moe.py. + """ + return _get_config_dtype_str( + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + use_mxfp4_w4a4=self.use_mxfp4_w4a4, + dtype=dtype, + ) def scale_shape( self, max_tokens: int, hidden_dim: int, ) -> Optional[tuple[int, int]]: + """ + Construct the proper activation scale shape for this + config. + """ if self.is_quantized: if self.is_block_quantized: assert self.block_shape is not None @@ -117,6 +336,10 @@ class FusedMoEQuantConfig: max_tokens: int, hidden_dim: int, ) -> Optional[tuple[int, int, int]]: + """ + Construct the proper activation batched scale shape for this + config, e.g. (num experts, *scale_shape). + """ if self.is_quantized: scale_shape = self.scale_shape(max_tokens, hidden_dim) assert scale_shape is not None @@ -126,38 +349,218 @@ class FusedMoEQuantConfig: @staticmethod def make( - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + quant_dtype: Union[torch.dtype, str, None] = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: Optional[list[int]] = None, + w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, + w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + g1_alphas: Optional[torch.Tensor] = None, + g2_alphas: Optional[torch.Tensor] = None, + a1_gscale: Optional[torch.Tensor] = None, + a2_gscale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, ) -> "FusedMoEQuantConfig": - assert sum([ - int(flag) for flag in [ - use_fp8_w8a8, - use_int8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - use_mxfp4_w4a4, - ] - ]) <= 1, "Quantization flags are mutually exclusive." + """ + General builder function for a FusedMoEQuantConfig. + - quant_dtype: Optional quantization type. None if activations are + unquantized or quantized prior to calling. Note: "nvfp4" and + "mxfp4" are the only valid string values for quant_dtype. + - per_act_token_quant: Activations have per token quantization. + - per_out_ch_quant: Outputs have per channel quantization. (only + for cutlass). + - block_shape: Optional block size for block-wise quantization. + Incompatible with per_act_token and per_out_ch quant. + - w1_scale: Optional scale to be used for w1. + - w2_scale: Optional scale to be used for w2. + - a1_scale: Optional scale to be used for a1. + - a2_scale: Optional scale to be used for a2. + - g1_alphas: Optional global quantization scales for w1 (for nvfp4). + - g2_alphas: Optional global quantization scales for w2 (for nvfp4). + - a1_gscale: Optional global quantization scales for a1 (for nvfp4). + - a2_gscale: Optional global quantization scales for a2 (for nvfp4). + - w1_bias: Optional biases for w1 (GPT OSS Triton). + - w2_bias: Optional biases for w1 (GPT OSS Triton). + - w1_zp: Optional w1 zero points for int4/int8 quantization. + - w2_zp: Optional w2 zero points for int4/int8 quantization. + """ + assert (not isinstance(quant_dtype, str) or quant_dtype == "nvfp4" + or quant_dtype == "mxfp4") + a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype, + per_act_token_quant, + per_out_ch_quant, + block_shape) + quant_config = FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale), + _a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale), + _w1=FusedMoEQuantDesc(quant_dtype, w_shape, w1_scale, g1_alphas, + w1_zp, w1_bias), + _w2=FusedMoEQuantDesc(quant_dtype, w_shape, w2_scale, g2_alphas, + w2_zp, w2_bias), + ) + assert quant_config.per_act_token_quant == per_act_token_quant + assert quant_config.per_out_ch_quant == per_out_ch_quant + assert quant_config.block_shape == block_shape + return quant_config - quant_dtype = get_config_quant_dtype( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - ) - return FusedMoEQuantConfig( - quant_dtype, - per_act_token_quant, - per_out_ch_quant, - block_shape, - ) + +def fp8_w8a8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for fp8 activations and fp8 weights. + """ + return FusedMoEQuantConfig.make(torch.float8_e4m3fn, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape) + + +def int8_w8a8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + per_act_token_quant: bool = False, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for int8 activations and int8 weights. + """ + return FusedMoEQuantConfig.make( + torch.int8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=False, + block_shape=None, + ) + + +def mxfp4_w4a4_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and mxfp4 weights. + """ + return FusedMoEQuantConfig.make( + "mxfp4", + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + per_act_token_quant=False, + per_out_ch_quant=False, + block_shape=block_shape, + ) + + +def nvfp4_moe_quant_config( + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and nvp4 weights. + """ + return FusedMoEQuantConfig.make( + "nvfp4", + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + per_act_token_quant=False, + per_out_ch_quant=False, + block_shape=None, + ) + + +def int4_w4a16_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for 16-bit float activations and int4 weights. + Note: Activations are pre-quantized. + """ + group_shape = GroupShape(*block_shape) if block_shape is not None else None + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(shape=group_shape), + _a2=FusedMoEQuantDesc(shape=group_shape), + _w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp), + _w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp), + ) + + +def int8_w8a16_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + block_shape: Optional[list[int]] = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for 16-bit float activations and int8 weights. + Note: Activations are pre-quantized. + """ + group_shape = GroupShape(*block_shape) if block_shape is not None else None + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(shape=group_shape), + _a2=FusedMoEQuantDesc(shape=group_shape), + _w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp), + _w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp), + ) + + +def biased_moe_quant_config( + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], +) -> FusedMoEQuantConfig: + """ + Construct a quant config for unquantized activations with biases. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(), + _a2=FusedMoEQuantDesc(), + _w1=FusedMoEQuantDesc(bias=w1_bias), + _w2=FusedMoEQuantDesc(bias=w2_bias), + ) + + +# A FusedMoEQuantConfig constant for an unquantized MoE op. +FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make() @dataclass @@ -315,8 +718,6 @@ class FusedMoEConfig: # The activation type. in_dtype: torch.dtype - quant_config: Optional[FusedMoEQuantConfig] = None - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE has_bias: bool = False @@ -328,34 +729,6 @@ class FusedMoEConfig: assert self.max_num_tokens > 0 - @property - def quant_dtype(self) -> Union[torch.dtype, str, None]: - if self.quant_config is not None: - return self.quant_config.quant_dtype - else: - return None - - @property - def block_shape(self) -> Optional[list[int]]: - if self.quant_config is not None: - return self.quant_config.block_shape - else: - return None - - @property - def per_act_token_quant(self) -> bool: - if self.quant_config is not None: - return self.quant_config.per_act_token_quant - else: - return False - - @property - def per_out_ch_quant(self) -> bool: - if self.quant_config is not None: - return self.quant_config.per_out_ch_quant - else: - return False - @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -401,97 +774,6 @@ class FusedMoEConfig: """ Whether to use FlashInfer cutlass kernels for NVFP4 MoE. """ - return (self.quant_config is not None - and self.quant_config.quant_dtype == "nvfp4" - and envs.VLLM_USE_FLASHINFER_MOE_FP4 + return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and has_flashinfer_cutlass_fused_moe() and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") - - @staticmethod - def make( - num_experts: int, - experts_per_token: int, - hidden_dim: int, - num_local_experts: int, - moe_parallel_config: FusedMoEParallelConfig, - in_dtype: torch.dtype, - max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config: Optional[Union[FusedMoEQuantConfig, - QuantizationConfig]] = None, - has_bias: bool = False, - ) -> "FusedMoEConfig": - - _quant_config: Optional[FusedMoEQuantConfig] = None - - if quant_config is not None and isinstance(quant_config, - QuantizationConfig): - if hasattr(quant_config, 'weight_block_size'): - block_shape = quant_config.weight_block_size - else: - block_shape = None - per_act_token_quant = False - per_out_ch_quant = False - quant_dtype: Union[torch.dtype, str, None] = None - - input_quant = get_quant_config_input_quant(quant_config) - weight_quant = get_quant_config_weight_quant(quant_config) - - if input_quant is not None: - per_act_token_quant = (input_quant.strategy - == QuantizationStrategy.TOKEN - if input_quant is not None else False) - - if input_quant.num_bits == 8: - if input_quant.type == QuantizationType.FLOAT: - quant_dtype = torch.float8_e4m3fn - elif input_quant.type == QuantizationType.INT: - quant_dtype = torch.int8 - - from vllm.model_executor.layers.quantization.fp8 import Fp8Config - if quant_dtype is None and isinstance(quant_config, Fp8Config): - quant_dtype = torch.float8_e4m3fn - - from vllm.model_executor.layers.quantization.mxfp4 import ( - Mxfp4Config) - if (quant_dtype is None and isinstance(quant_config, Mxfp4Config) - and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): - quant_dtype = "mxfp8" - - from vllm.model_executor.layers.quantization.modelopt import ( - ModelOptNvFp4Config) - if quant_dtype is None and isinstance(quant_config, - ModelOptNvFp4Config): - quant_dtype = "nvfp4" - - if weight_quant is not None: - per_out_ch_quant = ( - weight_quant.strategy == QuantizationStrategy.CHANNEL) - - if quant_dtype is not None: - _quant_config = FusedMoEQuantConfig( - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - ) - else: - _quant_config = FusedMoEQuantConfig() - if moe_parallel_config.dp_size > 1: - logger.warning_once("MoE DP setup unable to determine " - "quantization scheme or unsupported " - "quantization type. This model will " - "not run with DP enabled.") - else: - _quant_config = quant_config - - return FusedMoEConfig( - num_experts=num_experts, - experts_per_token=experts_per_token, - hidden_dim=hidden_dim, - num_local_experts=num_local_experts, - moe_parallel_config=moe_parallel_config, - in_dtype=in_dtype, - quant_config=_quant_config, - max_num_tokens=max_num_tokens, - has_bias=has_bias, - ) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..e5059358c91e3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..db1b6e98df469 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json index b9dc2d71f6dcf..1bbb8aa613996 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -9,16 +9,16 @@ }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -26,15 +26,15 @@ "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -42,7 +42,7 @@ "24": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -53,12 +53,12 @@ "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "48": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 @@ -82,10 +82,10 @@ "128": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "256": { "BLOCK_SIZE_M": 16, @@ -98,8 +98,8 @@ "512": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, @@ -107,7 +107,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, @@ -123,15 +123,15 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4 }, "3072": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..8fb4947d62ab2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..bdbaf3811c939 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..6e17bcd214748 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..aa7610cd75e77 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..df920e8b39ba8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..e8fe8ea67f246 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..0baf13cb6a5c5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json index 307c9240938c5..c7998718dab4c 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json @@ -18,18 +18,18 @@ "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, @@ -58,7 +58,7 @@ "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 @@ -74,73 +74,73 @@ "96": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "128": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2 + "num_stages": 4 }, "256": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, "num_stages": 4 }, "1024": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 4 }, "1536": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "2048": { - "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 5 + "num_warps": 4, + "num_stages": 3 }, - "3072": { - "BLOCK_SIZE_M": 128, + "2048": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 }, "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..cc853947c19f5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..b4e736bec9b65 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..bb71005a72bc5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..ac53df14ce846 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..f1ed617d6308f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..e72282dc5bcd9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..4fc4868eaa85a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..d70adca05e779 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000000000..0f5867fea5f89 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..d104aa5167b22 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..22e3d09676d06 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} + diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..94408e279b656 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..9f4c3cbc9b8a9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..20146f53a6eba --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..d0140252594f5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..8bac7af0c2dac --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..cc1427c139e3e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..68649395a23ed --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..2f0b45014e863 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..5d69efe9ed5f9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..5910027e17f9b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..564ff499d43c4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..a68c83147eeb3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..e55df46b40269 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json new file mode 100644 index 0000000000000..a0855a921f3f6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000000000..5dd1a8e19c2ce --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json new file mode 100644 index 0000000000000..d5b6d02123d71 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 95d23ec0346c1..957ffca0d1246 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, out_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - per_out_ch_quant: bool, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig( - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - )) + assert quant_config.use_fp8_w8a8 + super().__init__(quant_config) self.out_dtype = out_dtype self.ab_strides1 = ab_strides1 self.ab_strides2 = ab_strides2 @@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" - assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" expert_num_tokens = None if expert_tokens_meta is not None: @@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): in_dtype = hidden_states.dtype run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable, - global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, - self.c_strides2, workspace13, workspace2, expert_num_tokens, + global_num_experts, expert_map, self.w1_scale, self.w2_scale, + a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2, + self.c_strides1, self.c_strides2, workspace13, workspace2, + expert_num_tokens, self.out_dtype if self.out_dtype is not None else in_dtype, self.per_act_token_quant, self.per_out_ch_quant, use_batched_format, topk_weights) @@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): def __init__( self, out_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - per_out_ch_quant: bool, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): super().__init__( out_dtype, - per_act_token_quant, - per_out_ch_quant, ab_strides1, ab_strides2, c_strides1, c_strides2, - block_shape, + quant_config, ) @property @@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): max_experts_per_worker: int, num_dispatchers: int, out_dtype: Optional[torch.dtype], - per_act_token_quant: bool, - per_out_ch_quant: bool, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): super().__init__( out_dtype, - per_act_token_quant, - per_out_ch_quant, ab_strides1, ab_strides2, c_strides1, c_strides2, - block_shape, + quant_config, ) assert max_experts_per_worker > 0 self.max_experts_per_worker = max_experts_per_worker @@ -414,16 +395,12 @@ def cutlass_moe_fp8( w2_q: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, - per_act_token: Optional[bool] = None, + quant_config: FusedMoEQuantConfig, activation: str = "silu", - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, @@ -475,10 +452,18 @@ def cutlass_moe_fp8( Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ - if per_act_token is None: - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - per_out_ch = w1_scale.numel() != w1_q.size(0) + assert quant_config is not None + + if quant_config.a1_scale is not None: + assert (quant_config.per_act_token_quant == + quant_config.a1_scale.numel() != 1) + if quant_config.a2_scale is not None: + assert (quant_config.per_act_token_quant == + quant_config.a2_scale.numel() != 1) + + assert (quant_config.w1_scale is None + or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1) + == w1_q.size(1)))) num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( 0) @@ -487,12 +472,11 @@ def cutlass_moe_fp8( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp8( out_dtype=a.dtype, - per_act_token_quant=per_act_token, - per_out_ch_quant=per_out_ch, ab_strides1=ab_strides1, ab_strides2=ab_strides2, c_strides1=c_strides1, c_strides2=c_strides2, + quant_config=quant_config, ), ) @@ -502,14 +486,9 @@ def cutlass_moe_fp8( w2_q, topk_weights, topk_ids, - False, - activation, - num_experts, - expert_map, - w1_scale, - w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + activation=activation, + global_num_experts=num_experts, + expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -542,7 +521,7 @@ def run_cutlass_moe_fp4( ) -> None: """ MoE implementation for FP4 Inputs - + # Gemm 1 a: Input tensor: [m, k] (half/bfloat16) a1_gscale: Activation scale per expert: [e] (float32) @@ -552,16 +531,16 @@ def run_cutlass_moe_fp4( full precision) w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) (Block size = 16 for NVFP4) - + # Gemm 2 a2_gscale: Activation scale per expert: [e] w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 - + topk_weights: [m, topk] dtype: float8 topk_ids: [m, topk] dtype: float8 - + m, n, k: Unquantized weight shapes, dtype: int e: number of experts, dtype: int @@ -652,42 +631,21 @@ def run_cutlass_moe_fp4( return +# Split into batched and non-batched class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, max_experts_per_worker: int, out_dtype: torch.dtype, - per_act_token_quant: bool, - per_out_ch_quant: bool, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, use_batched_format: bool = False, ): - super().__init__( - # NVFP4 requires two levels of quantization, which involves - # computing some scaling factors dynamically. This makes it - # incompatible with the typical prepare -> MoE -> finalize - # pipeline. Move the quantization logic into the MoE body. - FusedMoEQuantConfig( - quant_dtype=None, # skip quantization in prepare/finalize - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape, - )) + super().__init__(quant_config) self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype self.use_batched_format = use_batched_format - # TODO(bnell): put this stuff into quant config? - self.g1_alphas = g1_alphas - self.g2_alphas = g2_alphas - self.a1_gscale = a1_gscale - self.a2_gscale = a2_gscale - @property def activation_formats( self @@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: torch.Tensor, + a1q_scale: Optional[torch.Tensor], # unused workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): a=hidden_states, a1_gscale=self.a1_gscale, w1_fp4=w1, - w1_blockscale=w1_scale, + w1_blockscale=self.w1_scale, w1_alphas=self.g1_alphas, a2_gscale=self.a2_gscale, w2_fp4=w2, - w2_blockscale=w2_scale, + w2_blockscale=self.w2_scale, w2_alphas=self.g2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, @@ -788,14 +741,9 @@ def cutlass_moe_fp4( a: torch.Tensor, w1_fp4: torch.Tensor, w2_fp4: torch.Tensor, - w1_blockscale: torch.Tensor, - w2_blockscale: torch.Tensor, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + quant_config: FusedMoEQuantConfig, m: int, n: int, k: int, @@ -805,17 +753,31 @@ def cutlass_moe_fp4( assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "ModelOptNvFp4FusedMoE's cutlass_moe_fp4.") + + # TODO(bnell): this feels a bit hacky + # NVFP4 requires two levels of quantization, which involves + # computing some scaling factors dynamically. This makes it + # incompatible with the typical prepare -> MoE -> finalize + # pipeline. Move the quantization logic into the MoE body. + quant_config = FusedMoEQuantConfig.make( + quant_dtype=None, # skip quantization in prepare/finalize + per_act_token_quant=quant_config.per_act_token_quant, + per_out_ch_quant=quant_config.per_out_ch_quant, + block_shape=quant_config.block_shape, + g1_alphas=quant_config.g1_alphas, + g2_alphas=quant_config.g2_alphas, + a1_gscale=quant_config.a1_gscale, + a2_gscale=quant_config.a2_gscale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + ) + fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( - g1_alphas, - g2_alphas, - a1_gscale, - a2_gscale, max_experts_per_worker=e, out_dtype=a.dtype, - per_act_token_quant=False, - per_out_ch_quant=False, + quant_config=quant_config, use_batched_format=False, ), ) @@ -830,10 +792,6 @@ def cutlass_moe_fp4( activation="silu", global_num_experts=e, expert_map=None, - w1_scale=w1_blockscale, - w2_scale=w2_blockscale, - a1_scale=None, - a2_scale=None, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm( return True +# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8. def run_cutlass_block_scaled_fused_experts( a: torch.Tensor, w1: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 7b8467a5a0cf0..8830b95df7cf0 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools from typing import Optional import torch @@ -9,9 +8,11 @@ from tqdm import tqdm import vllm.envs as env import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce) + compute_aligned_M, deep_gemm_block_shape, deepgemm_moe_permute, + deepgemm_unpermute_and_reduce) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -25,14 +26,6 @@ from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous logger = init_logger(__name__) -@functools.cache -def deep_gemm_block_shape() -> list[int]: - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - block = dg.get_m_alignment_for_contiguous_layout() - return [block, block] - - def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool: align = deep_gemm_block_shape()[0] return align <= M and N % align == 0 and K % align == 0 @@ -57,13 +50,14 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, if not _valid_deep_gemm_shape(M, N, K): logger.debug_once( "DeepGemm disabled due to unaligned problem size. " - "M: %s, N: %s, K: %s. M should >= align size " - "and N and K must be multiples of %s." + "M: %s, N: %s, K: %s. M should >= %s " + "and N and K must be multiples of %s. " "This is not an error and we will fall back to triton.", M, N, K, align, + align, ) return False elif N <= 512: @@ -162,13 +156,12 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self): - super().__init__( - FusedMoEQuantConfig( - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=False, - block_shape=deep_gemm_block_shape(), - )) + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__(quant_config) + assert quant_config.block_shape == deep_gemm_block_shape() + assert quant_config.quant_dtype == torch.float8_e4m3fn + assert not quant_config.per_act_token_quant + assert not quant_config.per_out_ch_quant @property def activation_formats( @@ -220,21 +213,17 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - assert self.block_shape is not None assert a1q_scale is not None - assert w1_scale is not None - assert w2_scale is not None + assert self.a2_scale is None + assert self.block_shape is not None + assert self.w1_scale is not None + assert self.w2_scale is not None a1q = hidden_states _, N, K = w1.size() @@ -269,7 +258,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): aq_out=a1q_perm) assert a1q.size(0) == M_sum - m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), + m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids) self.activation(activation, act_out, mm1_out.view(-1, N)) @@ -280,7 +269,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): column_major_scales=True, out_q=quant_out) - m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), + m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids) if apply_router_weight_on_input: @@ -347,9 +336,16 @@ def deep_gemm_moe_fp8( Returns: - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. """ + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=deep_gemm_block_shape()) + fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - DeepGemmExperts(), + DeepGemmExperts(quant_config), ) return fn( hidden_states, @@ -357,13 +353,9 @@ def deep_gemm_moe_fp8( w2, topk_weights, topk_ids, - inplace, - activation, - global_num_experts, - expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 437e569d3130d..f390f0a25875e 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Callable, Optional, Union import deep_ep import torch @@ -25,6 +25,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self.num_dispatchers_ = num_dispatchers self.dp_size = dp_size self.rank_expert_offset = rank_expert_offset + self.async_prepare = True + # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. @@ -47,19 +49,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return torch.int64 def _get_dispatch_config(self) -> Optional[deep_ep.Config]: - if self.dp_size not in self.available_rank_configs: + if self.num_dispatchers_ not in self.available_rank_configs: return None - return deep_ep.Buffer.get_dispatch_config(self.dp_size) + return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_) def _get_combine_config(self) -> Optional[deep_ep.Config]: - if self.dp_size not in self.available_rank_configs: + if self.num_dispatchers_ not in self.available_rank_configs: return None - return deep_ep.Buffer.get_combine_config(self.dp_size) + return deep_ep.Buffer.get_combine_config(self.num_dispatchers_) - def _do_dispatch(self, tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], - rank_topk_ids: torch.Tensor, - rank_topk_weights: torch.Tensor, num_experts: int): + def _do_dispatch( + self, + tokens: torch.Tensor, + token_scales: Optional[torch.Tensor], + rank_topk_ids: torch.Tensor, + rank_topk_weights: torch.Tensor, + num_experts: int, + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> Callable: has_scales = token_scales is not None @@ -93,9 +101,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_alignment=1, config=self._get_dispatch_config(), previous_event=None, - async_finish=False, + async_finish=self.async_prepare, allocate_on_comm_stream=False) + return lambda: self._receiver( + event, + has_scales, + token_data, + expert_topk_ids, + num_experts, + expert_num_tokens_per_expert_list, + expert_topk_weights, + a1_scale, + quant_config, + ) + + def _receiver( + self, + event: deep_ep.EventOverlap, + has_scales: bool, + token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], + expert_topk_ids: Optional[torch.Tensor], + num_experts: int, + expert_num_tokens_per_expert_list: list[int], + expert_topk_weights: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + if self.async_prepare: + event.current_stream_wait() + if has_scales: expert_x, expert_x_scale = token_data else: @@ -112,6 +147,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # DeepEP's topk_ids output refers to the local experts directly. Offset # the topk_ids to move it back to the global experts space so it aligns # with existing vLLM interfaces. + assert expert_topk_ids is not None expert_topk_ids = torch.where( expert_topk_ids == -1, num_experts - 1 if self.rank_expert_offset == 0 else 0, @@ -123,61 +159,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( expert_num_tokens_per_expert_list, device=expert_x.device) - return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) - - def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: - - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1") - a1 = a1 * topk_weights.to(a1.dtype) - - if quant_config.is_block_quantized: - # Quant and Dispatch - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - a1_scale, - quant_dtype=quant_config.quant_dtype, - per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape, - ) - if a1q_scale is not None and a1q_scale.numel() == 1: - a1q_scale = a1q_scale.view(1, 1) - (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) - else: - # Dispatch and Quant - # DeepEP kernels only support dispatching block-quantized - # activation scales. - # Dispatch in bfloat16 - (expert_x, _, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) = self._do_dispatch( - tokens=a1, - token_scales=None, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts) + # Dispatch and Quant + # DeepEP kernels only support dispatching block-quantized + # activation scales. + # Dispatch in bfloat16 and quantize afterwards + if not quant_config.is_block_quantized: # Quantize after dispatch. expert_x_scale = None if expert_x.numel() != 0: @@ -191,7 +177,70 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def finalize( + def supports_async(self) -> bool: + return True + + def prepare_async( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[Callable, mk.ReceiverType]: + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * topk_weights.to(a1.dtype) + + if quant_config.is_block_quantized: + # Quant and Dispatch + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + ) + if a1q_scale is not None and a1q_scale.numel() == 1: + a1q_scale = a1q_scale.view(1, 1) + a1_post_scale = None + else: + a1q = a1 + a1q_scale = None + a1_post_scale = quant_config.a1_scale + + return (lambda *args: None, + self._do_dispatch(tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config)) + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + (_, receiver) = self.prepare_async(a1, topk_weights, topk_ids, + num_experts, expert_map, + apply_router_weight_on_input, + quant_config) + return receiver() + + def _finalize( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -199,7 +248,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + do_async: bool, + ) -> Optional[Callable]: assert self.handle is not None @@ -222,7 +272,46 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_weights=None, config=self._get_combine_config(), previous_event=None, - async_finish=False, + async_finish=do_async, allocate_on_comm_stream=False) - # Respect inplace outputs. - output.copy_(combined_x, non_blocking=True) + + if do_async: + + def _receiver(): + event.current_stream_wait() + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) + + return lambda: _receiver() + else: + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) + return None + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> Callable: + receiver = self._finalize(output, fused_expert_output, topk_weights, + topk_ids, apply_router_weight_on_input, + weight_and_reduce_impl, True) + assert receiver is not None + return receiver + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + self._finalize(output, fused_expert_output, topk_weights, topk_ids, + apply_router_weight_on_input, weight_and_reduce_impl, + False) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 93ac11fb4bfbf..101fc8798c427 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Callable, Optional, Union import deep_ep import torch @@ -11,6 +11,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) +from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, + dbo_maybe_run_recv_hook) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -55,7 +57,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. - self.handle = None + self.handles: list[Optional[tuple]] = [None, None] self.num_dispatchers_ = num_dispatchers def num_dispatchers(self) -> int: @@ -74,16 +76,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def _do_quant( self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], a1_dtype: torch.dtype, - quant_dtype: Union[torch.dtype, str, None], - per_act_token_quant: bool, - block_shape: Optional[list[int]], + quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - block_k = block_shape[1] if block_shape is not None else None if self.use_fp8_dispatch: + block_k = quant_config.block_shape[ + 1] if quant_config.block_shape is not None else None if block_k == DEEPEP_QUANT_BLOCK_SIZE: # DeepEP kernels did the quantization for us. x, x_scales = x @@ -99,44 +98,46 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # TODO (varun): Optimization - Use a batched version of quant x = x.view((-1, hidden_dim)) - x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype, - per_act_token_quant, - block_shape) + x, x_scales = moe_kernel_quantize_input( + x, quant_config.a1_scale, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) x = x.view((num_experts, -1, hidden_dim)) - if quant_dtype is not None: + if quant_config.quant_dtype is not None: assert x_scales is not None x_scales = normalize_batched_scales_shape(x_scales, num_experts) return x, x_scales - def prepare( + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> tuple[Callable, mk.ReceiverType]: hidden_size = a1.size(1) assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ (f"Hidden Size {hidden_size} not in supported list of hidden sizes" f"{self.SUPPORTED_HIDDEN_SIZES}") + a2a_idx = dbo_current_ubatch_id() + if self.use_fp8_dispatch: assert hidden_size % 128 == 0, \ "DeepEP kernels quantize the inputs in blocks of shape 128" - has_per_token_scales = a1_scale.numel( - ) != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + has_per_token_scales = quant_config.a1_scale.numel( + ) != 1 if quant_config.a1_scale is not None else ( + quant_config.a2_scale.numel() != 1 + if quant_config.a2_scale is not None else False) assert not has_per_token_scales, ( "low_latency kernels doesn't support dispatching per-token scales") @@ -148,23 +149,112 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1 = a1 * topk_weights.to(a1.dtype) # Dispatch - expert_x, expert_num_tokens, self.handle, event, hook = \ + expert_x, expert_num_tokens, handle, _, hook= \ self.buffer.low_latency_dispatch(a1, topk_ids, self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, async_finish=False, - return_recv_hook=False) + return_recv_hook=True) + self.handles[a2a_idx] = handle - expert_x, expert_x_scale = self._do_quant( - expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, - quant_config.per_act_token_quant, quant_config.block_shape) + return ( + hook, + lambda: self._receiver(expert_x, expert_num_tokens, quant_config. + a1_scale, a1.dtype, quant_config)) + + def _receiver( + self, + expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + expert_num_tokens: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a1_dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, + quant_config) expert_tokens_meta = mk.ExpertTokensMetadata( expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) - return (expert_x, expert_x_scale, expert_tokens_meta, None, None) + return expert_x, expert_x_scale, expert_tokens_meta, None, None + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + hook, receiver = self.prepare_async(a1, topk_weights, topk_ids, + num_experts, expert_map, + apply_router_weight_on_input, + quant_config) + hook() + return receiver() + + def _finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + do_async: bool, + ) -> Optional[Callable]: + assert isinstance( + weight_and_reduce_impl, TopKWeightAndReduceDelegate + ), ("Weight application and reduction happens in the combine kernel.") + + a2a_idx = dbo_current_ubatch_id() + do_recv_hook = dbo_enabled() or do_async + handle = self.handles[a2a_idx] + assert handle is not None + + combine_topk_weights = topk_weights + if apply_router_weight_on_input: + # weights have already been applied. + combine_topk_weights = torch.ones_like(topk_weights) + + # TODO (varun) : Enable zero copy mode + dbo_maybe_run_recv_hook() + _, _, recv_hook = self.buffer.low_latency_combine( + fused_expert_output, + topk_ids, + combine_topk_weights, + handle, + async_finish=False, + zero_copy=False, + return_recv_hook=do_recv_hook, + out=output) + + return recv_hook + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> Callable: + recv_hook = self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=True, + ) + assert recv_hook is not None + return recv_hook def finalize( self, @@ -175,23 +265,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: - assert isinstance( - weight_and_reduce_impl, TopKWeightAndReduceDelegate - ), ("Weight application and reduction happens in the combine kernel.") - assert self.handle is not None - - combine_topk_weights = topk_weights - if apply_router_weight_on_input: - # weights have already been applied. - combine_topk_weights = torch.ones_like(topk_weights) - - # TODO (varun) : Enable zero copy mode - _, event, hook = self.buffer.low_latency_combine( + self._finalize( + output, fused_expert_output, + topk_weights, topk_ids, - combine_topk_weights, - self.handle, - async_finish=False, - zero_copy=False, - return_recv_hook=False, - out=output) + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=False, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index feab3f74cac53..6eeec18a6ec87 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Optional import torch @@ -44,33 +44,20 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, out_dtype: torch.dtype, - quant_dtype: Union[torch.dtype, str, None], + quant_config: FusedMoEQuantConfig, ep_rank: int = 0, ep_size: int = 1, tp_rank: int = 0, tp_size: int = 1, ): - super().__init__( - FusedMoEQuantConfig( - quant_dtype=quant_dtype, - per_act_token_quant=False, - block_shape=None, - )) - assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( + super().__init__(quant_config) + assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), ( "Only nvfp4,fp8 quantization are currently supported.") self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank self.tp_size = tp_size - self.g1_alphas = g1_alphas - self.g2_alphas = g2_alphas - self.a1_gscale = a1_gscale - self.a2_gscale = a2_gscale self.out_dtype = out_dtype @property @@ -141,12 +128,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], # Not used workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -162,17 +144,17 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): fc2_expert_weights = w2 else: # Ensure w1_scale and w2_scale are not None before calling view - assert w1_scale is not None and w2_scale is not None, ( + assert self.w1_scale is not None and self.w2_scale is not None, ( "w1_scale and w2_scale must not " "be None for FlashInferExperts") # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. quant_scales = [ self.a1_gscale, - w1_scale.view(torch.int32), + self.w1_scale.view(torch.int32), self.g1_alphas, self.a2_gscale, - w2_scale.view(torch.int32), + self.w2_scale.view(torch.int32), self.g2_alphas, ] # FlashInfer API requires weight to be long for nvfp4 @@ -202,12 +184,7 @@ def flashinfer_cutlass_moe_fp4( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, + quant_config: FusedMoEQuantConfig, inplace: bool = False, activation: str = "silu", global_num_experts: int = -1, @@ -216,15 +193,10 @@ def flashinfer_cutlass_moe_fp4( ) -> torch.Tensor: fused_experts = mk.FusedMoEModularKernel( - FlashInferCutlassMoEPrepareAndFinalize(use_dp=False, - a1_gscale=a1_gscale), + FlashInferCutlassMoEPrepareAndFinalize(use_dp=False), FlashInferExperts( - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, - a1_gscale=a1_gscale, - a2_gscale=a2_gscale, out_dtype=hidden_states.dtype, - quant_dtype="nvfp4", + quant_config=quant_config, )) return fused_experts( @@ -237,7 +209,5 @@ def flashinfer_cutlass_moe_fp4( activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 061b02172c446..8c7eff59f3cd1 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -22,13 +22,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def __init__( self, use_dp: bool, - a1_gscale: Optional[torch.Tensor], num_dispatchers: int = 1, ): super().__init__() self.num_dispatchers_ = num_dispatchers self.use_dp = use_dp - self.a1_gscale = a1_gscale self.local_tokens = None @property @@ -47,18 +45,13 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], # Not used - a2_scale: Optional[torch.Tensor], # Not used topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, - # TODO(bnell): use quant_config + scales instead of ctor args quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -69,7 +62,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1q, a1q_scale = moe_kernel_quantize_input( a1, - self.a1_gscale, + quant_config.a1_gscale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py new file mode 100644 index 0000000000000..e358143fac7c7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import List # noqa: UP035 +from typing import Optional + +import torch + +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + calculate_tile_tokens_dim) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import direct_register_custom_op + + +def flashinfer_fused_moe_blockscale_fp8( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: List[int], #noqa: UP006 + routed_scaling: float = 1.0) -> torch.Tensor: + from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + assert top_k <= global_num_experts + assert top_k <= 8 + assert topk_group <= 4 + assert global_num_experts > num_expert_group + assert global_num_experts % num_expert_group == 0 + assert global_num_experts % 4 == 0 + assert top_k < (topk_group * global_num_experts / num_expert_group) + assert block_shape == [128, 128] + + a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() + return flashinfer_trtllm_fp8_block_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale_inv, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale_inv, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling, + tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k, + global_num_experts), + routing_method_type=2, # DeepSeek-styled routing method + use_shuffled_weight=False, + ) + + +def flashinfer_fused_moe_blockscale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0) -> torch.Tensor: + return torch.empty_like(x) + + +# TODO(bnell): Does this really need to be a torch.op? +direct_register_custom_op( + op_name="flashinfer_fused_moe_blockscale_fp8", + op_func=flashinfer_fused_moe_blockscale_fp8, + mutates_args=[], + fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + +def flashinfer_fused_moe_per_tensor_scale_fp8( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0) -> torch.Tensor: + num_expert_group = num_expert_group if num_expert_group is not None else 0 + topk_group = topk_group if topk_group is not None else 0 + + quant_hidden_states, _ = moe_kernel_quantize_input( + hidden_states, + input_scale, + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False) + + from vllm.utils.flashinfer import ( + flashinfer_trtllm_fp8_per_tensor_scale_moe) + return flashinfer_trtllm_fp8_per_tensor_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=quant_hidden_states, + gemm1_weights=gemm1_weights, + output1_scales_scalar=output1_scales_scalar, + output1_scales_gate_scalar=output1_scales_gate_scalar, + gemm2_weights=gemm2_weights, + output2_scales_scalar=output2_scales_scalar, + num_experts=num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=use_routing_scales_on_input, + tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0], + top_k, num_experts), + routing_method_type=routing_method_type) + + +def flashinfer_fused_moe_per_tensor_scale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +# TODO(bnell): Does this really need to be a torch.op? +direct_register_custom_op( + op_name="flashinfer_fused_moe_per_tensor_scale_fp8", + op_func=flashinfer_fused_moe_per_tensor_scale_fp8, + mutates_args=["hidden_states"], + fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index b46f4be4b912e..fe6ac458a9593 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -8,7 +8,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_config_dtype_str, try_get_optimal_moe_config) + try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched) from vllm.model_executor.layers.fused_moe.utils import ( @@ -498,17 +498,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.size(0) == a1.size(0) @@ -547,14 +543,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): dtype=torch.float32, device=a1.device) else: - assert a1_scale is None + assert quant_config.a1_scale is None b_a1_scale = None first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts - a1_scale = normalize_scales_shape(a1_scale) - a2_scale = normalize_scales_shape(a2_scale) + a1_scale = normalize_scales_shape(quant_config.a1_scale) for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() @@ -625,28 +620,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): self, max_num_tokens: int, num_dispatchers: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - block_shape: Optional[list[int]] = None, - per_act_token_quant: bool = False, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" - assert not use_mxfp4_w4a4, "NYI" + super().__init__(quant_config) + assert not self.quant_config.use_int8_w8a8, "NYI" + assert not self.quant_config.use_int8_w8a16, "NYI" + assert not self.quant_config.use_int4_w4a16, "NYI" + assert not self.quant_config.use_mxfp4_w4a4, "NYI" self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -707,12 +687,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -742,10 +717,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): tmp = _resize_cache(workspace2, (num, N)) if self.quant_config.is_quantized: - assert a1q_scale is not None and w1_scale is not None + assert a1q_scale is not None and self.w1_scale is not None input = self.dequant(hidden_states[expert, :, :], a1q_scale[expert]) - w1_dq = self.dequant(w1[expert], w1_scale[expert]) + w1_dq = self.dequant(w1[expert], self.w1_scale[expert]) input = input[:num] @ w1_dq.transpose(0, 1) else: input = hidden_states[expert, :num, :] @ w1[expert].transpose( @@ -754,8 +729,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): self.activation(activation, tmp, input.to(tmp.dtype)) if self.quant_config.is_quantized: - assert w2_scale is not None - w2_dq = self.dequant(w2[expert], w2_scale[expert]) + assert self.w2_scale is not None + w2_dq = self.dequant(w2[expert], self.w2_scale[expert]) else: w2_dq = w2[expert] @@ -842,35 +817,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): self, max_num_tokens: int, num_dispatchers: int, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - assert not use_int8_w8a8, "NYI" - assert not use_int8_w8a16, "NYI" - assert not use_int4_w4a16, "NYI" - assert not use_mxfp4_w4a4, "NYI" + super().__init__(quant_config) + assert not self.quant_config.use_int8_w8a8, "NYI" + assert not self.quant_config.use_int8_w8a16, "NYI" + assert not self.quant_config.use_int4_w4a16, "NYI" + assert not self.quant_config.use_mxfp4_w4a4, "NYI" assert max_num_tokens > 0 assert num_dispatchers > 0 - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int4_w4a16 = use_int4_w4a16 - self.use_int8_w8a16 = use_int8_w8a16 - self.use_mxfp4_w4a4 = use_mxfp4_w4a4 self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -923,19 +878,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): # Check constraints. - if self.use_int4_w4a16: + if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( "Hidden size mismatch") else: @@ -960,11 +910,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): assert w1.size(0) == E assert w2.size(0) == E - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - use_mxfp4_w4a4=self.use_mxfp4_w4a4, - dtype=hidden_states.dtype) + config_dtype = self.quant_config.config_name(hidden_states.dtype) config = try_get_optimal_moe_config( w1.size(), @@ -994,7 +940,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) - if self.use_fp8_w8a8: + # TODO(bnell): should this be done for any quantized type? + if self.quant_config.use_fp8_w8a8: intermediate_cache1.fill_(0) a1q_scale = normalize_batched_scales_shape(a1q_scale, E) @@ -1007,11 +954,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens=expert_num_tokens, compute_type=compute_type, A_scale=a1q_scale, - B_scale=w1_scale, - B_zp=w1_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + B_scale=self.w1_scale, + B_zp=self.w1_zp, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape) @@ -1023,7 +970,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): intermediate_cache1.view(-1, N)) qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( - intermediate_cache2, a2_scale, max_num_tokens, E, N, + intermediate_cache2, self.a2_scale, max_num_tokens, E, N, expert_num_tokens, self.quant_dtype, self.per_act_token_quant, self.block_shape) @@ -1034,11 +981,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens=expert_num_tokens, compute_type=compute_type, A_scale=a2q_scale, - B_scale=w2_scale, - B_zp=w2_zp, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + B_scale=self.w2_scale, + B_zp=self.w2_zp, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, per_act_token_quant=self.per_act_token_quant, block_shape=self.block_shape) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 1e3ac6cd79f68..cc51ecdbacdef 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -161,6 +161,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor, if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, 2 * N)) elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 torch.ops._C.swigluoai_and_mul(intermediate_cache2, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index eb3e14180ecfe..d4de3f640865e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Fused MoE kernel.""" +"""Fused MoE Triton kernels.""" import functools import json import os # torch.compile needs typing.List. It will fail torch.library.infer_schema # otherwise from typing import List # noqa: UP035 -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -18,7 +18,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger # yapf: disable from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, get_config_quant_dtype) + FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( _valid_cutlass_block_scaled_grouped_gemm, run_cutlass_block_scaled_fused_experts) @@ -32,9 +32,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - calculate_tile_tokens_dim) + _resize_cache, activation_without_mul, moe_kernel_quantize_input) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4) from vllm.platforms import current_platform @@ -534,7 +532,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, EM = sorted_token_ids.size(0) if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. - # We assume that top_ids of each token is unique, so + # We assume that top_ids of each token is unique, # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. EM = min(sorted_token_ids.size(0), @@ -720,7 +718,10 @@ def get_moe_configs( logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} + tuned_config = json.load(f) + # Delete triton_version from tuned_config + tuned_config.pop("triton_version", None) + return {int(key): val for key, val in tuned_config.items()} # If no optimized configuration is available, we will use the default # configuration @@ -1044,87 +1045,66 @@ def fused_grouped_topk( return topk_values.to(torch.float32), topk_indices.to(torch.int32) -def get_config_dtype_str( - dtype: torch.dtype, - use_int4_w4a16: Optional[bool] = False, - use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False, - use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]: - if use_fp8_w8a8: - return "fp8_w8a8" - elif use_int8_w8a16: - return "int8_w8a16" - elif use_int4_w4a16: - return "int4_w4a16" - elif use_mxfp4_w4a4: - return "mxfp4_w4a4" - elif dtype == torch.float: - # avoiding cases where kernel fails when float32 MoE - # use fp16/bfloat16 configs - return "float32" - return None - - def inplace_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, #noqa: UP006 - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> None: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, #noqa: UP006 + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - activation, is_act_and_mul, - apply_router_weight_on_input, use_fp8_w8a8, + activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) -def inplace_fused_experts_fake(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> None: +def inplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, #noqa: UP006 + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> None: pass @@ -1138,175 +1118,6 @@ direct_register_custom_op( ) -def flashinfer_fused_moe_blockscale_fp8( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: List[int], #noqa: UP006 - routed_scaling: float = 1.0) -> torch.Tensor: - from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe - assert top_k <= global_num_experts - assert top_k <= 8 - assert topk_group <= 4 - assert global_num_experts > num_expert_group - assert global_num_experts % num_expert_group == 0 - assert global_num_experts % 4 == 0 - assert top_k < (topk_group * global_num_experts / num_expert_group) - assert block_shape == [128, 128] - - a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() - return flashinfer_trtllm_fp8_block_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=w13_weight, - gemm1_weights_scale=w13_weight_scale_inv, - gemm2_weights=w2_weight, - gemm2_weights_scale=w2_weight_scale_inv, - num_experts=global_num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling, - tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k, - global_num_experts), - routing_method_type=2, # DeepSeek-styled routing method - use_shuffled_weight=False, - ) - - -def flashinfer_fused_moe_blockscale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: list[int], - routed_scaling: float = 1.0) -> torch.Tensor: - return torch.empty_like(x) - - -direct_register_custom_op( - op_name="flashinfer_fused_moe_blockscale_fp8", - op_func=flashinfer_fused_moe_blockscale_fp8, - mutates_args=[], - fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), -) - - -def flashinfer_fused_moe_per_tensor_scale_fp8( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - routed_scaling_factor: float = 1.0) -> torch.Tensor: - num_expert_group = num_expert_group if num_expert_group is not None else 0 - topk_group = topk_group if topk_group is not None else 0 - - quant_hidden_states, _ = moe_kernel_quantize_input( - hidden_states, - input_scale, - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=False) - - from vllm.utils.flashinfer import ( - flashinfer_trtllm_fp8_per_tensor_scale_moe) - return flashinfer_trtllm_fp8_per_tensor_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=quant_hidden_states, - gemm1_weights=gemm1_weights, - output1_scales_scalar=output1_scales_scalar, - output1_scales_gate_scalar=output1_scales_gate_scalar, - gemm2_weights=gemm2_weights, - output2_scales_scalar=output2_scales_scalar, - num_experts=num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=local_expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling_factor, - use_routing_scales_on_input=use_routing_scales_on_input, - tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0], - top_k, num_experts), - routing_method_type=routing_method_type) - - -def flashinfer_fused_moe_per_tensor_scale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - routed_scaling_factor: float = 1.0) -> torch.Tensor: - pass - - -direct_register_custom_op( - op_name="flashinfer_fused_moe_per_tensor_scale_fp8", - op_func=flashinfer_fused_moe_per_tensor_scale_fp8, - mutates_args=["hidden_states"], - fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), -) - - def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1314,7 +1125,6 @@ def outplace_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", - is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -1336,37 +1146,37 @@ def outplace_fused_experts( ) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, False, activation, - is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, - per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, - w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) + apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, + use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, + global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, + a1_scale, a2_scale, block_shape, w1_bias, w2_bias) def outplace_fused_experts_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - is_act_and_mul: bool = True, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1398,45 +1208,36 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: # TODO (bnell): replace this with modular op. Can get rid of inplace/outplace # torch ops. -def fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - is_act_and_mul: bool = True, - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - allow_deep_gemm: bool = False, - allow_cutlass_block_scaled_grouped_gemm: bool = False, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + quant_config: Optional[FusedMoEQuantConfig] = None, + allow_deep_gemm: bool = False, + allow_cutlass_block_scaled_grouped_gemm: bool = False, +) -> torch.Tensor: + + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + use_fp8_w8a8 = quant_config.use_fp8_w8a8 + # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. # However, on B200, we use DeepGemm for all cases because they only support # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - if (allow_deep_gemm and use_fp8_w8a8 and + if (allow_deep_gemm and quant_config.use_fp8_w8a8 and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): + assert quant_config is not None assert apply_router_weight_on_input is False - assert is_act_and_mul, ( - "DeepGemm only supports is_act_and_mul=True for now.") return deep_gemm_moe_fp8( hidden_states=hidden_states, w1=w1, @@ -1447,22 +1248,23 @@ def fused_experts(hidden_states: torch.Tensor, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 and _valid_cutlass_block_scaled_grouped_gemm( w1, w2, inplace, activation, apply_router_weight_on_input, expert_map)): + assert quant_config is not None return run_cutlass_block_scaled_fused_experts( a=hidden_states, w1=w1, w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, topk_weights=topk_weights, topk_ids=topk_ids) else: @@ -1473,26 +1275,49 @@ def fused_experts(hidden_states: torch.Tensor, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - is_act_and_mul=is_act_and_mul, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_channel_quant=per_channel_quant, + use_fp8_w8a8=quant_config.use_fp8_w8a8, + use_int8_w8a8=quant_config.use_int8_w8a8, + use_int8_w8a16=quant_config.use_int8_w8a16, + use_int4_w4a16=quant_config.use_int4_w4a16, + use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4, + per_channel_quant=quant_config.per_act_token_quant, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape, - w1_bias=w1_bias, - w2_bias=w2_bias, - ) + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + w1_zp=quant_config.w1_zp, + w2_zp=quant_config.w2_zp, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, + block_shape=quant_config.block_shape, + w1_bias=quant_config.w1_bias, + w2_bias=quant_config.w2_bias) + + +SILU_NO_MUL: str = activation_without_mul("silu") +GELU_NO_MUL: str = activation_without_mul("gelu") + + +def _get_config_quant_dtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_mxfp4_w4a4: bool, +) -> Union[None, torch.dtype, str]: + """ + Get the quantization type based on the quantization strategy flags. + We don't have a quant_config at this point so we need to work backwards. + A return type of None means no quantization is required because the + input is unquantized or has been quantized prior to calling + fused_experts_impl. + """ + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + elif use_mxfp4_w4a4: + return "mxfp4" + return None def fused_experts_impl( @@ -1503,7 +1328,6 @@ def fused_experts_impl( topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", - is_act_and_mul: bool = True, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -1552,17 +1376,18 @@ def fused_experts_impl( # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) - config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - dtype=hidden_states.dtype) - qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4) + config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, + dtype=hidden_states.dtype) + + # Note: for use_int8_w8a16 or use_int4_w4a16, the activations are + # quantized prior to calling fused_experts. + quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_mxfp4_w4a4=use_mxfp4_w4a4) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -1635,7 +1460,7 @@ def fused_experts_impl( qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, - quant_dtype=qtype, + quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape) @@ -1666,30 +1491,29 @@ def fused_experts_impl( B_bias=w1_bias) # Activation function with multiplication - if activation == "silu" and is_act_and_mul: + if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - elif activation == "gelu" and is_act_and_mul: + elif activation == "gelu": torch.ops._C.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - elif activation == "swigluoai" and is_act_and_mul: + elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 torch.ops._C.swigluoai_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) # Activation function without multiplication - elif activation == "silu": + elif activation == SILU_NO_MUL: intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) - elif activation == "gelu": + elif activation == GELU_NO_MUL: intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}, " - f"with is_act_and_mul={is_act_and_mul}.") + raise ValueError(f"Unsupported FusedMoe activation: {activation}.") qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, - quant_dtype=qtype, + quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape) @@ -1721,164 +1545,13 @@ def fused_experts_impl( return out_hidden_states -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - inplace: bool = False, - activation: str = "silu", - is_act_and_mul: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - activation (str): The activation function to apply after the first - MoE layer. - - is_act_and_mul (bool): If True, use activation-and-mul function for - activation (self-gated activation), otherwise use activation function - for activation (ungated activation). - - num_expert_group: Optional[int]: additional parameter for grouped_topk - - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk - note: Deepseekv2 model uses grouped_topk - - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. - - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 - activation to compute the inner products for w1 and w2. - Defaults to False. - - use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and - OCP MXFP4 activation to compute the inner products for w1 and w2. - Defaults to False. - - global_num_experts (int): The total number of experts in the global - expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for - a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2. - - block_shape: (Optional[list[int]]): Optional block size for block-wise - quantization. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - if not is_act_and_mul: - assert inplace is False, ( - "is_act_and_mul=False is not supported with inplace=True") - - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, - topk, renormalize, - num_expert_group, topk_group) - elif custom_routing_function is None: - topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, gating_output, topk, renormalize) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize) - - return fused_experts(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - activation=activation, - is_act_and_mul=is_act_and_mul, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_channel_quant=per_channel_quant, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape, - w1_bias=w1_bias, - w2_bias=w2_bias) - - class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - - self.use_fp8_w8a8 = use_fp8_w8a8 - self.use_int4_w4a16 = use_int4_w4a16 - self.use_int8_w8a8 = use_int8_w8a8 - self.use_int8_w8a16 = use_int8_w8a16 - self.use_mxfp4_w4a4 = use_mxfp4_w4a4 + super().__init__(quant_config) @property def activation_formats( @@ -1924,19 +1597,14 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): # Check constraints. - if self.use_int4_w4a16: + if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( "Hidden size mismatch") else: @@ -1959,17 +1627,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): if global_num_experts == -1: global_num_experts = E - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - use_mxfp4_w4a4=self.use_mxfp4_w4a4, - dtype=hidden_states.dtype) - config = try_get_optimal_moe_config( w1.size(), w2.size(), top_k_num, - config_dtype, + self.quant_config.config_name(hidden_states.dtype), num_tokens, block_shape=self.block_shape, ) @@ -2003,8 +1665,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): w1, intermediate_cache1, a1q_scale, - w1_scale, - w1_zp, + self.w1_scale, + self.w1_zp, None, # topk_weights sorted_token_ids, expert_ids, @@ -2013,13 +1675,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): top_k_num, config, compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, - B_bias=None # TODO support B_bias + B_bias=self.w1_bias, ) self.activation(activation, intermediate_cache2, @@ -2028,7 +1690,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.quant_dtype, + intermediate_cache2, self.a2_scale, self.quant_dtype, self.per_act_token_quant, self.block_shape) invoke_fused_moe_kernel( @@ -2036,8 +1698,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): w2, intermediate_cache3, a2q_scale, - w2_scale, - w2_zp, + self.w2_scale, + self.w2_zp, topk_weights, sorted_token_ids, expert_ids, @@ -2046,36 +1708,21 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): 1, config, compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, - B_bias=None # TODO support B_bias + B_bias=self.w2_bias, ) ops.moe_sum(intermediate_cache3, output) def modular_triton_fused_moe( - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - use_mxfp4_w4a4: bool, - per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, -) -> mk.FusedMoEModularKernel: + quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), - TritonExperts( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ), + TritonExperts(quant_config), ) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 312befe2c1d71..614a83ad1158c 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional +from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.utils import has_triton_kernels @@ -23,9 +25,6 @@ if has_triton_kernels(): "Failed to import Triton kernels. Please make sure your triton " "version is compatible.") -if TYPE_CHECKING: - from triton_kernels.matmul_ogs import PrecisionConfig - def triton_kernel_moe_forward( hidden_states: torch.Tensor, @@ -35,20 +34,10 @@ def triton_kernel_moe_forward( topk: int, renormalize: bool, activation: str = "silu", + quant_config: Optional[FusedMoEQuantConfig] = None, apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - w1_precision: Optional["PrecisionConfig"] = None, - w2_precision: Optional["PrecisionConfig"] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, ) -> torch.Tensor: routing_data, gather_idx, scatter_idx = routing(gating_output, @@ -64,20 +53,10 @@ def triton_kernel_moe_forward( gather_idx, scatter_idx, activation=activation, + quant_config=quant_config, apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=use_fp8_w8a8, - per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_bias=w1_bias, - w2_bias=w2_bias, - w1_precision=w1_precision, - w2_precision=w2_precision, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) + expert_map=expert_map) # This is a triton implementation of the fused_experts function @@ -90,28 +69,23 @@ def triton_kernel_fused_experts( gather_indx, # GatherIndx scatter_indx, # ScatterIndx activation: str = "silu", + quant_config: Optional[FusedMoEQuantConfig] = None, swiglu_alpha: float = 1.702, swiglu_limit: float = 7.0, apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - w1_precision: Optional["PrecisionConfig"] = None, - w2_precision: Optional["PrecisionConfig"] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, + a1q_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 - assert w1_bias is None or w1_bias.dtype == torch.float32 - assert w2_bias is None or w2_bias.dtype == torch.float32 + assert (quant_config.w1_bias is None + or quant_config.w1_bias.dtype == torch.float32) + assert (quant_config.w2_bias is None + or quant_config.w2_bias.dtype == torch.float32) # Shape check, only check non-mxfp4 assert hidden_states.shape[-1] == w1.shape[-2] @@ -130,20 +104,20 @@ def triton_kernel_fused_experts( intermediate_cache1 = matmul_ogs( hidden_states, w1, - w1_bias, + quant_config.w1_bias, routing_data, gather_indx=gather_indx, - precision_config=w1_precision, + precision_config=quant_config.w1_precision, gammas=gammas if apply_router_weight_on_input else None, fused_activation=act) intermediate_cache3 = matmul_ogs( intermediate_cache1, w2, - w2_bias, + quant_config.w2_bias, routing_data, scatter_indx=scatter_indx, - precision_config=w2_precision, + precision_config=quant_config.w2_precision, gammas=None if apply_router_weight_on_input else gammas, y=output_tensor, ) @@ -154,21 +128,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - quant_config, max_num_tokens: int, num_dispatchers: int, - w1_precision: "PrecisionConfig", - w2_precision: "PrecisionConfig", - w1_bias: Optional[torch.Tensor], - w2_bias: Optional[torch.Tensor], + quant_config: FusedMoEQuantConfig, ): super().__init__(quant_config) self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers - self.w1_precision = w1_precision - self.w2_precision = w2_precision - self.w1_bias = w1_bias - self.w2_bias = w2_bias @property def activation_formats( @@ -212,12 +178,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -228,20 +189,12 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): hidden_states, w1, w2, - None, - None, - None, + routing_data=None, + gather_indx=None, + scatter_indx=None, activation=activation, + quant_config=self.quant_config, apply_router_weight_on_input=False, - use_fp8_w8a8=False, - per_channel_quant=False, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_bias=self.w1_bias, - w2_bias=self.w2_bias, - w1_precision=self.w1_precision, - w2_precision=self.w2_precision, - a1_scale=a1q_scale, - a2_scale=a2_scale) + a1q_scale=a1q_scale) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3a2c9cbaf459e..da513d75da4da 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Iterable from enum import Enum -from typing import Callable, Literal, Optional, overload +from typing import Callable, Literal, Optional, Union, get_args, overload import torch import torch.nn.functional as F @@ -12,6 +12,7 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs from vllm.config import get_current_vllm_config +from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import (get_dp_group, get_ep_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -21,7 +22,8 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp # yapf: disable from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEParallelConfig) + FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig, + FusedMoEQuantConfig, biased_moe_quant_config) # yapf: enable from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, FusedMoEModularKernel, @@ -35,8 +37,9 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, +from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up) +from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -76,11 +79,11 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): - # TODO(bnell): also pass quant_config? def __init__(self, moe: FusedMoEConfig): super().__init__() self.moe = moe - self.fused_experts: Optional[Callable] = None + self.moe_quant_config: Optional[FusedMoEQuantConfig] = None + self.fused_experts: Optional[FusedMoEModularKernel] = None self.topk_indices_dtype = None @abstractmethod @@ -101,23 +104,28 @@ class FusedMoEMethodBase(QuantizeMethodBase): @staticmethod def _maybe_make_prepare_finalize( - moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]: + moe: FusedMoEConfig, + quant_config: Optional[FusedMoEQuantConfig], + ) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + # TODO: could allow this now assert not moe.use_flashinfer_cutlass_kernels, \ "Must be created in modelopt.py" if moe.use_pplx_kernels: + assert quant_config is not None + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, moe.hidden_dim, moe.in_dtype, - moe.quant_dtype, - per_act_token_quant=moe.per_act_token_quant, - block_shape=moe.block_shape, + quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, ) all_to_all_args = dict( @@ -163,6 +171,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ) elif moe.use_deepep_ll_kernels: + assert quant_config is not None all_to_all_args = dict( max_num_tokens_per_dp_rank=moe.max_num_tokens, token_hidden_size=moe.hidden_dim, @@ -172,13 +181,11 @@ class FusedMoEMethodBase(QuantizeMethodBase): all2all_manager.world_size) handle = all2all_manager.get_handle(all_to_all_args) - # Note : We may want to use FP8 dispatch even otherwise just to - # reduce datamovement - use_fp8_dispatch = (moe.quant_config is not None - and moe.quant_config.quant_dtype - == current_platform.fp8_dtype() - and moe.quant_config.block_shape - == DEEPEP_QUANT_BLOCK_SHAPE) + # Note: We may want to use FP8 dispatch just to reduce + # data movement. + use_fp8_dispatch = ( + quant_config.quant_dtype == current_platform.fp8_dtype() + and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE) prepare_finalize = DeepEPLLPrepareAndFinalize( handle, @@ -190,11 +197,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): return prepare_finalize def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[FusedMoEPrepareAndFinalize]: - if moe.moe_parallel_config.use_all2all_kernels: - return FusedMoEMethodBase._maybe_make_prepare_finalize(moe) + self) -> Optional[FusedMoEPrepareAndFinalize]: + if self.moe.moe_parallel_config.use_all2all_kernels: + return FusedMoEMethodBase._maybe_make_prepare_finalize( + self.moe, self.moe_quant_config) else: return None @@ -202,7 +208,13 @@ class FusedMoEMethodBase(QuantizeMethodBase): # prepare_communication_buffer_for_model. def init_prepare_finalize(self, layer: torch.nn.Module): assert self.moe is not None - prepare_finalize = self.maybe_make_prepare_finalize(self.moe) + + # We must get the quant config here so that the layer is + # completely initialized, i.e. all weights loaded and post + # processed. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + + prepare_finalize = self.maybe_make_prepare_finalize() if prepare_finalize is not None: logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__, @@ -211,16 +223,16 @@ class FusedMoEMethodBase(QuantizeMethodBase): assert self.fused_experts is None, \ f"Attempt to override experts for {id(self)}!" self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, self.moe, layer) + experts = self.select_gemm_impl(prepare_finalize, layer) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, + layer.shared_experts, ) def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate @@ -229,6 +241,11 @@ class FusedMoEMethodBase(QuantizeMethodBase): f"{self.__class__.__name__} must select appropriate gemm " "implementation based on the prepare_finalize") + @abstractmethod + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + raise NotImplementedError + @abstractmethod def apply( self, @@ -252,7 +269,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: raise NotImplementedError @@ -262,7 +279,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.has_bias = self.moe.has_bias self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts @@ -270,23 +286,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): else: self.rocm_aiter_fused_experts = None # type: ignore + def maybe_make_prepare_finalize( + self) -> Optional[FusedMoEPrepareAndFinalize]: + if self.rocm_aiter_moe_enabled: + return None + else: + return super().maybe_make_prepare_finalize() + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - # TODO(bnell): Remove. Every layer should have an moe config object. - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, ) else: logger.debug("TritonExperts %s", self.moe) - return TritonExperts() + return TritonExperts(self.moe_quant_config) def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -300,7 +323,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - if self.has_bias: + if self.moe.has_bias: w13_bias = torch.nn.Parameter(torch.zeros( num_experts, 2 * intermediate_size_per_partition, @@ -317,7 +340,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - if self.has_bias: + if self.moe.has_bias: w2_bias = torch.nn.Parameter(torch.zeros(num_experts, hidden_size, dtype=params_dtype), @@ -409,7 +432,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None @@ -439,6 +462,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): logical_replica_count=logical_replica_count, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.moe.has_bias: + return biased_moe_quant_config( + layer.w13_bias, + layer.w2_bias, + ) + else: + return FUSED_MOE_UNQUANTIZED_CONFIG + def forward_cuda( self, layer: torch.nn.Module, @@ -461,7 +494,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -483,6 +516,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): logical_replica_count=logical_replica_count) if self.rocm_aiter_moe_enabled: + assert self.fused_experts is None return self.rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -493,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) elif self.fused_experts is not None: - if self.has_bias: + if self.moe.has_bias: raise ValueError( "FusedMoEModularKernel does not support bias.") return self.fused_experts( @@ -514,12 +548,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - w1_bias=layer.w13_bias if self.has_bias else None, - w2_bias=layer.w2_bias if self.has_bias else None, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, + quant_config=self.moe_quant_config, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, @@ -547,7 +580,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -594,7 +627,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: @@ -633,7 +666,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None @@ -674,8 +707,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def determine_expert_map( - ep_size: int, ep_rank: int, - global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]: + ep_size: int, + ep_rank: int, + global_num_experts: int, + expert_placement_strategy: ExpertPlacementStrategy = "linear", +) -> tuple[int, Optional[torch.Tensor]]: """ Calculates how many experts should be assigned to each rank for EP and creates a mapping from global to local expert index. Experts are @@ -683,8 +719,11 @@ def determine_expert_map( last rank. Args: - ep_size (int): The size of the expert parallel group - global_num_experts (int): The total number of experts in the model. + ep_size: The size of the expert parallel group + ep_rank: The rank of the current process in the expert parallel + group + global_num_experts: The total number of experts in the model. + expert_placement_strategy: The expert placement strategy. Returns: tuple[int, Optional[torch.Tensor]]: A tuple containing: @@ -709,10 +748,24 @@ def determine_expert_map( # Create a tensor of size num_experts filled with -1 expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) - # Create a expert map for the local experts - start_idx = ep_rank * base_experts + min(ep_rank, remainder) - expert_map[start_idx:start_idx + local_num_experts] = torch.arange( - 0, local_num_experts, dtype=torch.int32) + # Create an expert map for the local experts + if expert_placement_strategy == "linear": + start_idx = ep_rank * base_experts + min(ep_rank, remainder) + expert_map[start_idx:start_idx + local_num_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32) + elif expert_placement_strategy == "round_robin": + local_log_experts = torch.arange(ep_rank, + global_num_experts, + ep_size, + dtype=torch.int32) + + expert_map[local_log_experts] = torch.arange(0, + local_num_experts, + dtype=torch.int32) + else: + raise ValueError("Unsupported expert placement strategy " + f"'{expert_placement_strategy}', expected one of " + f"{get_args(ExpertPlacementStrategy)}") return (local_num_experts, expert_map) @@ -754,7 +807,7 @@ class FusedMoE(CustomOp): intermediate_size: Intermediate size of the experts params_dtype: Data type for the parameters. reduce_results: Whether to all all_reduce on the output of the layer - renomalize: Whether to renormalize the logits in the fused_moe kernel + renormalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. enable_eplb: Whether to enable expert parallelism load balancer. """ @@ -785,6 +838,7 @@ class FusedMoE(CustomOp): enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, + is_sequence_parallel=False, ): super().__init__() if params_dtype is None: @@ -796,6 +850,10 @@ class FusedMoE(CustomOp): dp_size_ = (dp_size if dp_size is not None else get_dp_group().world_size) + self.is_sequence_parallel = is_sequence_parallel + if self.is_sequence_parallel: + self.sp_size = tp_size_ + vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( @@ -805,11 +863,18 @@ class FusedMoE(CustomOp): self.global_num_experts = num_experts + num_redundant_experts - # we padding globally so EP buffer allocation works + # we are padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": - from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 - should_use_flashinfer_mxfp4) - if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Backend, get_mxfp4_backend) + current_mxfp4_backend = get_mxfp4_backend() + if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS): + hidden_size = round_up(hidden_size, 128) + elif (current_platform.is_rocm() or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or + current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op @@ -833,15 +898,36 @@ class FusedMoE(CustomOp): else: assert num_redundant_experts == 0, \ "Redundant experts are only supported with EPLB." + + expert_placement_strategy = ( + vllm_config.parallel_config.expert_placement_strategy) + if expert_placement_strategy == "round_robin": + # TODO(Bruce): will support round robin expert placement with + # EPLB enabled in the future. + round_robin_supported = ((num_expert_group is not None + and num_expert_group > 1) + and num_redundant_experts == 0 + and not self.enable_eplb) + + if not round_robin_supported: + logger.warning( + "Round-robin expert placement is only supported for " + "models with multiple expert groups and no redundant " + "experts. Falling back to linear expert placement.") + expert_placement_strategy = "linear" + self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, - global_num_experts=self.global_num_experts) + global_num_experts=self.global_num_experts, + expert_placement_strategy=expert_placement_strategy, + ) logger.info_once( - "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + "[EP Rank %s/%s] Expert parallelism is enabled. Expert " + "placement strategy: %s. Local/global" " number of experts: %s/%s. Experts local to global index map:" - " %s.", self.ep_rank, self.ep_size, self.local_num_experts, - self.global_num_experts, + " %s.", self.ep_rank, self.ep_size, expert_placement_strategy, + self.local_num_experts, self.global_num_experts, get_compressed_expert_map(self.expert_map)) else: self.local_num_experts, self.expert_map = (self.global_num_experts, @@ -877,16 +963,18 @@ class FusedMoE(CustomOp): # since model_config is not set in the pytest test. model_dtype = params_dtype - moe = FusedMoEConfig.make(num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - in_dtype=model_dtype, - max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, - quant_config=quant_config, - has_bias=has_bias) + moe = FusedMoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=model_dtype, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + has_bias=has_bias, + ) self.moe_config = moe + self.moe_quant_config: Optional[FusedMoEQuantConfig] = None self.quant_config = quant_config # Note: get_quant_method will look at the layer's local_num_experts @@ -934,19 +1022,38 @@ class FusedMoE(CustomOp): # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None + + # TODO(bnell): flashinfer uses non-batched format. + # Does it really need a batched buffer? if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_config.use_flashinfer_cutlass_kernels): - self.batched_hidden_states = torch.zeros( - (moe.max_num_tokens, self.hidden_size), - dtype=moe.in_dtype, - device=torch.cuda.current_device()) + if vllm_config.parallel_config.enable_dbo: + self.batched_hidden_states = torch.zeros( + (2, moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) - # Note here we use `num_experts` which is logical expert count - self.batched_router_logits = torch.zeros( - (moe.max_num_tokens, num_experts), - dtype=moe.in_dtype, - device=torch.cuda.current_device()) + # Note here we use `num_experts` which is logical expert count + self.batched_router_logits = torch.zeros( + (2, moe.max_num_tokens, num_experts), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + else: + self.batched_hidden_states = torch.zeros( + (moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + + # Note here we use `num_experts` which is logical expert count + self.batched_router_logits = torch.zeros( + (moe.max_num_tokens, num_experts), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return None @property def tp_size(self): @@ -990,7 +1097,9 @@ class FusedMoE(CustomOp): @property def use_flashinfer_cutlass_kernels(self): - return self.moe_config.use_flashinfer_cutlass_kernels + return (self.moe_quant_config is not None + and self.moe_quant_config.quant_dtype == "nvfp4" + and self.moe_config.use_flashinfer_cutlass_kernels) def update_expert_map(self): # ep_size and ep_rank should already be updated @@ -1399,7 +1508,8 @@ class FusedMoE(CustomOp): return [ weight.view(self.local_num_experts, -1) for name, weight in weights - if name not in NON_EXPERT_WEIGHTS + if name not in NON_EXPERT_WEIGHTS and weight.shape != torch.Size( + []) and not name.startswith("_shared_experts.") ] def set_eplb_state( @@ -1419,6 +1529,11 @@ class FusedMoE(CustomOp): self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx] + def ensure_moe_quant_config(self): + if self.quant_method.moe_quant_config is None: + self.quant_method.moe_quant_config = ( + self.quant_method.get_fused_moe_quant_config(self)) + @staticmethod def select_experts( hidden_states: torch.Tensor, @@ -1582,25 +1697,52 @@ class FusedMoE(CustomOp): else: return tensor_model_parallel_all_reduce(final_hidden_states) - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_native( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: og_hidden_states = hidden_states.shape[-1] if self.hidden_size != og_hidden_states: hidden_states = F.pad(hidden_states, (0, self.hidden_size - og_hidden_states), mode='constant', value=0.0) - # TODO: Once the OOM issue for the TPU backend is resolved, we will - # switch to using the moe_forward custom op. - if current_platform.is_tpu(): - return self.forward_impl(hidden_states, router_logits) - else: - return torch.ops.vllm.moe_forward( - hidden_states, router_logits, - self.layer_name)[..., :og_hidden_states] - def forward_impl_chunked(self, full_hidden_states: torch.Tensor, - full_router_logits: torch.Tensor): + if self.shared_experts is None: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + fused_output = self.forward_impl(hidden_states, router_logits) + assert not isinstance(fused_output, tuple) + else: + fused_output = torch.ops.vllm.moe_forward( + hidden_states, router_logits, self.layer_name) + return fused_output[..., :og_hidden_states] + else: + if current_platform.is_tpu(): + # TODO: Once the OOM issue for the TPU backend is resolved, we + # will switch to using the moe_forward custom op. + shared_output, fused_output = self.forward_impl( + hidden_states, router_logits) + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name) + return (shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states]) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(hidden_states, router_logits) + + def forward_impl_chunked( + self, + full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.batched_hidden_states is not None assert self.batched_router_logits is not None assert self.batched_hidden_states.dtype == full_hidden_states.dtype @@ -1611,21 +1753,41 @@ class FusedMoE(CustomOp): assert ( self.batched_router_logits.size(-1) == full_router_logits.size(-1)) - full_final_hidden_states = torch.empty_like(full_hidden_states) + self.ensure_moe_quant_config() + + full_fused_final_hidden_states = torch.empty_like(full_hidden_states) + if self.shared_experts is not None: + full_shared_final_hidden_states = torch.empty_like( + full_hidden_states) def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_size = chunk_end - chunk_start hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - assert (self.batched_hidden_states.size(0) # type: ignore + assert self.batched_hidden_states is not None + assert self.batched_router_logits is not None + # This is only true when DBO has been enabled in the config. + # Both tensors will have an outer dimension for the ubatch id + if self.batched_hidden_states.dim() == 3: + assert self.batched_router_logits.dim() == 3 + batch_buffer_idx = dbo_current_ubatch_id() + batched_hidden_states = self.batched_hidden_states[ + batch_buffer_idx, :] + batched_router_logits = self.batched_router_logits[ + batch_buffer_idx, :] + else: + batched_hidden_states = self.batched_hidden_states + batched_router_logits = self.batched_router_logits + + assert (batched_hidden_states.size(0) # type: ignore >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore + assert (batched_router_logits.size(0) # type: ignore >= chunk_size) - staged_hidden_states = self.batched_hidden_states[: - chunk_size, :] # type: ignore - staged_router_logits = self.batched_router_logits[: - chunk_size, :] # type: ignore + staged_hidden_states = batched_hidden_states[: + chunk_size, :] # type: ignore + staged_router_logits = batched_router_logits[: + chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) @@ -1652,20 +1814,40 @@ class FusedMoE(CustomOp): logical_replica_count=self.logical_replica_count, ) + assert self.shared_experts is None or isinstance( + final_hidden_states, tuple) + if not skip_result_store: - full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states, non_blocking=True) + if self.shared_experts is None: + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states, + non_blocking=True) + else: + full_shared_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[0], + non_blocking=True) + full_fused_final_hidden_states[ + chunk_start:chunk_end, :].copy_(final_hidden_states[1], + non_blocking=True) ctx = get_forward_context() # flashinfer_cutlass_kernels can handle: optional DP + TP/EP - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + + # If the input to the MoE is sequence parallel then divide by sp_size + # to find the maximum number of tokens for any individual dispatcher. + if self.is_sequence_parallel: + max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers, + self.sp_size) + num_tokens = full_hidden_states.size(0) for chunk_idx, chunk_start_ in enumerate( - range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): + range(0, max_tokens_across_dispatchers, + moe_dp_chunk_size_per_rank)): chunk_start = chunk_start_ chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, - max_tokens_across_dp) + max_tokens_across_dispatchers) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) @@ -1675,25 +1857,45 @@ class FusedMoE(CustomOp): chunk_end, skip_result_store=chunk_start_ >= num_tokens) - return full_final_hidden_states + if self.shared_experts is None: + return full_fused_final_hidden_states + else: + return (full_shared_final_hidden_states, + full_fused_final_hidden_states) - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.quant_method is not None + + self.ensure_moe_quant_config() + # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. - use_flashinfer_cutlass_kernels = ( - self.dp_size > 1 - and self.moe_config.use_flashinfer_cutlass_kernels) + _use_flashinfer_cutlass_kernels = (self.dp_size > 1 and + self.use_flashinfer_cutlass_kernels) + if (self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels - or use_flashinfer_cutlass_kernels): + or _use_flashinfer_cutlass_kernels): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_config.use_flashinfer_cutlass_kernels) + + # If there are shared experts but we are not using a modular kernel, the + # shared experts must be called here + if (not isinstance(self.quant_method.fused_experts, + FusedMoEModularKernel) + and self.shared_experts is not None): + shared_output = self.shared_experts(hidden_states) + else: + shared_output = None + if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) @@ -1722,14 +1924,32 @@ class FusedMoE(CustomOp): logical_replica_count=self.logical_replica_count, ) - if do_naive_dispatch_combine: - final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs. - final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( - final_hidden_states) + if shared_output is not None: + assert not isinstance(final_hidden_states, tuple) + assert self.shared_experts is not None + final_hidden_states = ( + shared_output, + final_hidden_states, + ) - return final_hidden_states + def reduce_output(states: torch.Tensor, + do_combine: bool = True) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: + states = get_ep_group().combine(states) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + states = self.maybe_all_reduce_tensor_model_parallel(states) + + return states + + if self.shared_experts is None: + assert not isinstance(final_hidden_states, tuple) + return reduce_output(final_hidden_states) + else: + return ( + reduce_output(final_hidden_states[0], do_combine=False), + reduce_output(final_hidden_states[1]), + ) @classmethod def make_expert_params_mapping( @@ -1784,17 +2004,22 @@ class FusedMoE(CustomOp): return s -def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - assert self.quant_method is not None - + assert self.shared_experts is None return self.forward_impl(hidden_states, router_logits) -def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, - layer_name: str) -> torch.Tensor: +def moe_forward_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1807,6 +2032,37 @@ direct_register_custom_op( tags=(torch.Tag.needs_fixed_stride_order, ), ) + +def moe_forward_shared( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.shared_experts is not None + return self.forward_impl(hidden_states, router_logits) + + +def moe_forward_shared_fake( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + shared_out = torch.empty_like(hidden_states) + fused_out = torch.empty_like(hidden_states) + return shared_out, fused_out + + +direct_register_custom_op( + op_name="moe_forward_shared", + op_func=moe_forward_shared, + mutates_args=["hidden_states"], + fake_impl=moe_forward_shared_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + # Mark the FusedMoE weight_loader as supporting MoE-specific parameters # to avoid expensive runtime reflection in model loading code FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2ea6383d5ae90..729f8e39cf0f7 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from math import prod -from typing import Optional, final +from typing import Callable, Optional, Union, final import torch @@ -13,6 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv +from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook, + dbo_register_recv_hook, dbo_yield) # # This file defines a set of base classes used to make MoE kernels more modular. @@ -141,6 +143,29 @@ class TopKWeightAndReduce(ABC): raise NotImplementedError +# +# PrepareResultType is a tuple of: +# - quantized + dispatched a. +# - quantized + dispatched a1_scales. +# - Optional ExpertTokensMetadata containing gpu/cpu tensors +# as big as the number of local experts with the information about the +# number of tokens assigned to each local expert. +# - Optional dispatched expert topk IDs +# - Optional dispatched expert topk weight +# +# See `prepare` method below. +# +PrepareResultType = tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[ExpertTokensMetadata], + Optional[torch.Tensor], + Optional[torch.Tensor], +] + +ReceiverType = Callable[[], PrepareResultType] + + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -152,24 +177,56 @@ class FusedMoEPrepareAndFinalize(ABC): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[ExpertTokensMetadata], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: + ) -> PrepareResultType: """ - Perform any quantization (and/or) dispatching needed - for this kernel. + Perform any quantization (and/or) dispatching needed for this kernel. + - a1: The (unquantized) input to the MoE layer. + - topk_ids: The topk ids. + - topk_weights: The topk weights. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. + - quant_config: Quantization info provided by the fused experts. + + Returns a tuple of: + - quantized + dispatched a. + - Optional quantized + dispatched a1_scales. + - Optional ExpertTokensMetadata containing gpu/cpu tensors + as big as the number of local experts with the information about the + number of tokens assigned to each local expert. + - Optional dispatched expert topk IDs + - Optional dispatched expert topk weight + """ + raise NotImplementedError + + def supports_async(self) -> bool: + """ + Indicates whether or not this class implements prepare_async and + finalize_async. + """ + return False + + def prepare_async( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[Callable, ReceiverType]: + """ + Perform any quantization (and/or) dispatching needed for this kernel + but do not wait for results from other workers. - a1: The (unquantized) input to the MoE layer. - a1_scale: Optional scales for a1 - a2_scale: Optional scales for the second MoE gemm. Required to make @@ -182,14 +239,15 @@ class FusedMoEPrepareAndFinalize(ABC): - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - Returns a tuple of: - - quantized + dispatched a. - - quantized + dispatched a1_scales. - - Optional ExpertTokensMetadata containing gpu/cpu tensors - as big as the number of local experts with the information about the - number of tokens assigned to each local expert. - - Optional dispatched expert topk IDs - - Optional dispatched expert topk weight + Returns a callback that when invoked waits for results from other + workers and has the same return signature as `prepare`, e.g. + + receiver = obj.prepare_async(...) + a, a_scales, expert_meta, topk_ids, topk_weights = receiver() + + is equivalent to: + + a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...) """ raise NotImplementedError @@ -218,6 +276,42 @@ class FusedMoEPrepareAndFinalize(ABC): """ raise NotImplementedError + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + ) -> Callable: + """ + Perform any combine plus apply weights and perform a reduction on the + fused experts output but do not wait for results from other workers. + - output: The output tensor, written in place. Must be (M, K) shape. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. + - topk_weights: The weights to be applied to the fused_experts_output. + - topk_ids: The topk_ids. + - apply_router_weight_on_input: When False, apply the weights to + fused_expert_output. + - weight_and_reduce_impl: An optional TopKWeightAndReduce + implementation. + + Returns a callback that when invoked waits for results from other + workers and has the same return signature as `finalize`, e.g. + + receiver = obj.finalize_async(output, ...) + ... output not valid yet ... + receiver() + ... output valid here ... + + is equivalent to: + + obj.finalize(output, ...) + """ + raise NotImplementedError + @property @abstractmethod def activation_format(self) -> FusedMoEActivationFormat: @@ -241,7 +335,7 @@ class FusedMoEPrepareAndFinalize(ABC): def max_num_tokens_per_rank(self) -> Optional[int]: """ Some PrepareFinalize All2All implementations are batched. Meaning, - they can processes only as set of tokens at a time. This + they can process only as set of tokens at a time. This function returns the batch size i.e the maximum number of tokens the implementation can process at a time. Return None if there are no such restrictions. @@ -253,6 +347,7 @@ class FusedMoEPrepareAndFinalize(ABC): raise NotImplementedError +# TODO: add supported activations method (return string) class FusedMoEPermuteExpertsUnpermute(ABC): """ An abstract base class for the [Permute-Experts-Unpermute] step described @@ -261,12 +356,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def __init__( self, - quant_config: Optional[FusedMoEQuantConfig], + quant_config: FusedMoEQuantConfig, ): - if quant_config is not None: - self.quant_config = quant_config - else: - self.quant_config = FusedMoEQuantConfig() + """ + quant_config: Quantization parameters for this experts instance. + """ + self.quant_config = quant_config @property @abstractmethod @@ -278,6 +373,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ raise NotImplementedError + # + # Various helpers for accessing quantization parameters from the + # quant_config. + # + @property def quant_dtype(self) -> Optional[torch.dtype]: return self.quant_config.quant_dtype @@ -294,6 +394,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def per_out_ch_quant(self) -> bool: return self.quant_config.per_out_ch_quant + @property + def a1_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.a1_scale + + @property + def a2_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.a2_scale + + @property + def a1_gscale(self) -> Optional[torch.Tensor]: + return self.quant_config.a1_gscale + + @property + def a2_gscale(self) -> Optional[torch.Tensor]: + return self.quant_config.a2_gscale + + @property + def w1_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.w1_scale + + @property + def w2_scale(self) -> Optional[torch.Tensor]: + return self.quant_config.w2_scale + + @property + def w1_zp(self) -> Optional[torch.Tensor]: + return self.quant_config.w1_zp + + @property + def w2_zp(self) -> Optional[torch.Tensor]: + return self.quant_config.w2_zp + + @property + def w1_bias(self) -> Optional[torch.Tensor]: + return self.quant_config.w1_bias + + @property + def w2_bias(self) -> Optional[torch.Tensor]: + return self.quant_config.w2_bias + + @property + def g1_alphas(self) -> Optional[torch.Tensor]: + return self.quant_config.g1_alphas + + @property + def g2_alphas(self) -> Optional[torch.Tensor]: + return self.quant_config.g2_alphas + # TODO (bnell): make this return a CHUNK_SIZE or None instead? @abstractmethod def supports_chunking(self) -> bool: @@ -370,12 +518,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], @@ -392,7 +535,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights: A map of row to expert weights. Some implementations - choose to do weight application. + choose to do weight application. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first MoE layer. @@ -401,15 +544,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC): - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. - - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be - used for a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + used for a1. Result of quantization from prepare/finalize and not + from the FusedMoEQuantConfig. - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation @@ -435,6 +572,20 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, return None +class SharedResizableBuffer: + + def __init__(self): + self.buffer = None + + def get(self, shape: tuple[int, ...], device: torch.device, + dtype: torch.dtype): + shape_numel = prod(shape) + if (self.buffer is None or self.buffer.numel() < shape_numel + or self.buffer.device != device or self.buffer.dtype != dtype): + self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) + return self.buffer[:shape_numel].view(*shape) + + @final class FusedMoEModularKernel(torch.nn.Module): """ @@ -448,15 +599,20 @@ class FusedMoEModularKernel(torch.nn.Module): layer due to any layer specific state that may be used by the component objects. """ + fused_out_buffer = SharedResizableBuffer() + workspace13_buffer = SharedResizableBuffer() + workspace2_buffer = SharedResizableBuffer() def __init__( self, prepare_finalize: FusedMoEPrepareAndFinalize, fused_experts: FusedMoEPermuteExpertsUnpermute, + shared_experts: Optional[torch.nn.Module] = None, ): super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts + self.shared_experts = shared_experts assert prepare_finalize.activation_format == \ fused_experts.activation_formats[0], ( f"{prepare_finalize.__class__.__name__}." @@ -477,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts: int, local_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, ) -> torch.Tensor: @@ -496,12 +647,12 @@ class FusedMoEModularKernel(torch.nn.Module): # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = torch.empty(prod(workspace13_shape), - device=a1.device, - dtype=workspace_dtype) - workspace2 = torch.empty(prod(workspace2_shape), - device=a1.device, - dtype=workspace_dtype) + workspace13 = self.workspace13_buffer.get(workspace13_shape, + device=a1.device, + dtype=workspace_dtype) + workspace2 = self.workspace2_buffer.get(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) assert fused_out is None or fused_out.shape == fused_out_shape, ( f"fused_out {fused_out.shape} but expected {fused_out_shape}") @@ -519,12 +670,7 @@ class FusedMoEModularKernel(torch.nn.Module): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, a1q_scale=a1q_scale, - a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, @@ -545,12 +691,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts: int, local_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, ) -> torch.Tensor: @@ -576,12 +717,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts=global_num_experts, local_num_experts=local_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, a1q_scale=a1q_scale, - a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -593,9 +729,9 @@ class FusedMoEModularKernel(torch.nn.Module): (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, expert_tokens_meta) - fused_out = torch.empty(fused_out_shape, - device=a1q.device, - dtype=a1.dtype) + fused_out = self.fused_out_buffer.get(fused_out_shape, + device=a1q.device, + dtype=a1.dtype) def slice_input_tensors( chunk_idx: int @@ -603,9 +739,13 @@ class FusedMoEModularKernel(torch.nn.Module): Optional[torch.Tensor], torch.Tensor, torch.Tensor]: s = chunk_idx * CHUNK_SIZE e = min(s + CHUNK_SIZE, M) - return (a1q[s:e], _chunk_scales(a1q_scale, s, e), - _chunk_scales(a2_scale, s, - e), topk_ids[s:e], topk_weights[s:e]) + return ( + a1q[s:e], + _chunk_scales(a1q_scale, s, e), + _chunk_scales(self.fused_experts.a2_scale, s, e), + topk_ids[s:e], + topk_weights[s:e], + ) def slice_output_tensor(chunk_idx: int) -> torch.Tensor: assert fused_out.size(0) % M == 0, ( @@ -662,12 +802,7 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts=global_num_experts, local_num_experts=local_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, a1q_scale=c_a1q_scale, - a2_scale=c_a2_scale, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -685,14 +820,8 @@ class FusedMoEModularKernel(torch.nn.Module): activation: str = "silu", global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -713,14 +842,6 @@ class FusedMoEModularKernel(torch.nn.Module): - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. - - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. @@ -730,24 +851,53 @@ class FusedMoEModularKernel(torch.nn.Module): """ a1 = hidden_states - output = a1 if inplace else torch.zeros_like(a1) + if inplace and self.shared_experts is None: + output = a1 + else: + output = torch.zeros_like(a1) local_num_experts = w1.size(0) if global_num_experts == -1: global_num_experts = local_num_experts - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = self.prepare_finalize.prepare( - a1, - a1_scale, - a2_scale, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) + if not self.prepare_finalize.supports_async(): + # We shouldn't be running an a2a kernel that doesn't + # support async prepare/finalize + assert not dbo_enabled() + + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = self.prepare_finalize.prepare( + a1, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + else: + # Overlap shared expert compute with all2all dispatch. + dbo_maybe_run_recv_hook() + hook, receiver = self.prepare_finalize.prepare_async( + a1, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + + # If DBO is being used, register the hook with the ubatch context + # and call it in dbo_maybe_run_recv_hook instead of passing it to + # the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + if not dbo_enabled(): + hook() + + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, + _expert_topk_weights) = receiver() # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids @@ -776,23 +926,47 @@ class FusedMoEModularKernel(torch.nn.Module): global_num_experts=global_num_experts, local_num_experts=local_num_experts, expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, a1q_scale=a1q_scale, - a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, ) - self.prepare_finalize.finalize( - output, - fused_out, - topk_weights, - topk_ids, - apply_router_weight_on_input, - self.fused_experts.finalize_weight_and_reduce_impl(), - ) + shared_output: Optional[torch.Tensor] = None - return output + if not self.prepare_finalize.supports_async(): + assert not dbo_enabled() + + self.prepare_finalize.finalize( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + else: + recv_hook = self.prepare_finalize.finalize_async( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + + assert recv_hook is not None + dbo_register_recv_hook(recv_hook) + dbo_yield() + if not dbo_enabled(): + recv_hook() + + if self.shared_experts is None: + return output + else: + assert shared_output is not None + return shared_output, output diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 401f37922b7bb..ddddd2a3b7a2e 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Callable, Optional, Union import pplx_kernels as pplx import torch @@ -84,25 +84,24 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.int32 + return torch.uint32 def num_dispatchers(self) -> int: return self.num_dispatchers_ - def prepare( + def supports_async(self) -> bool: + return True + + def prepare_async( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> tuple[Callable, mk.ReceiverType]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K @@ -129,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): repeat_cols = 4 repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) + # TODO(bnell): always pass quant_config.a1_scale? a1q, a1q_scale = moe_kernel_quantize_input( - a1, (None if quant_config.per_act_token_quant else a1_scale), + a1, (None if quant_config.per_act_token_quant else + quant_config.a1_scale), quant_dtype=quant_config.quant_dtype, per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape) @@ -138,6 +139,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape) + orig_a_scale_block_shape: Optional[int] = None + if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -205,10 +208,39 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=topk_ids.view(dtype=torch.uint32), + indices=topk_ids, bound_m=bound_m, + do_send=True, + do_recv=False, ) + hook = lambda: self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=topk_ids, + bound_m=bound_m, + do_send=False, + do_recv=True, + ) + + return (hook, lambda: self._receiver( + expert_num_tokens, + expert_x, + expert_x_scale, + orig_a_scale_block_shape, + )) + + def _receiver( + self, + expert_num_tokens: torch.Tensor, + expert_x: torch.Tensor, + expert_x_scale: Optional[torch.Tensor], + orig_a_scale_block_shape: Optional[int], + ) -> mk.PrepareResultType: + if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 @@ -218,7 +250,29 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): return expert_x, expert_x_scale, expert_tokens_meta, None, None - def finalize( + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + hook, receiver = self.prepare_async( + a1, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + hook() + return receiver() + + def finalize_async( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -226,7 +280,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + ) -> Callable: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") @@ -249,8 +303,39 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) + topk_ids_u32 = topk_ids.view(dtype=torch.uint32) + self.a2a.combine(out_tokens=output, - indices=topk_ids.view(dtype=torch.uint32), + indices=topk_ids_u32, weights=topk_weights, expert_y=fused_expert_output, - bound_m=bound_m) + bound_m=bound_m, + do_send=True, + do_recv=False) + + return lambda: self.a2a.combine(out_tokens=output, + indices=topk_ids_u32, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=False, + do_recv=True) + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + receiver = self.finalize_async( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + ) + receiver() diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 567a0a88fec0a..588e5de865dd9 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -30,17 +30,13 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): def prepare( self, a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -50,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): a1.mul_(topk_weights.to(a1.dtype)) a1q, a1q_scale = moe_kernel_quantize_input( - a1, a1_scale, quant_config.quant_dtype, + a1, quant_config.a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index f14f13e2ade9d..f4972ff5f9cb0 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -7,6 +7,8 @@ from typing import Optional import torch from vllm import envs +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk( def rocm_aiter_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - per_channel_quant: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + expert_map: Optional[torch.Tensor] = None, + quant_config: Optional[FusedMoEQuantConfig] = None, +) -> torch.Tensor: + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG activation_method = (ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU) @@ -333,7 +332,8 @@ def rocm_aiter_fused_experts( expert_mask = None # w8a8 per-channel quantization - if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + if (quant_config.per_act_token_quant and apply_router_weight_on_input + and quant_config.use_fp8_w8a8): # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # This applies topk_weights on the GEMM output of the first FC layer # rather than the second FC. @@ -349,8 +349,8 @@ def rocm_aiter_fused_experts( w2, topk_weights, topk_ids, - fc1_scale=w1_scale, - fc2_scale=w2_scale, + fc1_scale=quant_config.w1_scale, + fc2_scale=quant_config.w2_scale, fc1_smooth_scale=None, fc2_smooth_scale=None, a16=False, @@ -362,14 +362,14 @@ def rocm_aiter_fused_experts( quant_method = QuantMethod.NO.value # w8a8 block-scaled - if block_shape is not None and use_fp8_w8a8: + if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is\ not supported for block scaled moe") - assert w1_scale is not None - assert w2_scale is not None + assert quant_config.w1_scale is not None + assert quant_config.w2_scale is not None quant_method = QuantMethod.BLOCK_128x128.value - elif use_fp8_w8a8: + elif quant_config.use_fp8_w8a8: # Currently only per tensor quantization method is enabled. quant_method = QuantMethod.PER_TENSOR.value @@ -390,10 +390,10 @@ def rocm_aiter_fused_experts( expert_mask=expert_mask, quant_method=quant_method, activation_method=activation_method, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + w1_scale=quant_config.w1_scale, + w2_scale=quant_config.w2_scale, + a1_scale=quant_config.a1_scale, + a2_scale=quant_config.a2_scale, doweight_stage1=apply_router_weight_on_input) @@ -420,9 +420,8 @@ def shuffle_weights( Args: *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the - block sizes used to divide the tensors during shuffling. - Default is (16, 16). + layout: A pair of integers specifying the block sizes used to divide + the tensors during shuffling. Default is (16, 16). Returns: A Tuple of shuffled tensors. diff --git a/vllm/model_executor/layers/fused_moe/routing_simulator.py b/vllm/model_executor/layers/fused_moe/routing_simulator.py index c8b107f13cd0d..8758a570b3c63 100644 --- a/vllm/model_executor/layers/fused_moe/routing_simulator.py +++ b/vllm/model_executor/layers/fused_moe/routing_simulator.py @@ -10,7 +10,7 @@ like uniform random routing. """ from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional import torch @@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy): distributions for testing different routing patterns. """ - def __init__(self, distribution: str = "uniform", **distribution_params): + def __init__(self, + distribution: str = "uniform", + **distribution_params: Any): """ Initialize distribution-based routing. @@ -244,7 +246,7 @@ class RoutingSimulator: cls._routing_strategies[name] = strategy @classmethod - def get_available_strategies(cls): + def get_available_strategies(cls) -> list[str]: """ Get list of available routing strategy names. diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 6cd81d97f0298..b2dbc306a6148 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -7,7 +7,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, + DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used @@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, - per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + quant_config: FusedMoEQuantConfig, allow_deep_gemm: bool = False, ): - super().__init__( - FusedMoEQuantConfig.make( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - )) - self.triton_expert = TritonExperts( - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int4_w4a16=use_int4_w4a16, - use_int8_w8a16=use_int8_w8a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - ) + super().__init__(quant_config) - self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and + self.triton_expert = TritonExperts(quant_config) + + self.allow_deep_gemm = (allow_deep_gemm + and self.quant_config.use_fp8_w8a8 and self.block_shape == deep_gemm_block_shape()) self.deep_gemm_expert = DeepGemmExperts( - ) if self.allow_deep_gemm else None + self.quant_config) if self.allow_deep_gemm else None @property def activation_formats( @@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): activation, global_num_experts, expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, a1q_scale, - a2_scale, workspace13, workspace2, expert_tokens_meta, diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 14dfce4b0e3aa..8e5f6acc9df63 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -5,7 +5,8 @@ from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.utils import next_power_of_2 @@ -16,20 +17,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, moe: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, - w13_bias, - w2_bias, max_capture_size, ): - super().__init__(moe.quant_config) + super().__init__(quant_config) self.moe = moe self.gemm1_alpha = gemm1_alpha self.gemm1_beta = gemm1_beta self.gemm1_clamp_limit = gemm1_clamp_limit - self.w13_bias = w13_bias - self.w2_bias = w2_bias self.max_capture_size = max_capture_size @property @@ -104,12 +102,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -129,8 +122,8 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( torch.bfloat16).view(torch.int16) - assert w1_scale is not None - assert w2_scale is not None + assert self.w1_scale is not None + assert self.w2_scale is not None kwargs = { "topk_ids": packed_tensor, @@ -143,9 +136,9 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): "gemm1_weights": w1, "gemm1_weights_scale": - w1_scale, + self.w1_scale, "gemm1_bias": - self.w13_bias, + self.w1_bias, "gemm1_alpha": self.gemm1_alpha, "gemm1_beta": @@ -155,7 +148,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): "gemm2_weights": w2, "gemm2_weights_scale": - w2_scale, + self.w2_scale, "gemm2_bias": self.w2_bias, "output1_scale_scalar": diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 1aeb3f92bc3ea..678942e568d86 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -268,3 +268,7 @@ def _validate_scale_shape( assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" + + +def activation_without_mul(activation: str) -> str: + return activation + "_no_mul" diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a5fc1db2dc10f..f875f712ba9c9 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -9,11 +9,11 @@ import torch.nn as nn import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_RMSNORM \ + return envs.VLLM_ROCM_USE_AITER_RMSNORM \ and envs.VLLM_ROCM_USE_AITER @@ -43,8 +43,22 @@ def fused_add_rms_norm( return x, residual -def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + from vllm import _custom_ops as ops + out = torch.empty_like(x) + ops.poly_norm( + out, + x, + weight, + bias, + variance_epsilon, + ) + return out + + +def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: import aiter as rocm_aiter if x.dim() > 2: x_original_shape = x.shape @@ -55,7 +69,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, return rocm_aiter.rms_norm(x, weight, variance_epsilon) -def rocm_aiter_fused_add_rms_norm( +def rocm_aiter_rmsnorm2d_fwd_with_add_impl( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: @@ -74,14 +88,48 @@ def rocm_aiter_fused_add_rms_norm( return output, residual_out -def dispatch_cuda_rmsnorm_func(add_residual: bool): - if add_residual: - if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_fused_add_rms_norm - return fused_add_rms_norm +def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + return torch.empty_like(x) - if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_rms_norm + +def rocm_aiter_rmsnorm2d_fwd_with_add_fake( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(residual) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=rocm_aiter_rms_norm_impl, + mutates_args=[], + fake_impl=rocm_aiter_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add", + op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, + mutates_args=[], + fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): + use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ + torch.float16, torch.bfloat16 + ] + + if use_aiter and with_fused_add: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + if use_aiter: + return torch.ops.vllm.rocm_aiter_rms_norm + + # fall back to CUDA implementation + if with_fused_add: + return fused_add_rms_norm return rms_norm @@ -114,6 +162,13 @@ class RMSNorm(CustomOp): self.weight = torch.ones(hidden_size) if self.has_weight: self.weight = nn.Parameter(self.weight) + weight_dtype = self.weight.data.dtype + + if current_platform.is_rocm(): + self.rocm_norm_func = dispatch_rocm_rmsnorm_func( + with_fused_add=False, dtype=weight_dtype) + self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( + with_fused_add=True, dtype=weight_dtype) def forward_native( self, @@ -162,13 +217,27 @@ class RMSNorm(CustomOp): return self.forward_native(x, residual) add_residual = residual is not None - norm_func = dispatch_cuda_rmsnorm_func(add_residual) - if add_residual: - return norm_func(x, residual, self.weight.data, - self.variance_epsilon) + return fused_add_rms_norm(x, residual, self.weight.data, + self.variance_epsilon) else: - return norm_func(x, self.weight.data, self.variance_epsilon) + return rms_norm(x, self.weight.data, self.variance_epsilon) + + def forward_hip( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + add_residual = residual is not None + if add_residual: + return self.rocm_norm_func_with_add(x, residual, self.weight.data, + self.variance_epsilon) + else: + return self.rocm_norm_func(x, self.weight.data, + self.variance_epsilon) def forward_xpu( self, @@ -265,3 +334,48 @@ class GemmaRMSNorm(CustomOp): self.forward_static) self._is_compiled = True return self.forward_native(x, residual) + + +@CustomOp.register("poly_norm") +class PolyNorm(CustomOp): + """Polynomial normalization. + + Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b + where w_n is the learned weight and b is the bias. + Refer to https://arxiv.org/html/2411.03884v1 + """ + + def __init__( + self, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(3) / 3) + self.bias = torch.nn.Parameter(torch.zeros(1)) + self.variance_epsilon = eps + + def _norm(self, x): + return x / torch.sqrt( + x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward(). + + Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md + """ + + orig_dtype = x.dtype + x_float = x.to(torch.float32) + output = (self.weight[0] * self._norm(x_float**3) + + self.weight[1] * self._norm(x_float**2) + + self.weight[2] * self._norm(x_float) + self.bias) + return output.to(orig_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return poly_norm(x, self.weight, self.bias, self.variance_epsilon) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 19ff63145024f..5bf96398bc710 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter -from vllm import envs from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -30,6 +29,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, # yapf: enable from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.utils import GiB_bytes logger = init_logger(__name__) @@ -191,35 +191,36 @@ class UnquantizedLinearMethod(LinearMethodBase): output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + # This method creates unquantized linear weights. + # The weights are not quantized, and they are not sharded. + # The amount of memory allocated for the weights is + # sum(output_partition_sizes) * input_size_per_partition. + try: + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + except torch.cuda.OutOfMemoryError as e: + logger.error("Failed to create unquantized linear weights: %s", e) + if torch.cuda.is_available(): + logger.debug("CUDA device: %s", torch.cuda.current_device()) + logger.debug("Allocated: %.2f GiB", + torch.cuda.memory_allocated() / GiB_bytes) + logger.debug("Reserved: %.2f GiB", + torch.cuda.memory_reserved() / GiB_bytes) + raise RuntimeError( + "Failed to create unquantized linear weights. " + "This may be caused by insufficient memory to allocate " + "the weight.") from e set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # special postprocessing for CPU SGL - if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: - from vllm.model_executor.layers.utils import check_cpu_sgl_kernel - N, K = layer.weight.size() - dtype = layer.weight.dtype - if check_cpu_sgl_kernel(N, K, dtype): - packed_weight = torch.ops._C.convert_weight_packed( - layer.weight) - assert packed_weight.size() == layer.weight.size() - layer.weight.copy_(packed_weight) - if layer.bias is not None: - layer.bias = Parameter(layer.bias.to(torch.float32), - requires_grad=False) - layer.use_cpu_sgl = True - else: - logger.warning( - "CPU SGL kernels require Intel AMX support," - " bf16/fp16/int8 weight, IC and OC are divisible by " - "32 and 16.") - layer.use_cpu_sgl = False + if current_platform.is_cpu(): + from vllm.model_executor.layers.utils import ( + dispatch_cpu_unquantized_gemm) + dispatch_cpu_unquantized_gemm(layer, remove_weight=True) def apply(self, layer: torch.nn.Module, @@ -240,6 +241,7 @@ class LinearBase(CustomOp): quant_config: Quantization configure. prefix: Prefix for parameter names. return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, tensor parallelism will be disabled for this layer. """ def __init__( @@ -252,6 +254,7 @@ class LinearBase(CustomOp): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): super().__init__() @@ -271,6 +274,17 @@ class LinearBase(CustomOp): self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias + self.disable_tp = disable_tp + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) + + def update_param_tp_status(self): + for param in self.parameters(): + if isinstance(param, BasevLLMParameter): + param.tp_rank = self.tp_rank + param.tp_size = self.tp_size @CustomOp.register("replicated_linear") @@ -287,6 +301,7 @@ class ReplicatedLinear(LinearBase): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: Take no effect for replicated linear layers. """ def __init__( @@ -300,26 +315,21 @@ class ReplicatedLinear(LinearBase): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): - # If MergedReplicatedLinear, use output size of each partition. - if hasattr(self, "output_sizes"): - self.output_partition_sizes = self.output_sizes - else: - self.output_partition_sizes = [output_size] - super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) # All the linear layer supports quant method. assert self.quant_method is not None self.quant_method.create_weights(self, - self.input_size, - self.output_partition_sizes, + self.input_size, [self.output_size], self.input_size, self.output_size, self.params_dtype, @@ -375,74 +385,6 @@ class ReplicatedLinear(LinearBase): return s -class MergedReplicatedLinear(ReplicatedLinear): - """Replicated linear layer. - - Args: - input_size: input dimension of the linear layer. - output_sizes: list of output dimensions of the linear layer. - bias: If true, add bias. - skip_bias_add: If true, skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) - return_bias: If true, return bias together with outputs in forward pass. - """ - - def __init__( - self, - input_size: int, - output_sizes: list[int], - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - ): - self.output_sizes = output_sizes - super().__init__(input_size, - sum(output_sizes), - bias, - skip_bias_add, - params_dtype, - quant_config, - prefix=prefix, - return_bias=return_bias) - - def weight_loader(self, - param: Union[Parameter, BasevLLMParameter], - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): - assert loaded_shard_id is not None - assert loaded_shard_id < len(self.output_sizes) - - if isinstance(param, BlockQuantScaleParameter): - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8LinearMethod, Fp8MoEMethod) - assert self.quant_method is not None - assert isinstance(self.quant_method, - (Fp8LinearMethod, Fp8MoEMethod)) - weight_block_size = self.quant_method.quant_config.weight_block_size - assert weight_block_size is not None - block_n, _ = weight_block_size[0], weight_block_size[1] - shard_offset = ( - (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) - shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n) - elif isinstance(param, PerTensorScaleParameter): - shard_offset = loaded_shard_id - shard_size = 1 - else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) - shard_size = self.output_sizes[loaded_shard_id] - - param.data[shard_offset:shard_offset + shard_size] = loaded_weight - - @CustomOp.register("column_parallel_linear") class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -465,7 +407,9 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -481,9 +425,13 @@ class ColumnParallelLinear(LinearBase): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] @@ -500,7 +448,8 @@ class ColumnParallelLinear(LinearBase): params_dtype, quant_config, prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) self.gather_output = gather_output @@ -528,8 +477,7 @@ class ColumnParallelLinear(LinearBase): }) else: self.register_parameter("bias", None) - - self.tp_rank = get_tensor_model_parallel_rank() + self.update_param_tp_status() def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): @@ -571,7 +519,8 @@ class ColumnParallelLinear(LinearBase): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + def weight_loader_v2(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: @@ -587,7 +536,7 @@ class ColumnParallelLinear(LinearBase): # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) - if self.gather_output: + if self.gather_output and self.tp_size > 1: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) else: @@ -601,7 +550,7 @@ class ColumnParallelLinear(LinearBase): s = f"in_features={self.input_size}" s += f", output_features={self.output_size_per_partition}" s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", tp_size={self.tp_size}" s += f", gather_output={self.gather_output}" return s @@ -628,6 +577,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, all weights matrix won't be sharded, this layer + will be treated as a "Replicated" MergedLinear. """ def __init__( @@ -642,10 +593,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): self.output_sizes = output_sizes - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) assert all(output_size % self.tp_size == 0 for output_size in output_sizes) @@ -657,7 +611,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) def weight_loader(self, param: Parameter, @@ -722,8 +677,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -756,8 +711,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -803,7 +758,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): """ Handle special case for models where MLP layers are already fused on disk. In this case, we have no shard id. This function - determmines the shard id by splitting these layers and then calls + determines the shard id by splitting these layers and then calls the weight loader using the shard id. An example of a model with these fused layers: @@ -849,30 +804,28 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) - tp_size = get_tensor_model_parallel_world_size() - if isinstance(param, BlockQuantScaleParameter): - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8LinearMethod, Fp8MoEMethod) assert self.quant_method is not None - assert isinstance(self.quant_method, - (Fp8LinearMethod, Fp8MoEMethod)) - weight_block_size = self.quant_method.quant_config.weight_block_size + # Assume the weight block size has been set by quant method + assert hasattr(self, "weight_block_size") + weight_block_size = self.weight_block_size assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = ( (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) // tp_size + block_n) // self.tp_size shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n // tp_size) + block_n // self.tp_size) else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum( + self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=loaded_shard_id, shard_offset=shard_offset, - shard_size=shard_size) + shard_size=shard_size, + tp_rank=self.tp_rank) class QKVParallelLinear(ColumnParallelLinear): @@ -900,6 +853,7 @@ class QKVParallelLinear(ColumnParallelLinear): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -915,6 +869,7 @@ class QKVParallelLinear(ColumnParallelLinear): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): self.hidden_size = hidden_size self.head_size = head_size @@ -923,7 +878,8 @@ class QKVParallelLinear(ColumnParallelLinear): total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 @@ -949,7 +905,8 @@ class QKVParallelLinear(ColumnParallelLinear): params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { @@ -973,7 +930,7 @@ class QKVParallelLinear(ColumnParallelLinear): """ Handle special case for models where QKV layers are already fused on disk. In this case, we have no shard id. This function - determmines the shard id by splitting these layers and then calls + determines the shard id by splitting these layers and then calls the weight loader using the shard id. An example of a model with these fused layers: @@ -1010,10 +967,13 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_shard_id: Optional[str] = None): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): - param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) + param.load_qkv_weight(loaded_weight=loaded_weight, + shard_id=0, + tp_rank=self.tp_rank) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): - param.load_qkv_weight(loaded_weight=loaded_weight) + param.load_qkv_weight(loaded_weight=loaded_weight, + tp_rank=self.tp_rank) return # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) @@ -1027,8 +987,10 @@ class QKVParallelLinear(ColumnParallelLinear): # Note(simon): This is needed for Qwen3's fp8 quantization. if isinstance(param, BlockQuantScaleParameter): assert self.quant_method is not None - assert hasattr(self.quant_method, "quant_config") - weight_block_size = self.quant_method.quant_config.weight_block_size + # Assume the weight block size has been set by quant method + assert hasattr(self, "weight_block_size") + weight_block_size = self.weight_block_size + assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = (shard_offset + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n @@ -1037,7 +999,8 @@ class QKVParallelLinear(ColumnParallelLinear): num_heads=self.num_kv_head_replicas, shard_id=loaded_shard_id, shard_offset=shard_offset, - shard_size=shard_size) + shard_size=shard_size, + tp_rank=self.tp_rank) def weight_loader(self, param: Parameter, @@ -1107,8 +1070,8 @@ class QKVParallelLinear(ColumnParallelLinear): # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( @@ -1155,8 +1118,8 @@ class QKVParallelLinear(ColumnParallelLinear): # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.packed_factor + shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( @@ -1243,6 +1206,7 @@ class RowParallelLinear(LinearBase): prefix: The name of the layer in the state dict, including all parents (e.g. model.layers.0.down_proj) return_bias: If true, return bias together with outputs in forward pass. + disable_tp: If true, weights matrix won't be sharded through tp rank. """ def __init__( @@ -1258,10 +1222,13 @@ class RowParallelLinear(LinearBase): prefix: str = "", *, return_bias: bool = True, + disable_tp: bool = False, ): # Divide the weight matrix along the first dimension. - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] @@ -1272,7 +1239,8 @@ class RowParallelLinear(LinearBase): params_dtype, quant_config, prefix, - return_bias=return_bias) + return_bias=return_bias, + disable_tp=disable_tp) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -1301,6 +1269,7 @@ class RowParallelLinear(LinearBase): }) else: self.register_parameter("bias", None) + self.update_param_tp_status() def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): input_dim = getattr(param, "input_dim", None) @@ -1356,10 +1325,9 @@ class RowParallelLinear(LinearBase): if self.input_is_parallel: input_parallel = input_ else: - tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() + input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. assert self.quant_method is not None diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index e93be9bfb1657..8a4ac214443eb 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import torch -import torch.nn as nn import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: envs.VLLM_LOGITS_PROCESSOR_THREADS) -class LogitsProcessor(nn.Module): +@CustomOp.register("logits_processor") +class LogitsProcessor(CustomOp): """Process logits and apply logits processors from sampling metadata. This layer does the following: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index d93cef1a27ad4..5fe37a6289e01 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -83,17 +83,7 @@ class MiniMaxText01RMSNormTP(CustomOp): variance = tensor_model_parallel_all_reduce( variance) / self.tp_world x = x * torch.rsqrt(variance + self.variance_epsilon) - - weight = self.weight - if x.size(-1) != self.weight.size(0): - if self.weight.size(0) < x.size(-1): - repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1) - full_weight = self.weight.repeat(repeat_count) - weight = full_weight[:x.size(-1)] - else: - weight = self.weight[:x.size(-1)] - - x = x.to(orig_dtype) * weight + x = x.to(orig_dtype) * self.weight return x def forward( diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 3256ac034aa11..c926e17a2c197 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -11,20 +11,20 @@ from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionMetadata) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.platforms import current_platform +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.mamba2_attn import ( Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets) @dataclass class Mamba2Metadata: - - has_initial_states: torch.Tensor prep_initial_states: bool - chunk_size: int - seq_idx: torch.Tensor - chunk_indices: torch.Tensor - chunk_offsets: torch.Tensor + + has_initial_states_p: torch.Tensor + seq_idx_p: torch.Tensor + chunk_indices_p: torch.Tensor + chunk_offsets_p: torch.Tensor """ With continuous batching layout of `x` in vLLM, to enable a Triton program to handle a request in parallel, two supporting tensors are used @@ -46,8 +46,8 @@ class Mamba2Metadata: """ nums_dict: Optional[dict] = None cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: @@ -68,7 +68,6 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: def prepare_mamba2_metadata( chunk_size: int, attn_metadata: AttentionMetadata, - mamba2_metadata=None, ) -> Mamba2Metadata: # compute number of prefill and decode requests @@ -76,11 +75,11 @@ def prepare_mamba2_metadata( num_prefills = attn_metadata.num_prefills num_prefill_tokens = attn_metadata.num_prefill_tokens - seq_idx = None - chunk_indices, chunk_offsets = None, None + seq_idx_p = None + chunk_indices_p, chunk_offsets_p = None, None # Need flags to indicate if there are initial states # currently we really only support the FlashAttention backend - has_initial_states = None + has_initial_states_p = None prep_initial_states = False # Compute seq_idx, chunk_indices and chunk_offsets for prefill only @@ -91,49 +90,36 @@ def prepare_mamba2_metadata( # precompute flag to avoid device syncs later in mamba2 layer # forwards # prep is only needed for mamba2 ssd prefill processing - has_initial_states = attn_metadata.context_lens_tensor > 0 - prep_initial_states = torch.any( - has_initial_states[:num_prefills]).item() - query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] - seq_idx = torch.repeat_interleave(torch.arange( - num_prefills, dtype=torch.int32, device=query_start_loc.device), - query_start_loc.diff(), - output_size=num_prefill_tokens) - seq_idx.unsqueeze_(0) + has_initial_states_p = ( + attn_metadata.context_lens_tensor[:num_prefills] > 0) + prep_initial_states = torch.any(has_initial_states_p).item() + query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1] + seq_idx_p = torch.repeat_interleave(torch.arange( + num_prefills, dtype=torch.int32, device=query_start_loc_p.device), + query_start_loc_p.diff(), + output_size=num_prefill_tokens) + seq_idx_p.unsqueeze_(0) # We compute metadata for chunked prefill once at the top level model # forward and reuse them in mamba layers. If not needed, they will be # ignored inside mamba kernels. if prep_initial_states: - chunk_indices, chunk_offsets = \ + chunk_indices_p, chunk_offsets_p = \ _query_start_loc_to_chunk_indices_offsets( - query_start_loc, chunk_size, num_prefill_tokens) + query_start_loc_p, chunk_size, num_prefill_tokens) - if mamba2_metadata is not None: - mamba2_metadata.has_initial_states = has_initial_states - mamba2_metadata.prep_initial_states = prep_initial_states - mamba2_metadata.chunk_size = chunk_size - mamba2_metadata.seq_idx = seq_idx - mamba2_metadata.chunk_indices = chunk_indices - mamba2_metadata.chunk_offsets = chunk_offsets - # We use 1 reset flag: - # * mamba2_metadata.cu_seqlen is None - # update config specific to (each input) - # (become available at first layer, e.g. conv_weights) - mamba2_metadata.cu_seqlen = None # suppose to be updated at each input - - return mamba2_metadata - return Mamba2Metadata(has_initial_states=has_initial_states, + return Mamba2Metadata(has_initial_states_p=has_initial_states_p, prep_initial_states=prep_initial_states, chunk_size=chunk_size, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets) + seq_idx_p=seq_idx_p, + chunk_indices_p=chunk_indices_p, + chunk_offsets_p=chunk_offsets_p) def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor, mamba2_metadata: Union[Mamba2Metadata, - Mamba2AttentionMetadata]): + Mamba2AttentionMetadata, + GDNAttentionMetadata]): """ this is triggered upon handling a new input at the first layer """ diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index bb3fdd38dbef3..02e6a9138c05f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -19,6 +19,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, @@ -261,12 +262,14 @@ class MambaMixer2(MambaBase, CustomOp): ), "Tensor parallel world size must divide num heads." assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( - "If tensor parallel world size does not divide num_heads, " + "If tensor parallel world size does not divide num_groups, " "then num_groups must equal 1.") - assert ( - self.tp_size == 1 or quant_config is None - ), "Tensor parallel currently not supported for quantized models." + assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \ + quant_config is None, ( + "Tensor parallel currently supported for quantized models only " + "if tensor parallel world size divides num groups." + ) self.ssm_state_size = ssm_state_size self.conv_kernel_size = conv_kernel_size @@ -285,92 +288,84 @@ class MambaMixer2(MambaBase, CustomOp): n_groups, self.tp_size) self.n_groups = n_groups + groups - self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size - self.conv1d = ColumnParallelLinear( - input_size=conv_kernel_size, - output_size=self.conv_dim, - bias=use_conv_bias, - quant_config=None, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + self.groups_ssm_state_size = self.n_groups * self.ssm_state_size + self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size - self.in_proj = ColumnParallelLinear( - input_size=hidden_size, - output_size=intermediate_size + self.conv_dim + self.num_heads, - bias=use_bias, - quant_config=quant_config, - ) + if n_groups % self.tp_size == 0: + self.conv1d = MergedColumnParallelLinear( + input_size=conv_kernel_size, + output_sizes=[ + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + ], + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - # - because in_proj is a concatenation of 3 weights, we - # need to interleave them before sharding - # - use the custom weight loader mamba_v2_sharded_weight_loader - # for conv1d.bias, covn1d.weight and in_proj.weight - # - need to set these settings, to assign the groups to the head shards - group_shard_settings = ( - self.n_groups * self.ssm_state_size, # expected model size - (self.n_groups - n_groups) * - self.ssm_state_size, # extra dims assigned - n_groups == 1, # if there was only one group - ) - intermediate_settings = (intermediate_size, 0, False) - head_settings = (self.num_heads, 0, False) + self.in_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[ + intermediate_size, + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + self.num_heads, + ], + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + else: + # This is the n_groups == 1 case, + # where we need to duplicate groups if TP>1. - # - the weight already has a "weight_loader" attribute - # which set_weight_attrs will raise if we do not - # delete before trying to override it - # - ditto for the otther two weights below - delattr(self.conv1d.bias, "weight_loader") - set_weight_attrs( - self.conv1d.bias, - { - "weight_loader": - mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": - mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) - if quant_config is None: - # - quant layers do not have a weight loader - delattr(self.in_proj.weight, "weight_loader") + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups + # to the head shards + group_shard_settings = ( + self.groups_ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + n_groups == 1, # if there was only one group + ) + intermediate_settings = (intermediate_size, 0, False) + head_settings = (self.num_heads, 0, False) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the other two weights below + delattr(self.conv1d.bias, "weight_loader") set_weight_attrs( - self.in_proj.weight, + self.conv1d.bias, { "weight_loader": mamba_v2_sharded_weight_loader( [ - intermediate_settings, # for gate intermediate_settings, group_shard_settings, group_shard_settings, - head_settings, # for dt ], self.tp_size, tp_rank, @@ -378,6 +373,50 @@ class MambaMixer2(MambaBase, CustomOp): }, ) + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) + + if quant_config is None: + # - quant layers do not have a weight loader + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_settings, # for dt + ], + self.tp_size, + tp_rank, + ) + }, + ) + + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `MergedColumnParallelLinear`, + # and `set_weight_attrs` doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + # - these are TPed by heads to reduce the size of the # temporal shape self.A = nn.Parameter( @@ -402,6 +441,7 @@ class MambaMixer2(MambaBase, CustomOp): bias=use_bias, input_is_parallel=True, quant_config=quant_config, + prefix=f"{prefix}.out_proj", ) self.norm = Mixer2RMSNormGated(intermediate_size, @@ -478,24 +518,19 @@ class MambaMixer2(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor - has_initial_states_p = attn_metadata.has_initial_states_p - prep_initial_states = attn_metadata.prep_initial_states - chunk_size = attn_metadata.chunk_size - seq_idx_p = attn_metadata.seq_idx_p - chunk_indices_p = attn_metadata.chunk_indices_p - chunk_offsets_p = attn_metadata.chunk_offsets_p else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state state_indices_tensor = mamba_cache_params.state_indices_tensor - has_initial_states_p = mamba2_metadata.has_initial_states + + # Common members between V1 metadata and V0 metadata + if mamba2_metadata is not None: + has_initial_states_p = mamba2_metadata.has_initial_states_p prep_initial_states = mamba2_metadata.prep_initial_states chunk_size = mamba2_metadata.chunk_size - seq_idx_p = mamba2_metadata.seq_idx - chunk_indices_p = mamba2_metadata.chunk_indices - chunk_offsets_p = mamba2_metadata.chunk_offsets - - groups_time_state_size = self.n_groups * self.ssm_state_size + seq_idx_p = mamba2_metadata.seq_idx_p + chunk_indices_p = mamba2_metadata.chunk_indices_p + chunk_offsets_p = mamba2_metadata.chunk_offsets_p # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -521,8 +556,8 @@ class MambaMixer2(MambaBase, CustomOp): hidden_states_B_C, [ self.intermediate_size // self.tp_size, - groups_time_state_size // self.tp_size, - groups_time_state_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, ], dim=-1, ) @@ -639,15 +674,9 @@ class MambaMixer2(MambaBase, CustomOp): # 3. State Space Model sequence transformation initial_states = None if (has_initial_states_p is not None and prep_initial_states): - # making a copy of the states - if envs.VLLM_USE_V1: - initial_states = torch.where( - has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) - else: - initial_states = torch.where( - has_initial_states_p[:num_prefills, None, None, None], - ssm_state[state_indices_tensor_p], 0) + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) # NOTE: final output is an in-place update of out tensor varlen_state = mamba_chunk_scan_combined( diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 1dc46639640b0..a6c1af91de421 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -70,6 +70,15 @@ class MambaStateDtypeCalculator: model_dtype) return (conv_state_dtype, ) + @classmethod + def gated_delta_net_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, torch.dtype]: + state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (state_dtype, state_dtype) + class MambaStateShapeCalculator: @@ -163,3 +172,31 @@ class MambaStateShapeCalculator: # for n_groups == 1, this is exactly tp_size - n_groups return tp_size - ngroups + + @classmethod + def gated_delta_net_state_shape( + cls, + tp_world_size: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_kernel_size: int, + num_spec: int = 0, + use_v1: bool = True, + ): + conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads) + conv_state_shape = ( + divide(conv_dim, tp_world_size), + conv_kernel_size - 1 + num_spec, + ) + + # In V0, the conv_state shape was swapped during allocation in + # MambaCacheManager, but in V1 it needs to be determined here at the + # calculation level + if use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + + temporal_state_shape = (divide(num_v_heads, + tp_world_size), head_k_dim, head_v_dim) + return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index b8d4bbc37105d..8cfd0962c5bfe 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -415,6 +415,9 @@ def causal_conv1d_fn( activation = "silu" args = None + # Store original dtype to cast back at the end + original_x_dtype = x.dtype + x = x.to(conv_states.dtype) out = torch.empty_like(x) if metadata is not None: cu_seqlen = metadata.cu_seqlen @@ -464,7 +467,9 @@ def causal_conv1d_fn( # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] # 4. computation can be skipped if cache_indices[idx] == pad_slot_id num_cache_lines = conv_states.size(0) - assert (num_cache_lines, dim, width - 1) == conv_states.shape + assert (num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2]) stride_istate_seq = conv_states.stride(0) stride_istate_dim = conv_states.stride(1) stride_istate_token = conv_states.stride(2) @@ -611,7 +616,7 @@ def causal_conv1d_fn( BLOCK_N=256, num_stages=2, ) - return out + return out.to(original_x_dtype) @triton.jit() @@ -623,6 +628,8 @@ def _causal_conv1d_update_kernel( conv_state_ptr, cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -639,6 +646,7 @@ def _causal_conv1d_update_kernel( stride_conv_state_seq: tl.constexpr, stride_conv_state_dim: tl.constexpr, stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, @@ -648,7 +656,9 @@ def _causal_conv1d_update_kernel( HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, BLOCK_N: tl.constexpr, @@ -663,8 +673,9 @@ def _causal_conv1d_update_kernel( if IS_CONTINUOUS_BATCHING: # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( - tl.int64) + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices).to( + tl.int64) else: conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa @@ -672,13 +683,51 @@ def _causal_conv1d_update_kernel( # not processing as this is not the actual sequence return + if IS_VARLEN: + query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to( + tl.int64) + # revise state_len and seqlen + state_len = state_len - (seqlen - + (query_end_index - query_start_index)) + seqlen = query_end_index - query_start_index + x_offset = query_start_index * stride_x_token + o_offset = query_start_index * stride_o_token + else: + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = ( + tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1) + else: + conv_state_token_offset = 0 + # STEP 1: READ init_state data conv_states_base = (conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim)) mask_w = idx_feats < dim - prior_tokens = conv_states_base + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok if KERNEL_WIDTH >= 2: conv_states_ptrs = prior_tokens # [BLOCK_N] col0 = tl.load(conv_states_ptrs, mask_w, 0.0) @@ -688,26 +737,32 @@ def _causal_conv1d_update_kernel( if KERNEL_WIDTH >= 4: conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: + if KERNEL_WIDTH >= 5: conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 6: + conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] + col4 = tl.load(conv_states_ptrs, mask_w, 0.0) # STEP 2: assume state_len > seqlen idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. conv_state_ptrs_source = ( conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * + stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N] mask = ((conv_state_batch_coord < num_cache_lines) & ((idx_tokens + seqlen) < state_len)[:, None] & (idx_feats < dim)[None, :]) conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim - ) # [BLOCK_N] + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] x_ptrs = x_base[None, :] + ( (idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] @@ -753,12 +808,18 @@ def _causal_conv1d_update_kernel( if KERNEL_WIDTH >= 4: w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 5: + w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor + w_col4 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 6: + w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor + w_col5 = tl.load(w_ptrs, mask_w, other=0.0) x_base_1d = x_base # starting of chunk [BLOCK_N] mask_x_1d = idx_feats < dim # STEP 5: compute each token - for idx_token in tl.static_range(seqlen): + for idx_token in tl.range(seqlen): acc = acc_preload matrix_w = w_col0 @@ -788,6 +849,37 @@ def _causal_conv1d_update_kernel( matrix_w = w_col3 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) acc += matrix_x * matrix_w # [BLOCK_N] @@ -800,14 +892,24 @@ def _causal_conv1d_update_kernel( col0 = col1 col1 = col2 col2 = matrix_x + elif KERNEL_WIDTH == 5: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x + elif KERNEL_WIDTH == 6: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) mask_1d = (idx_token < seqlen) & (idx_feats < dim ) # token-index # feature-index - o_ptrs = o_ptr + ( - idx_seq) * stride_o_seq + idx_token * stride_o_token + ( - idx_feats * stride_o_dim) + o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * + stride_o_dim) tl.store(o_ptrs, acc, mask=mask_1d) @@ -820,14 +922,21 @@ def causal_conv1d_update( activation: Union[bool, str, None] = None, cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, metadata=None, validate_data=False, ): """ - x: (batch, dim) or (batch, dim, seqlen) - [shape=2: single token prediction] - [shape=3: single or multiple tokens prediction] + x: Input tensor which can take the following shapes: + + - `[batch, dim]` - single token prediction + - `[batch, dim, seqlen]` - single or multiple tokens prediction + - `[num_tokens, dim]` - continuous batching, where num_tokens is + the total tokens of all sequences in that batch + conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) @@ -840,13 +949,24 @@ def causal_conv1d_update( If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. + num_accepted_tokens: (batch,), dtype int32 + If not None, it indicates the number of accepted tokens for each + sequence in the batch. + This is used in speculative decoding, where the conv_state is updated + in a sliding window manner. + query_start_loc: (batch + 1,) int32 + If not None, the inputs is given in a varlen fashion and this indicates + the starting index of each sequence in the batch. + max_query_len: int + If query_start_loc is not None, this indicates the maximum query + length in the batch. pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - out: (batch, dim) or (batch, dim, seqlen) + out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ if validate_data: assert cache_seqlens is None # not implemented yet - ok for vLLM @@ -856,11 +976,20 @@ def causal_conv1d_update( activation = "silu" if activation is True else None elif activation is not None: assert activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 if unsqueeze: # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + seqlen = max_query_len _, width = weight.shape # conv_state: (..., dim, state_len), where state_len >= width - 1 num_cache_lines, _, state_len = conv_state.size() @@ -886,14 +1015,25 @@ def causal_conv1d_update( out = x stride_w_dim, stride_w_width = weight.stride() - stride_x_seq, stride_x_dim, stride_x_token = x.stride( - ) # X (batch, dim, seqlen) - - stride_o_seq, stride_o_dim, stride_o_token = out.stride() + if query_start_loc is None: + # X (batch, dim, seqlen) + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + # X (dim, cu_seqlen) + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( ) - state_len = width - 1 + stride_state_indices = conv_state_indices.stride( + 0) if conv_state_indices is not None else 0 + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 np2_statelen = triton.next_power_of_2(state_len) def grid(META): @@ -910,6 +1050,8 @@ def causal_conv1d_update( conv_state, cache_seqlens, conv_state_indices, + num_accepted_tokens, + query_start_loc, out, # Matrix dimensions batch, @@ -926,6 +1068,7 @@ def causal_conv1d_update( stride_istate_seq, stride_istate_dim, stride_istate_token, + stride_state_indices, stride_o_seq, stride_o_dim, stride_o_token, @@ -935,11 +1078,13 @@ def causal_conv1d_update( HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], + IS_VARLEN=query_start_loc is not None, IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, BLOCK_N=256, ) if unsqueeze: out = out.squeeze(-1) - return out + return out.to(original_x_dtype) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 365139e237c66..fb8350e191c94 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel( # get the cs at the offset boundary # - c_off == 0 is a passthrough + # - We need dA_cs at the boundary, defined by c_off - no need + # to increase pointer by pid_m (it is a constant offset, + # i.e. the same for all blocks) dA_cs_m_boundary = tl.load( dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index ad58a9918f03c..a7b3c814859ce 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -502,7 +502,7 @@ def _chunk_state_varlen_kernel( dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk - # If HAS_INITSTATES==True need to consider two possiblties + # If HAS_INITSTATES==True need to consider two possibilities # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs # - if state_idx >= pid * chunk_size, then we need to insert initstates if ((start_idx < pid_c * chunk_size) # first chunk diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index d0b3e9e5235bf..fcc5c905bf77f 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x, # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx and iii) is_cont_batched to be all specified. + # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. + # - We will also make sure that the dA_cumsum is taken only from the start of the + # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], + dA_cumsum, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=state_dtype if state_dtype is not None else C.dtype, - is_cont_batched=cu_seqlens is not None) + is_cont_batched=cu_seqlens is not None, + chunk_offsets=chunk_offsets) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index a28fc9ffad71b..d61c3a8cdbe9c 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -31,6 +31,8 @@ def _state_passing_fwd_kernel( dA_cs_ptr, initstates_ptr, seq_idx_ptr, + chunk_offsets_ptr, + chunk_meta_num, # Matrix dimensions dim, nchunks, @@ -51,6 +53,7 @@ def _state_passing_fwd_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_dA_cs_csize, stride_initstates_batch, stride_initstates_head, stride_initstates_dim, @@ -66,7 +69,8 @@ def _state_passing_fwd_kernel( pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( + chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head if HAS_INITSTATES: @@ -95,35 +99,62 @@ def _state_passing_fwd_kernel( tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk - seq_idx = 0 + prev_seq_idx_chunk_end = 0 + logical_chunk_idx = 0 for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) + scale_mask = True if HAS_SEQ_IDX: # - the seq to pass forward is the one that is flushed to the right # boundary. - # - that is given by seq_idx_new below. - seq_idx_new = tl.load(seq_idx_ptr + - (min((c + 1) * chunk_size, seqlen) - 1) * - stride_seq_idx_seqlen) + # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. + seq_idx_chunk_end = tl.load(seq_idx_ptr + (min( + (c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) if HAS_INITSTATES: - if IS_CONT_BATCHED and seq_idx != seq_idx_new: + if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: # this means in the current chunk the rightmost flushed seq # has changed. # - so we do not propagate the state from previous chunk # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch # - update state with seq_idx_new's init state states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - else: - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) - seq_idx = seq_idx_new + # - we need to consider the cumsum only of the last sequence in the chunk + # - find its starting position (given by c_off of the logical chunk index) + # - and subtract the cumsum just before that position from the total cumsum + # - first, update the logical chunk index (add the number of sequences in the current physical chunk): + # sequence index at the start of the current chunk + seq_idx_chunk_start = tl.load(seq_idx_ptr + + min(c * chunk_size, seqlen) * + stride_seq_idx_seqlen) + logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start + # - load the chunk offset: + c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, + mask=logical_chunk_idx < chunk_meta_num, + other=0) + # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything + if c_off > 0: + # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset + dA_cs_boundary = tl.load( + dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + + (c_off - 1) * stride_dA_cs_csize, + mask=(c_off - 1) > -1 and c_off < chunk_size, + other=0.0) + dA_cs -= dA_cs_boundary + + # - increment logical chunk index for every physical chunk + logical_chunk_idx += 1 + else: + scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end + prev_seq_idx_chunk_end = seq_idx_chunk_end + + scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states if c < nchunks - 1: tl.store(out_ptrs, states, mask=offs_m < dim) @@ -136,28 +167,36 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, - dA_chunk_cumsum, + dA_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None, is_cont_batched=False, + chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if chunk_size is None: + chunk_size = dA_cumsum.shape[-1] + else: + assert chunk_size == dA_cumsum.shape[-1] + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if initial_states is not None: if is_cont_batched: # - if cu_seqlens is provided, then the initial states # are used for continuous batching. In which case we # require seq_idx to be provided - assert seq_idx is not None, "" + assert seq_idx is not None, "seq_idx must be provided for continuous batching" + # - we also need chunk_offsets to be provided, to account + # for computation of dA_cumsum from the start of the + # sequence + assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching" else: # - this is the regular batching case, where initial # states are used are for each example of the batch. assert initial_states.shape == (batch, nheads, dim) if seq_idx is not None: - assert chunk_size is not None seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) out_dtype = states.dtype if out_dtype is None else out_dtype @@ -173,13 +212,15 @@ def _state_passing_fwd( states, out, final_states, - dA_chunk_cumsum, + dA_cumsum, initial_states, seq_idx, + chunk_offsets, + len(chunk_offsets) if chunk_offsets is not None else 0, dim, nchunks, seqlen if seq_idx is not None else 0, - chunk_size if seq_idx is not None else 0, + chunk_size, states.stride(0), states.stride(1), states.stride(2), @@ -191,9 +232,10 @@ def _state_passing_fwd( final_states.stride(0), final_states.stride(1), final_states.stride(2), - dA_chunk_cumsum.stride(0), - dA_chunk_cumsum.stride(2), - dA_chunk_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) if initial_states is not None else (0, 0, 0)), diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py new file mode 100644 index 0000000000000..a05716190365f --- /dev/null +++ b/vllm/model_executor/layers/mla.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.attention import Attention +from vllm.config import CacheConfig +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization import QuantizationConfig + + +@dataclass +class MLAModules: + """Modules used in MLA. + """ + kv_a_layernorm: torch.nn.Module + kv_b_proj: torch.nn.Module + rotary_emb: torch.nn.Module + o_proj: torch.nn.Module + fused_qkv_a_proj: Optional[torch.nn.Module] + kv_a_proj_with_mqa: Optional[torch.nn.Module] + q_a_layernorm: Optional[torch.nn.Module] + q_b_proj: Optional[torch.nn.Module] + q_proj: Optional[torch.nn.Module] + + +@CustomOp.register("multi_head_latent_attention") +class MultiHeadLatentAttention(CustomOp): + """MLA layer registered as CustomOp. + Note that currently MLA ignores the enable/disable mechanism of CustomOp + because there is only one in-tree implementation in forward_native. + TODO: implement this with a new PluggableLayer mechanism. + + This class takes positions and hidden_states as input. + The input tensors can either contain prefill tokens or decode tokens. + The class does the following: + + 1. MLA Preprocess. + 2. Perform multi-head attention to prefill tokens and + multi-query attention to decode tokens separately. + 3. Return the output tensor. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + mla_modules: MLAModules, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj + self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa + self.q_a_layernorm = mla_modules.q_a_layernorm + self.q_b_proj = mla_modules.q_b_proj + self.q_proj = mla_modules.q_proj + self.kv_a_layernorm = mla_modules.kv_a_layernorm + self.kv_b_proj = mla_modules.kv_b_proj + self.rotary_emb = mla_modules.rotary_emb + self.o_proj = mla_modules.o_proj + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scale, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + kv_b_proj=self.kv_b_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + def forward_native( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + q_c = None + kv_lora = None + + if self.q_lora_rank is not None: + assert self.fused_qkv_a_proj is not None, \ + "fused_qkv_a_proj is required when q_lora_rank is not None" + assert self.q_a_layernorm is not None, \ + "q_a_layernorm is required when q_lora_rank is not None" + assert self.q_b_proj is not None, \ + "q_b_proj is required when q_lora_rank is not None" + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] + else: + assert self.kv_a_proj_with_mqa is not None, \ + "kv_a_proj_with_mqa is required when q_lora_rank is None" + assert self.q_proj is not None, \ + "q_proj is required when q_lora_rank is None" + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] + q = self.q_proj(hidden_states)[0] + + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], + dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c) + + q = q.view(-1, self.num_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim:], k_pe) + + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states.shape[0], + self.num_heads * self.v_head_dim)) + return self.o_proj(attn_out)[0] + + def forward_cuda(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 62b3ee1abaca8..4a97438b1bb2c 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,15 +5,16 @@ from collections.abc import Mapping, Set from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union, cast +from typing import Callable, Optional, TypeVar, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.config import ModelConfig, PoolerConfig +from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config from vllm.logger import init_logger +from vllm.model_executor.models.adapters import _load_st_projector from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.tasks import PoolingTask @@ -362,14 +363,13 @@ class PoolerIdentity(PoolerActivation): class PoolerNormalize(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - x = F.normalize(pooled_data.float(), p=2, dim=-1) - return x.to(pooled_data.dtype) + return F.normalize(pooled_data, p=2, dim=-1) class PoolerMultiLabelClassify(PoolerActivation): def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + return F.sigmoid(pooled_data) class PoolerClassify(PoolerActivation): @@ -378,7 +378,6 @@ class PoolerClassify(PoolerActivation): super().__init__() if static_num_labels: - from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() self.num_labels = getattr(vllm_config.model_config.hf_config, "num_labels", 0) @@ -394,9 +393,9 @@ class PoolerClassify(PoolerActivation): pooled_data.shape[-1]) if num_labels < 2: - return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + return F.sigmoid(pooled_data) - return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) + return F.softmax(pooled_data, dim=-1) class LambdaPoolerActivation(PoolerActivation): @@ -428,12 +427,11 @@ class EmbeddingPoolerHead(PoolerHead): super().__init__(activation=PoolerNormalize()) # Load ST projector if available - from vllm.config import get_current_vllm_config - from vllm.model_executor.models.adapters import _load_st_projector vllm_config = get_current_vllm_config() - self.projector = _load_st_projector( + self.projector: Optional[nn.Module] = _load_st_projector( vllm_config.model_config) if vllm_config else None + self.head_dtype = vllm_config.model_config.head_dtype def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): @@ -442,16 +440,11 @@ class EmbeddingPoolerHead(PoolerHead): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_dimension] + pooled_data = pooled_data.to(self.head_dtype) + # Apply ST projector if self.projector is not None: - projector = cast(nn.Module, self.projector) - - def _proj(x: torch.Tensor) -> torch.Tensor: - orig_dtype = x.dtype - y = projector(x.to(torch.float32)) - return y.to(orig_dtype) - - pooled_data = _proj(pooled_data) + pooled_data = self.projector(pooled_data) # pooled_data shape: [batchsize, embedding_dimension] pooling_params = get_pooling_params(pooling_metadata) @@ -494,8 +487,17 @@ class RewardPoolerHead(PoolerHead): def __init__(self) -> None: super().__init__(activation=PoolerClassify(static_num_labels=False)) + vllm_config = get_current_vllm_config() + self.head_dtype = vllm_config.model_config.head_dtype + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + + if isinstance(pooled_data, list): + pooled_data = [p.to(self.head_dtype) for p in pooled_data] + else: + pooled_data = pooled_data.to(self.head_dtype) + pooling_params = get_pooling_params(pooling_metadata) # for softmax @@ -633,9 +635,14 @@ class ClassifierPooler(Pooler): ) -> None: super().__init__() + vllm_config = get_current_vllm_config() + self.pooling = pooling self.classifier = classifier self.act_fn = act_fn or PoolerClassify() + self.logit_bias: Optional[ + float] = vllm_config.model_config.pooler_config.logit_bias + self.head_dtype = vllm_config.model_config.head_dtype def get_supported_tasks(self) -> Set[PoolingTask]: return {"classify", "score"} @@ -650,10 +657,15 @@ class ClassifierPooler(Pooler): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_size] + pooled_data = pooled_data.to(self.head_dtype) + if self.classifier is not None: pooled_data = self.classifier(pooled_data) # pooled_data shape: [batchsize, num_labels] + if self.logit_bias is not None: + pooled_data -= self.logit_bias + pooling_params = get_pooling_params(pooling_metadata) flags = [p.activation for p in pooling_params] @@ -714,3 +726,7 @@ class DispatchPooler(Pooler): offset += num_items return PoolerOutput(outputs) + + def extra_repr(self) -> str: + s = f"supported_task={self.get_supported_tasks()}" + return s diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index d73fcf368f261..8cac47b5a39a3 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -26,7 +26,6 @@ QuantizationMethods = Literal[ "bitsandbytes", "hqq", "experts_int8", - "neuron_quant", "ipex", "quark", "moe_wna16", @@ -108,7 +107,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config - from .neuron_quant import NeuronQuantConfig from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config from .rtn import RTNConfig @@ -135,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "ptpc_fp8": PTPCFp8Config, "hqq": HQQMarlinConfig, "experts_int8": ExpertsInt8Config, - "neuron_quant": NeuronQuantConfig, "ipex": IPEXConfig, "quark": QuarkConfig, "moe_wna16": MoeWNA16Config, diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index fb285413ba9ef..1ca92273430dd 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -327,6 +327,8 @@ class AutoRoundConfig(QuantizationConfig): if isinstance(layer, FusedMoE): if use_marlin: + return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe) + else: from vllm.model_executor.layers.quantization.moe_wna16 import ( MoeWNA16Config) @@ -339,7 +341,6 @@ class AutoRoundConfig(QuantizationConfig): } return MoeWNA16Config.from_config(config).get_quant_method( layer, prefix) - return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 5fcfbe7ce4f40..227a41a572232 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from torch.nn import Parameter @@ -9,8 +9,10 @@ from torch.nn import Parameter import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, @@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase): if hasattr(layer, "w2_bias") and layer.w2_bias is not None: layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -505,7 +511,7 @@ class AWQMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index ebc526d6db2f9..2e8894436a985 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -19,7 +19,7 @@ def awq_dequantize_kernel( num_rows, # input num rows in qweight BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr): - # Setup the pids. + # Set up the pids. pid_x = tl.program_id(axis=0) pid_y = tl.program_id(axis=1) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 4a43351260e9f..6fd94afbe5566 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -128,7 +128,7 @@ class QuantizationConfig(ABC): @staticmethod def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any: - """Get a optional value from the model's quantization config.""" + """Get an optional value from the model's quantization config.""" try: return QuantizationConfig.get_from_keys(config, keys) except ValueError: diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 39bd34d351f61..d05c0c0d5473c 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -202,7 +202,7 @@ class BitBLASLinearMethod(LinearMethodBase): output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, - ): + ) -> None: """Creates quantized weights for use in linear operations. The function initializes and returns a dictionary containing quantized @@ -211,7 +211,7 @@ class BitBLASLinearMethod(LinearMethodBase): Args: input_size_per_partition: The size of the input partition. - output_size_per_partition: The size of the output partition. + output_partition_sizes: List of output partition sizes. input_size: The total size of the input (unused). output_size: The total size of the output (unused). params_dtype: @@ -222,9 +222,9 @@ class BitBLASLinearMethod(LinearMethodBase): scales ('scales'), and zeros ('zeros'). Raises: - ValueError: If `params_dtype` is not `torch.float16` or if the - input size per partition is not divisible by the group size in - `quant_config`. + ValueError: If `params_dtype` is not `torch.float16` or if the input + size per partition is not divisible by the group size + in `quant_config`. """ del input_size, output_size # Unused arguments. weight_loader = extra_weight_attrs["weight_loader"] diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 9713757df9b07..650dab8df87e3 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -6,8 +6,9 @@ from typing import Any, Callable, Optional, Union import torch from packaging import version +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, @@ -452,6 +453,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): **extra_weight_attrs, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -474,7 +479,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: from vllm.model_executor.layers.fused_moe import fused_experts assert self.fused_experts is None @@ -509,6 +514,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, + quant_config=self.moe_quant_config, ) def _create_weights_4bit( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b07bf675ca47d..d6550dd16892f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -12,7 +12,6 @@ from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) from compressed_tensors.transform import TransformConfig -from pydantic import BaseModel import vllm.envs as envs from vllm.logger import init_logger @@ -63,7 +62,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_ignore_list: list[str], kv_cache_scheme: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None, - transform_config: Optional[TransformConfig] = None, + transform_config: Optional[dict[str, Any]] = None, ): super().__init__() self.ignore = ignore @@ -75,7 +74,7 @@ class CompressedTensorsConfig(QuantizationConfig): self.sparsity_ignore_list = sparsity_ignore_list self.config = config - if transform_config is not None: + if transform_config: self.transform_config = TransformConfig.model_validate( transform_config) else: @@ -129,7 +128,7 @@ class CompressedTensorsConfig(QuantizationConfig): # choose transform method if any((input_tfms, output_tfms)): return CompressedTensorsLinearTransformMethod.from_schemes( - quant_method, input_tfms, output_tfms) + quant_method, quant_scheme, input_tfms, output_tfms) else: return quant_method @@ -268,7 +267,8 @@ class CompressedTensorsConfig(QuantizationConfig): else: return False - def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): + def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs): if weight_quant is None or input_quant is None: return False @@ -288,8 +288,8 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_tensor_group_quant and is_float_type and is_4_bits and is_group_size_16 and is_symmetric) - def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, - input_quant: BaseModel): + def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs): is_weight_only = weight_quant is not None and input_quant is None is_tensor_group_quant = ( @@ -303,8 +303,8 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_weight_only and is_tensor_group_quant and is_float_type and is_4_bits and is_group_size_16 and is_symmetric) - def _is_static_tensor_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value @@ -317,8 +317,8 @@ class CompressedTensorsConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_8_bits and is_tensor and weight_quant.symmetric and is_static - def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value @@ -331,8 +331,8 @@ class CompressedTensorsConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_8_bits and is_token and weight_quant.symmetric and is_dynamic - def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_dynamic_token_w4a8_int(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: is_weight_4_bits = weight_quant.num_bits == 4 is_activation_8_bits = input_quant.num_bits == 8 weight_strategy = ( @@ -347,8 +347,8 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_weight_4_bits and is_activation_8_bits and is_token and weight_quant.symmetric and is_dynamic) - def _is_fp8_w8a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: # Confirm weights and activations quantized. if weight_quant is None or input_quant is None: return False @@ -358,11 +358,12 @@ class CompressedTensorsConfig(QuantizationConfig): and input_quant.type == QuantizationType.FLOAT) is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK ]) if not (is_floating_point and is_symmetric_weight and is_static_weight - and is_per_tensor_or_channel_weight): + and is_tensor_or_channel_or_block_weight): return False # Dynamic quantization is always supported if weights supported. @@ -375,8 +376,8 @@ class CompressedTensorsConfig(QuantizationConfig): input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation - def _is_fp8_w4a8(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w4a8(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: if not weight_quant or not input_quant: return False is_weight_4_bits = weight_quant.num_bits == 4 @@ -392,24 +393,24 @@ class CompressedTensorsConfig(QuantizationConfig): return (is_weight_4_bits and is_activation_8_bits and is_token and is_symmetric and is_dynamic) - def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w4a8_sm90(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: return (self._check_scheme_supported(90, error=False, match_exact=True) and self._is_fp8_w4a8(weight_quant, input_quant)) - def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: return (self._check_scheme_supported(90, error=False, match_exact=True) and self._is_fp8_w8a8(weight_quant, input_quant)) - def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a8_sm100(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: return (self._check_scheme_supported( 100, error=False, match_exact=True) and self._is_fp8_w8a8(weight_quant, input_quant)) - def _is_fp8_w8a16(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_fp8_w8a16(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: # Confirm weights quantized. if weight_quant is None: return False @@ -421,18 +422,19 @@ class CompressedTensorsConfig(QuantizationConfig): # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK ]) if not (is_symmetric_weight and is_static_weight # noqa: SIM103 - and is_per_tensor_or_channel_weight): + and is_tensor_or_channel_or_block_weight): return False # All conditions satisfied. return True - def _is_wNa16_group_channel(self, weight_quant: BaseModel, - input_quant: BaseModel) -> bool: + def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs, + input_quant: QuantizationArgs) -> bool: input_quant_none = input_quant is None is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value @@ -443,8 +445,8 @@ class CompressedTensorsConfig(QuantizationConfig): def _get_scheme_from_parts( self, - weight_quant: BaseModel, - input_quant: BaseModel, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, format: Optional[str] = None) -> "CompressedTensorsScheme": # use the per-layer format if defined, otherwise, use global format @@ -496,7 +498,7 @@ class CompressedTensorsConfig(QuantizationConfig): CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( - strategy=weight_quant.strategy, + weight_quant=weight_quant, is_static_input_scheme=(input_quant and not input_quant.dynamic)) else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 1deb019298d0e..e04fce591a265 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -3,7 +3,7 @@ import enum from enum import Enum -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch from compressed_tensors import CompressionFormat @@ -16,8 +16,11 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, - FusedMoeWeightScaleSupported) + FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config, + int8_w8a16_moe_quant_config, nvfp4_moe_quant_config) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa @@ -122,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): return CompressedTensorsWNA16MarlinMoEMethod( quant_config, layer.moe_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer) + return CompressedTensorsW4A4MoeMethod(layer.moe_config) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8(weight_quant, input_quant)): @@ -138,7 +141,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): - def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module): + def __init__(self, moe: FusedMoEConfig): from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 detect_nvfp4_moe_support) super().__init__(moe) @@ -147,7 +150,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.group_size = 16 - self.layer = layer def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -305,37 +307,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): (layer.w2_input_global_scale), requires_grad=False) def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if not self.allow_flashinfer: - return super().maybe_make_prepare_finalize(moe) + self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.use_marlin: + return None + elif not self.allow_flashinfer: + return super().maybe_make_prepare_finalize() prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe, - a1_gscale=self.layer.w13_input_scale_quant, - ) + self.moe) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None """Return the appropriate GEMM experts implementation.""" experts = select_nvfp4_gemm_impl( - moe, - g1_alphas=self.layer.g1_alphas, - g2_alphas=self.layer.g2_alphas, - a1_gscale=self.layer.w13_input_scale_quant, - a2_gscale=self.layer.w2_input_scale_quant, + self.moe, + self.moe_quant_config, allow_flashinfer=self.allow_flashinfer, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.use_marlin: + return None + + return nvfp4_moe_quant_config( + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + def apply( self, layer: torch.nn.Module, @@ -358,9 +369,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert self.fused_experts is None - + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsW4A4MoeMethod` yet.") @@ -380,7 +389,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): indices_type=self.topk_indices_dtype, ) + # + # Note: the order here is important. self.fused_experts can override + # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin. + # if self.use_marlin: + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -398,7 +412,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, - expert_map=expert_map) + expert_map=expert_map, + workspace=layer.workspace) assert activation == "silu", "Only SiLU activation is supported." @@ -418,11 +433,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) + # FlashInfer fused experts path elif self.allow_flashinfer: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 flashinfer_cutlass_moe_fp4) @@ -431,51 +445,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): x, layer.w13_weight, layer.w2_weight), ( "Flashinfer CUTLASS Fused MoE not applicable!") + assert self.moe_quant_config is not None + return flashinfer_cutlass_moe_fp4( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + quant_config=self.moe_quant_config, inplace=False, # TODO(shuw): fix later, now output is high prec activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, apply_router_weight_on_input=apply_router_weight_on_input, ) + else: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "CompressedTensorsW4A4MoeMethod.") - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "CompressedTensorsW4A4MoeMethod.") + assert self.moe_quant_config is not None - # Cutlass moe takes in activations in BF16/Half precision - # and fp4 quantized weights loaded from the checkpoint - return cutlass_moe_fp4( - a=x, - w1_fp4=layer.w13_weight, - w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_weight_scale, - w2_blockscale=layer.w2_weight_scale, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - apply_router_weight_on_input=apply_router_weight_on_input).to( - x.dtype) + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w2_fp4=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + # TODO(bnell): derive these from arguments + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + ).to(x.dtype) class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): @@ -693,16 +702,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale - self.fused_experts_func = None - else: - from vllm.model_executor.layers.fused_moe import fused_experts - self.fused_experts_func = fused_experts if self.use_cutlass: device = layer.w13_weight.device @@ -723,11 +727,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): device=device, dtype=torch.int64) + def maybe_make_prepare_finalize( + self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.use_marlin or self.rocm_aiter_moe_enabled: + return None + else: + return super().maybe_make_prepare_finalize() + def select_gemm_impl( - self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, - layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute: + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> FusedMoEPermuteExpertsUnpermute: # cutlass path + assert self.moe_quant_config is not None if self.use_cutlass: from vllm.model_executor.layers.fused_moe import ( CutlassBatchedExpertsFp8, CutlassExpertsFp8) @@ -741,26 +754,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__) experts = CutlassBatchedExpertsFp8( - moe.num_local_experts, + self.moe.num_local_experts, num_dispatchers, - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + self.moe.in_dtype, ab_strides1=self.ab_strides1_c_strides2, ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, c_strides2=self.ab_strides1_c_strides2, + quant_config=self.moe_quant_config, ) else: logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) experts = CutlassExpertsFp8( - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + self.moe.in_dtype, ab_strides1=self.ab_strides1_c_strides2, ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, c_strides2=self.ab_strides1_c_strides2, + quant_config=self.moe_quant_config, ) self.disable_expert_map = (num_dispatchers > 1 @@ -775,29 +786,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): assert not self.rocm_aiter_moe_enabled and not self.use_marlin - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank( ) assert max_num_tokens_per_rank is not None + logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) return BatchedTritonExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN), + quant_config=self.moe_quant_config, ) else: - return TritonExperts( - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN), - ) + logger.debug("TritonExperts(%s)", self.__class__.__name__) + return TritonExperts(self.moe_quant_config) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.use_marlin: + return None + + per_act_token = ( + self.input_quant.strategy == QuantizationStrategy.TOKEN) + per_channel_quant = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_channel_quant, + ) def apply( self, @@ -821,7 +843,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for " @@ -842,90 +864,17 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): indices_type=self.topk_indices_dtype, ) - # cutlass path - if self.use_cutlass: - per_act_token = ( - self.input_quant.strategy == QuantizationStrategy.TOKEN) - per_channel_quant = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + per_act_token = ( + self.input_quant.strategy == QuantizationStrategy.TOKEN) + per_channel_quant = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL) - # small-batch fallback on SM100 - if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=per_channel_quant, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) - - if self.fused_experts is None: - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) - return cutlass_moe_fp8( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - per_act_token=per_act_token, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - ab_strides1=self.ab_strides1_c_strides2, - ab_strides2=self.ab_strides2, - c_strides1=self.c_strides1, - c_strides2=self.ab_strides1_c_strides2, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - else: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - if self.rocm_aiter_moe_enabled: - return self.rocm_aiter_fused_experts_func( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy == - QuantizationStrategy.CHANNEL, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - expert_map=expert_map) + # + # Note: the order here is important. self.fused_experts can override + # cutlass fp8 or fused_experts but not marlin or rocm. + # if self.use_marlin: + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -941,28 +890,98 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, - expert_map=expert_map) + expert_map=expert_map, + workspace=layer.workspace) - assert self.fused_experts_func is not None + elif self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts) + assert per_act_token == per_channel_quant + assert self.moe_quant_config is not None + assert self.fused_experts is None + return rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) - return self.fused_experts_func( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy == - QuantizationStrategy.CHANNEL, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + elif self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + ) + + # cutlass path + elif self.use_cutlass: + assert self.moe_quant_config is not None + + # small-batch fallback on SM100 + if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: + from vllm.model_executor.layers.fused_moe import fused_experts + assert per_act_token == per_channel_quant + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + quant_config=self.moe_quant_config, + ) + else: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) + assert per_act_token == per_channel_quant + assert self.moe_quant_config is not None + return cutlass_moe_fp8( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_config=self.moe_quant_config, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + ab_strides1=self.ab_strides1_c_strides2, + ab_strides2=self.ab_strides2, + c_strides1=self.c_strides1, + c_strides2=self.ab_strides1_c_strides2, + ) + + else: + from vllm.model_executor.layers.fused_moe import fused_experts + assert per_act_token == per_channel_quant + assert self.moe_quant_config is not None + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): @@ -1048,6 +1067,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return int8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=True, + ) + def apply( self, layer: torch.nn.Module, @@ -1070,7 +1099,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -1103,14 +1132,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, - use_int8_w8a8=True, - per_channel_quant=True, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + quant_config=self.moe_quant_config, + ) class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): @@ -1354,6 +1379,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): layer.workspace = marlin_make_workspace_new(device, 4) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -1376,7 +1405,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -1585,6 +1614,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight_scale.transpose(1, 2).contiguous(), requires_grad=False) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + assert self.num_bits == 4 or self.num_bits == 8 + config_builder = (int4_w4a16_moe_quant_config if self.num_bits == 4 + else int8_w8a16_moe_quant_config) + + return config_builder( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, self.group_size], + ) + def apply( self, layer: torch.nn.Module, @@ -1607,7 +1650,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -1638,13 +1681,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): topk_ids=topk_ids, inplace=True, activation=activation, - use_int4_w4a16=self.num_bits == 4, - use_int8_w8a16=self.num_bits == 8, - global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_zp=None, - w2_zp=None, - block_shape=[0, self.group_size]) + quant_config=self.moe_quant_config, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index d984e89d9e02a..d42ae22c51393 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -4,28 +4,41 @@ from typing import Callable, Optional import torch -from compressed_tensors.quantization import QuantizationStrategy +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy) from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_fp8_block_linear, check_aiter_fp8_linear_support, + create_fp8_input_scale, create_fp8_scale_parameter, + create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, + process_fp8_weight_tensor_strategy, validate_fp8_block_shape) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, - requantize_with_max_scale) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter, + Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity) +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ChannelQuantScaleParameter, PerTensorScaleParameter) -from vllm.platforms import current_platform __all__ = ["CompressedTensorsW8A8Fp8"] +strategy_to_parameter_type = { + QuantizationStrategy.BLOCK: BlockQuantScaleParameter, + QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter, + QuantizationStrategy.TENSOR: PerTensorScaleParameter, +} + class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - def __init__(self, strategy: str, is_static_input_scheme: bool): - self.strategy = strategy + def __init__(self, weight_quant: QuantizationArgs, + is_static_input_scheme: bool): + self.weight_quant = weight_quant + self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme self.act_q_group_shape = GroupShape.PER_TENSOR \ @@ -34,61 +47,84 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): act_quant_static=self.is_static_input_scheme, act_quant_group_shape=self.act_q_group_shape) + self.weight_block_size = self.weight_quant.block_structure + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + @classmethod def get_min_capability(cls) -> int: # lovelace and up return 89 + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + weight_loader: Callable, **kwargs): + maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.weight_block_size = None + + if self.strategy == QuantizationStrategy.BLOCK: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + # Validate block quantization shapes + validate_fp8_block_shape(layer, input_size, output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size) + + # WEIGHT + weight = create_fp8_weight_parameter(output_size_per_partition, + input_size_per_partition, + weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = create_fp8_scale_parameter( + strategy_to_parameter_type[self.strategy], output_partition_sizes, + input_size_per_partition, layer.weight_block_size, weight_loader) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = create_fp8_input_scale(output_partition_sizes, + weight_loader) + layer.register_parameter("input_scale", input_scale) + def process_weights_after_loading(self, layer) -> None: - # If per tensor, when we have a fused module (e.g. QKV) with per - # tensor scales (thus N scales being passed to the kernel), - # requantize so we can always run per tensor if self.strategy == QuantizationStrategy.TENSOR: - max_w_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths, - ) + weight, weight_scale, input_scale = ( + process_fp8_weight_tensor_strategy( + layer.weight, layer.weight_scale, layer.logical_widths, + getattr(layer, 'input_scale', None))) + weight = weight.t() - if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) - - weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=max_w_scale, - input_scale=input_scale) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) - - layer.weight = Parameter(weight.t(), requires_grad=False) - layer.weight_scale = Parameter(max_w_scale, requires_grad=False) - - # If channelwise, scales are already lined up, so just transpose. elif self.strategy == QuantizationStrategy.CHANNEL: - weight = layer.weight + weight, weight_scale, input_scale = ( + process_fp8_weight_channel_strategy( + layer.weight, layer.weight_scale, + getattr(layer, 'input_scale', None))) + weight = weight.t() - if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) - - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=layer.weight_scale, - input_scale=input_scale) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) - else: - weight_scale = layer.weight_scale.data - - layer.weight = Parameter(weight.t(), requires_grad=False) - # required by torch.compile to be torch.nn.Parameter - layer.weight_scale = Parameter(weight_scale, requires_grad=False) + elif self.strategy == QuantizationStrategy.BLOCK: + assert self.is_static_input_scheme is False + weight, weight_scale = process_fp8_weight_block_strategy( + layer.weight, layer.weight_scale) + input_scale = None else: raise ValueError(f"Unknown quantization strategy {self.strategy}") + # required by torch.compile to be torch.nn.Parameter + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) + if input_scale is not None: + layer.input_scale = Parameter(input_scale.data, + requires_grad=False) + # INPUT SCALE if self.is_static_input_scheme and hasattr(layer, 'input_scale'): layer.input_scale = Parameter(layer.input_scale.max(), @@ -96,58 +132,23 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): else: layer.input_scale = None - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - maybe_create_device_identity() - - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes - - # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) - layer.register_parameter("weight", weight) - - # WEIGHT SCALE - # TODO: update create_xxx_parameter functions to return - # the newly added parameters - if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) - else: - assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - # min requirement for fp8 kernels - weight_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", weight_scale) - - # INPUT SCALE - if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - input_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("input_scale", input_scale) + if self.strategy == QuantizationStrategy.BLOCK: + maybe_post_process_fp8_weight_block( + layer, self.cutlass_block_fp8_supported) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if layer.weight_block_size is not None: + return apply_fp8_block_linear( + layer, + input=x, + bias=bias, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported) + return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py index 2fc94b3c257e6..d098185146e41 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py @@ -12,6 +12,8 @@ from compressed_tensors.utils import is_match from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, LinearMethodBase, QKVCrossParallelLinear) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501 HadamardTransform) from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 @@ -26,14 +28,22 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase): @classmethod def from_schemes( - cls, quant_method: LinearMethodBase, input_tfms: dict[int, - TransformTuple], - output_tfms: dict[int, TransformTuple] + cls, + quant_method: LinearMethodBase, + quant_scheme: Optional[CompressedTensorsScheme], + input_tfms: dict[int, TransformTuple], + output_tfms: dict[int, TransformTuple], ) -> "CompressedTensorsLinearTransformMethod": + from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501 + QutlassNvFP4LinearMethod, is_qutlass_fp4_scheme) + assert input_tfms or output_tfms - # TODO (@ksayers): implement QutlassLinearMethodNvFP4 - # hadacore and fwht can be selected by Transform module + if is_qutlass_fp4_scheme(quant_scheme, input_tfms): + return QutlassNvFP4LinearMethod(quant_method, input_tfms, + output_tfms) + + # hadacore or dense gemm is selected by Transform module return cls(quant_method, input_tfms, output_tfms) @@ -129,11 +139,12 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase): assert bias is None x = self.quant_method.apply(layer, x, bias) - # TODO (@ksayers): Write a triton kernel to do this in parallel + # In most cases, input transforms are preferred over output transforms + # (@ksayers): confirm that this is done concurrently if self.output_transform is not None: for part_id, (start, length) in enumerate(self.partition_ranges): x[:, start:start + length] = self.output_transform( - x[:, start:start + length], part_id=part_id) + x[:, start:start + length].contiguous(), part_id=part_id) return x diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py index b3be254717734..5e863354715e4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py @@ -2,12 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Hashable -from typing import Callable, Optional +from typing import Callable import torch -from compressed_tensors.transform import TransformLocation, TransformScheme +from compressed_tensors.transform import (TransformArgs, TransformLocation, + TransformScheme) from torch import Tensor +import vllm._custom_ops as ops from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import LinearBase @@ -28,16 +30,12 @@ class HadamardTransform(torch.nn.Module): transforms: dict[int, TransformTuple] # info parsed from transforms config weight: SharedWeightParameter # container for shared tensors - kernel: Callable # function used during application scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0)) - def __init__(self, - transforms: dict[int, TransformTuple], - layer: torch.nn.Module, - weight_loader: Callable, + def __init__(self, transforms: dict[int, TransformTuple], + layer: torch.nn.Module, weight_loader: Callable, input_size_per_partition: int, - output_partition_sizes: list[int], - kernel: Optional[Callable] = None): + output_partition_sizes: list[int]): super().__init__() self.transforms = transforms self.scales = {} @@ -55,7 +53,7 @@ class HadamardTransform(torch.nn.Module): for part_index, (_scheme_name, scheme, args) in self.transforms.items(): output_size = output_partition_sizes[part_index] - weight_size = self._get_weight_size(layer, args.location, + weight_size = self._get_weight_size(layer, scheme, args, input_size, output_size) data_key = self._get_data_key(scheme, weight_size) @@ -69,9 +67,6 @@ class HadamardTransform(torch.nn.Module): # validate that shared tensors and schemes are correct self._validate_input_transforms() - # select kernel based on transform schemes - self.kernel = self._infer_kernel() if kernel is None else kernel - def process_weights_after_loading(self): for part_id in self.weight.partitions: data = self.weight.partitions[part_id].data @@ -83,39 +78,66 @@ class HadamardTransform(torch.nn.Module): # do not fold into weight in order to utilize FWHT self.scales[part_id] = 1 / math.sqrt(data.size(0)) - # FUTURE: avoid runtime tranpose by processing weights + # FUTURE: avoid runtime transpose by processing weights # prior to apply def forward(self, value: Tensor, part_id: int = 0) -> Tensor: if part_id not in self.weight.partitions: return value - weight = self.weight.partitions[part_id] - weight = weight if self.transforms[ - part_id].args.inverse else weight.T # linear := x(W.T) - scale = self.scales[part_id] - return self.kernel(self, value.to(weight.dtype), weight, None).to( - value.dtype) * scale + # use hadacore if possible + if self.transforms[part_id].scheme.type == "hadamard": + if self.transforms[part_id].scheme.head_dim is not None: + weight_size = self.transforms[part_id].scheme.head_dim + value = value.unflatten(-1, (-1, weight_size)) + value = ops.hadacore_transform(value) + value = value.flatten(-2, -1) + + return value + + # sylvester transforms are symmetric, inv => transpose => original + return ops.hadacore_transform(value) + + # fall back to dense + else: + weight = self.weight.partitions[part_id] + weight = weight if self.transforms[ + part_id].args.inverse else weight.T # linear := x(W.T) + scale = self.scales[part_id] + + if self.transforms[part_id].scheme.head_dim is not None: + value = value.unflatten(-1, (-1, weight.size(0))) + value = dispatch_unquantized_gemm()(self, value.to( + weight.dtype), weight, None).to(value.dtype) * scale + value = value.flatten(-2, -1) + + return value + + return dispatch_unquantized_gemm()(self, value.to( + weight.dtype), weight, None).to(value.dtype) * scale def _get_data_key(self, scheme: TransformScheme, weight_size: int) -> Hashable: return (id(scheme), weight_size) - def _get_weight_size(self, layer: torch.nn.Module, - location: TransformLocation, input_size: int, + def _get_weight_size(self, layer: torch.nn.Module, scheme: TransformScheme, + args: TransformArgs, input_size: int, output_size: int) -> int: + if scheme.head_dim is not None: + return scheme.head_dim + if isinstance(layer, LinearBase): - if location == TransformLocation.INPUT: + if args.location == TransformLocation.INPUT: return input_size - elif location == TransformLocation.OUTPUT: + elif args.location == TransformLocation.OUTPUT: return output_size elif isinstance(layer, VocabParallelEmbedding): - if location == TransformLocation.INPUT: + if args.location == TransformLocation.INPUT: return output_size - elif location == TransformLocation.OUTPUT: + elif args.location == TransformLocation.OUTPUT: return input_size raise ValueError() @@ -129,7 +151,3 @@ class HadamardTransform(torch.nn.Module): for partition in self.weight.partitions.values(): if partition.data.data_ptr() != first_data.data_ptr(): raise ValueError("") - - def _infer_kernel(self) -> Callable: - # TODO (@ksayers): use fwht, hadacore - return dispatch_unquantized_gemm() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py index f42258f9f9d7f..69b39f31eec1e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py @@ -4,18 +4,43 @@ from typing import Optional import torch +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsScheme, CompressedTensorsW4A4Fp4) from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 - CompressedTensorsLinearTransformMethod) + CompressedTensorsLinearTransformMethod, TransformTuple) + +__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"] -# Because qutlass fuses hadamard with quantization, it cannot automatically be -# composed with kernels in the way CompressedTensorsLinearTransformMethod does. -# Therefore, a separate scheme must be created for each quantized dtype -class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod): +def is_qutlass_fp4_scheme(quant_scheme: Optional[CompressedTensorsScheme], + input_tfms: dict[int, TransformTuple]) -> bool: + return isinstance( + quant_scheme, + (CompressedTensorsW4A4Fp4, )) and len(input_tfms) == 1 and input_tfms[ + 0].scheme.head_dim == quant_scheme.group_size + + +class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod): + + def create_weights(self, layer, input_size_per_partition, + output_partition_sizes, input_size, output_size, + params_dtype, **extra_weight_attrs): + # initializes fp4 qparams + assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4, )) + ret = super().create_weights(layer, input_size_per_partition, + output_partition_sizes, input_size, + output_size, params_dtype, + **extra_weight_attrs) + + assert self.input_transform is not None + assert len(self.input_transform.weight) == 1 + assert self.input_transform.weight[0].size( + 0) == layer.scheme.group_size + + return ret def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - # fused hadamard quant linear method raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 099d8613fc1a7..b2dd2501095f8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -94,7 +94,7 @@ def find_matched_target( config that a layer corresponds to. Recall that a compressed-tensors configs has a concept of - config_groups, where each layer can be quantized with with a different + config_groups, where each layer can be quantized with a different scheme. targets in each config_group will be a list of either layer names diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 2d8a684bc7d90..8555e9ff20346 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, int8_w8a16_moe_quant_config) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -106,6 +108,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): requires_grad=False) layer.register_parameter("w2_scale", w2_scale) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + w1_zp=None, + w2_zp=None) + def apply( self, layer: torch.nn.Module, @@ -128,7 +137,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -159,12 +168,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, activation=activation, - use_int8_w8a16=True, - global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale) + quant_config=self.moe_quant_config, + ) @staticmethod def quantizing_weight_loader(layer, weight_loader): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c0321082070f7..617c601dc2965 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch -import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter @@ -14,9 +13,11 @@ from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( - FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, + FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -30,7 +31,12 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) + apply_fp8_block_linear, check_aiter_fp8_linear_support, + create_fp8_input_scale, create_fp8_scale_parameter, + create_fp8_weight_parameter, get_col_major_tma_aligned_tensor, + maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, + process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, + validate_fp8_block_shape) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) @@ -39,8 +45,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, - normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, - requantize_with_max_scale) + normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.parameter import (BlockQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -230,14 +235,10 @@ class Fp8LinearMethod(LinearMethodBase): if current_platform.is_rocm(): self.use_marlin = False - # AITER is only supported on ROCm and only for FP8_FNUZ - # and at the moment are MI300 series - self.use_aiter_and_is_supported = (current_platform.is_rocm() - and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()) + self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() - self.block_quant = self.quant_config.weight_block_size is not None + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): @@ -270,50 +271,27 @@ class Fp8LinearMethod(LinearMethodBase): layer.weight_block_size = None if self.block_quant: - tp_size = get_tensor_model_parallel_world_size() - assert self.quant_config.weight_block_size is not None - layer.weight_block_size = self.quant_config.weight_block_size - block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], - ) - # Required by row parallel - if (tp_size > 1 - and input_size // input_size_per_partition == tp_size - and input_size_per_partition % block_k != 0): - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"weight quantization block_k = {block_k}.") - # Required by column parallel or enabling merged weights - is_tp_split = (tp_size > 1 and - output_size // output_size_per_partition == tp_size) - is_merged_gemm = len(output_partition_sizes) > 1 - if is_tp_split or is_merged_gemm: - sizes_to_check = output_partition_sizes - if not is_tp_split and is_merged_gemm: - # In case of merged matrices, we allow the last - # matrix to not be a multiple of block size - sizes_to_check = output_partition_sizes[:-1] - for output_partition_size in sizes_to_check: - if output_partition_size % block_n != 0: - raise ValueError( - f"Weight output_partition_size = " - f"{output_partition_size} is not divisible by " - f"weight quantization block_n = {block_n}.") + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + validate_fp8_block_shape(layer, input_size, output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size) # WEIGHT - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized else - params_dtype) - - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + if self.quant_config.is_checkpoint_fp8_serialized: + weight = create_fp8_weight_parameter(output_size_per_partition, + input_size_per_partition, + weight_loader) + else: + # For non-serialized checkpoints, use original dtype + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) # If checkpoint is serialized fp8, load them. @@ -321,145 +299,87 @@ class Fp8LinearMethod(LinearMethodBase): if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE if not self.block_quant: - scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.float32), - weight_loader=weight_loader, - ) - scale[:] = torch.finfo(torch.float32).min + scale = create_fp8_scale_parameter(PerTensorScaleParameter, + output_partition_sizes, + input_size_per_partition, + None, weight_loader) set_weight_attrs(scale, {"scale_type": "weight_scale"}) layer.register_parameter("weight_scale", scale) else: - assert self.quant_config.activation_scheme == "dynamic" - scale = BlockQuantScaleParameter( - data=torch.empty( - (output_size_per_partition + block_n - 1) // block_n, - (input_size_per_partition + block_k - 1) // block_k, - dtype=torch.float32, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - scale[:] = torch.finfo(torch.float32).min + assert not self.act_q_static + assert self.weight_block_size is not None + scale = create_fp8_scale_parameter(BlockQuantScaleParameter, + output_partition_sizes, + input_size_per_partition, + self.weight_block_size, + weight_loader) set_weight_attrs(scale, {"scale_type": "weight_scale"}) # The weight_scale_inv name is intentional for deepseekv3 layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE - if self.quant_config.activation_scheme == "static": - scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) - - scale[:] = torch.finfo(torch.float32).min + if self.act_q_static: + scale = create_fp8_input_scale(output_partition_sizes, + weight_loader) set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) - def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: - # Pad the weight tensor. This is an optimization on ROCm platform, which - # can benefit from tensors located far enough from one another in memory - if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() - and weight.stride(-1) == 1 - and (weight.stride(-2) * weight.element_size()) % 512 == 0): - num_pad = 256 // weight.element_size() - weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() - return weight - def process_weights_after_loading(self, layer: Module) -> None: size_k_first = True + input_scale = None # TODO(rob): refactor block quant into separate class. if self.block_quant: - assert self.quant_config.activation_scheme == "dynamic" + assert not self.act_q_static size_k_first = False - if current_platform.is_fp8_fnuz(): - weight, weight_scale_inv, _ = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=layer.weight, - weight_scale=layer.weight_scale_inv) - else: - weight = layer.weight.data - weight_scale_inv = layer.weight_scale_inv.data - weight = self._maybe_pad_weight(weight) - - # Torch.compile cannot use Parameter subclasses. - layer.weight = Parameter(weight, requires_grad=False) - layer.weight_scale_inv = Parameter(weight_scale_inv, - requires_grad=False) + weight, weight_scale = process_fp8_weight_block_strategy( + layer.weight, layer.weight_scale_inv) + # Delete the weight_scale_inv parameter to avoid confusion + # with the weight_scale parameter + del layer.weight_scale_inv # If checkpoint not serialized fp8, quantize the weights. elif not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + weight = qweight.t() - # Update the layer with the new values. - layer.weight = Parameter(qweight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - # layer.input_scale is None indicates dynamic quant and scale is - # computed from input. - layer.input_scale = None - - # If checkpoint is fp8, handle that there are N scales for N + # If checkpoint is fp8 per-tensor, handle that there are N scales for N # shards in a fused module else: - layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, - requires_grad=False) - if self.quant_config.activation_scheme == "static": - layer.input_scale = torch.nn.Parameter(layer.input_scale.data, - requires_grad=False) - weight = layer.weight weight_scale = layer.weight_scale # If using w8a8, torch._scaled_mm needs per tensor, so # requantize the logical shards as a single weight. if not self.use_marlin: - # Dequant -> Quant with max scale so we can run per tensor. - if current_platform.is_fp8_fnuz(): - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=weight_scale, - input_scale=layer.input_scale) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) + weight, weight_scale, input_scale = ( + process_fp8_weight_tensor_strategy( + weight, weight_scale, layer.logical_widths, + getattr(layer, 'input_scale', None))) + if self.act_q_static: + assert input_scale is not None + input_scale = input_scale.max() + weight = weight.t() - weight_scale, weight = requantize_with_max_scale( - weight=weight, - weight_scale=weight_scale, - logical_widths=layer.logical_widths, - ) - - weight = self._maybe_pad_weight(weight) - # Update layer with new values. - layer.weight = Parameter(weight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - if self.quant_config.activation_scheme == "static": - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + # Update layer with new values. + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) + layer.input_scale = Parameter( + input_scale, + requires_grad=False) if input_scale is not None else None if self.use_marlin: prepare_fp8_layer_for_marlin(layer, size_k_first) # Activations not quantized for marlin. del layer.input_scale + return - # On B200, if E8M0 for DeepGemm is used, we need to - # requantize the weight and input to the specific scale - # at the same time. - if is_deep_gemm_e8m0_used(): - assert layer.weight_block_size is not None - block_sz = tuple(layer.weight_block_size) - requant_weight_ue8m0_inplace( - layer.weight.data, - layer.weight_scale_inv.data if hasattr( - layer, "weight_scale_inv") else layer.weight_scale.data, - block_sz, - ) + if self.block_quant: + maybe_post_process_fp8_weight_block( + layer, self.cutlass_block_fp8_supported) def apply(self, layer: torch.nn.Module, @@ -477,18 +397,12 @@ class Fp8LinearMethod(LinearMethodBase): bias=bias) if self.block_quant: - assert self.quant_config.weight_block_size is not None - - return torch.ops.vllm.apply_w8a8_block_fp8_linear( + return apply_fp8_block_linear( + layer, input=x, - weight=layer.weight, - block_size=self.quant_config.weight_block_size, - weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, bias=bias, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, - ) + use_aiter_and_is_supported=self.use_aiter_and_is_supported) return self.fp8_linear.apply(input=x, weight=layer.weight, @@ -515,7 +429,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): super().__init__(layer.moe_config) self.layer = layer self.quant_config = quant_config - self.block_quant = self.quant_config.weight_block_size is not None + self.weight_block_size = self.quant_config.weight_block_size + self.block_quant = self.weight_block_size is not None self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None self.fused_experts: Optional[ @@ -564,20 +479,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): "CutlassBlockScaledGroupedGemm not supported on the current " "platform.") - def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: - return super().maybe_make_prepare_finalize(moe) - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe, - layer=self.layer, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -591,12 +492,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn if self.block_quant: - assert self.quant_config.weight_block_size is not None - layer.weight_block_size = self.quant_config.weight_block_size + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size tp_size = get_tensor_model_parallel_world_size() block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], + self.weight_block_size[0], + self.weight_block_size[1], ) # NOTE: To ensure proper alignment of the block-wise quantization # scales, the output_size of the weights for both the gate and up @@ -756,16 +657,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - # DeepGemm scales need to be transposed and aligned. We try to do + # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): - # Lazy import to avoid CUDA initialization problems. if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ - get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() + get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) if _is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = \ - get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() + get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) # If checkpoint is fp16, quantize in place. elif not self.quant_config.is_checkpoint_fp8_serialized: @@ -895,7 +795,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w13_input_scale del layer.w2_input_scale - if is_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used() and self.block_quant: assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) @@ -913,15 +813,28 @@ class Fp8MoEMethod(FusedMoEMethodBase): # Ensure column-major TMA alignment expected by DeepGEMM. if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv).contiguous() + layer.w13_weight_scale_inv) if _is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv).contiguous() + layer.w2_weight_scale_inv) + + def maybe_make_prepare_finalize( + self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if (self.rocm_aiter_moe_enabled or self.use_marlin + or self.flashinfer_moe_backend + == FlashinferMoeBackend.TENSORRT_LLM): + return None + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + prepare_finalize = ( + build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe)) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + else: + return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( @@ -930,6 +843,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet.") + assert self.moe_quant_config is not None + if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): max_num_tokens_per_rank = ( @@ -939,33 +854,44 @@ class Fp8MoEMethod(FusedMoEMethodBase): "BatchedTritonOrDeepGemmExperts(%s): " "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", self.__class__.__name__, max_num_tokens_per_rank, - self.quant_config.weight_block_size, False) + self.weight_block_size, False) return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - per_act_token_quant=False, + quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: experts = select_cutlass_fp8_gemm_impl( - moe, - self.layer, + self.moe, + self.moe_quant_config, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts else: logger.debug( "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", - self.__class__.__name__, self.quant_config.weight_block_size, - False) + self.__class__.__name__, self.weight_block_size, False) return TritonOrDeepGemmExperts( - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, + quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.use_marlin: + return None + + return fp8_w8a8_moe_quant_config( + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.weight_block_size, + ) + def apply( self, layer: torch.nn.Module, @@ -988,19 +914,21 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + and self.fused_experts is None): assert activation == 'silu', ( f"Expected 'silu' activation but got {activation}") assert scoring_func == 'sigmoid', ( f"Expected 'sigmoid' scoring func but got {scoring_func}") if self.block_quant: + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 assert (renormalize and use_grouped_topk and custom_routing_function is None) @@ -1019,7 +947,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): intermediate_size=layer.intermediate_size_per_partition, expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, - block_shape=self.quant_config.weight_block_size, + block_shape=self.weight_block_size, routed_scaling=routed_scaling_factor, ) else: @@ -1056,9 +984,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): logical_replica_count=logical_replica_count, ) + # + # Note: the order of checks is important since self.fused_experts + # can override fused_experts or cutlass but not rocm or marlin. + # if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_fused_experts) + assert self.fused_experts is None return rocm_aiter_fused_experts( x, layer.w13_weight, @@ -1066,17 +999,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - use_fp8_w8a8=True, apply_router_weight_on_input=apply_router_weight_on_input, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - expert_map=expert_map) + expert_map=expert_map, + quant_config=self.moe_quant_config) elif self.use_marlin: + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -1092,41 +1019,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, - expert_map=expert_map) - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert self.block_quant is None - assert (not renormalize and custom_routing_function is not None) - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") - assert scoring_func == 'sigmoid', ( - f"Expected 'sigmoid' scoring func but got {scoring_func}") - if self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - return flashinfer_cutlass_moe_fp8( - x, - layer, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - common_kwargs = dict( + expert_map=expert_map, + workspace=layer.workspace) + elif self.fused_experts: + return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1137,26 +1033,43 @@ class Fp8MoEMethod(FusedMoEMethodBase): global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + assert self.block_quant is None + assert (not renormalize and custom_routing_function is not None) + assert activation == 'silu', ( + f"Expected 'silu' activation but got {activation}") + assert scoring_func == 'sigmoid', ( + f"Expected 'sigmoid' scoring func but got {scoring_func}") - if self.fused_experts is not None: - return self.fused_experts(**common_kwargs) - else: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts( - **common_kwargs, - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm), - ) + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm)) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index ad648df238194..a631dfdab6544 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import gguf import torch @@ -10,8 +10,9 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -518,6 +519,10 @@ class GGUFMoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_qweight_type, extra_weight_attrs) layer.register_parameter("w2_qweight_type", w2_qweight_type) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -540,7 +545,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index d03074f861848..6462292586482 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -265,9 +265,9 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): scales ('scales'), and zeros ('zeros'). Raises: - ValueError: If `params_dtype` is not `torch.float16` or - if the input size per partition is not divisible by the - group size in `quant_config`. + ValueError: If `params_dtype` is not `torch.float16` or if the input + size per partition is not divisible by the group size + in `quant_config`. """ if params_dtype != torch.float16: raise ValueError("Parameter data type must be torch.float16, " diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index d42073c0d47bf..81190d56f2bf8 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -9,8 +9,10 @@ import torch import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) @@ -469,7 +471,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) - # dont shard the w2 scales when running act order + # don't shard the w2 scales when running act order set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act}) # up_proj scales @@ -493,7 +495,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) - # dont shard the w2 scales when running act order + # don't shard the w2 scales when running act order set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act}) w13_g_idx = torch.nn.Parameter( @@ -632,6 +634,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): if hasattr(layer, "w2_bias") and layer.w2_bias is not None: layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, @@ -654,7 +660,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index e1a9bdde9334d..31182f40b48f6 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -23,28 +23,39 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) @CustomOp.register("quant_fp8") class QuantFP8(CustomOp): """ - Quantize input tensor to per-tensor or per-token FP8. + Quantize input tensor to FP8 (per-tensor, per-token, or per-group). This CustomOp supports both static and dynamic quantization. """ def __init__(self, static: bool, group_shape: GroupShape, - num_token_padding: Optional[int] = None): + num_token_padding: Optional[int] = None, + column_major_scales: bool = False): """ - :param static: static or dynamic quantization - :param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR) - :param num_token_padding: Pad the token dimension of output to this size + :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, + or arbitrary block size) + :param num_token_padding: Pad the token dimension of output to this + size + :param column_major_scales: For group quantization, output scales in + column major format """ super().__init__() - self.num_token_padding = num_token_padding - assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} - assert not static or group_shape == GroupShape.PER_TENSOR, \ - "Only per-tensor scales supported for static quantization." self.static = static self.group_shape = group_shape - self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN + self.num_token_padding = num_token_padding + self.column_major_scales = column_major_scales + + self.is_group_quant = group_shape.is_per_group() + if self.is_group_quant: + assert not static, "Group quantization only supports dynamic mode" + self.group_size = group_shape.col + else: + assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} + assert not static or group_shape == GroupShape.PER_TENSOR, \ + "Only per-tensor scales supported for static quantization." + self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN def forward_cuda( self, @@ -52,11 +63,19 @@ class QuantFP8(CustomOp): scale: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if self.is_group_quant: + assert scale is None, "Group quantization is always dynamic" + from vllm.model_executor.layers.quantization.utils import fp8_utils + return fp8_utils.per_token_group_quant_fp8( + x, + group_size=self.group_size, + column_major_scales=self.column_major_scales, + dtype=_FP8_DTYPE) + assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape == GroupShape.PER_TOKEN and scale_ub.numel() == 1) - return ops.scaled_fp8_quant( x, scale, @@ -70,6 +89,10 @@ class QuantFP8(CustomOp): scale: Optional[torch.Tensor] = None, scale_ub: Optional[torch.Tensor] = None, ): + if self.is_group_quant: + assert scale is None, "Group quantization is always dynamic" + return self._quantize_group_native(x) + assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape == GroupShape.PER_TOKEN @@ -84,8 +107,7 @@ class QuantFP8(CustomOp): else: x_max = x.abs().max().unsqueeze(-1).to(torch.float32) - scale = x_max / _FP8_MAX - scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR) + scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) # Even for dynamic per-token scales, # reciprocal performs slightly better than division @@ -101,3 +123,34 @@ class QuantFP8(CustomOp): out = F.pad(out, (0, 0, 0, padding), "constant", 0.0) return out, scale + + def _quantize_group_native( + self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + orig_shape = x.shape + hidden_dim = x.shape[-1] + num_groups = (hidden_dim + self.group_size - 1) // self.group_size + padded_dim = num_groups * self.group_size + + if padded_dim != hidden_dim: + padding = padded_dim - hidden_dim + x = F.pad(x, (0, padding), mode='constant', value=0.0) + + x_grouped = x.view(-1, num_groups, self.group_size) + absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() + scales = (absmax / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR) + + x_scaled = x_grouped / scales + x_quant = x_scaled.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + + x_quant = x_quant.view(-1, padded_dim) + if padded_dim != hidden_dim: + x_quant = x_quant[..., :hidden_dim] + x_quant = x_quant.view(orig_shape) + + scales = scales.squeeze(-1) + scales = scales.reshape(orig_shape[:-1] + (num_groups, )) + + if self.column_major_scales: + scales = scales.transpose(-2, -1).contiguous() + + return x_quant, scales diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 5f9d4814274c8..c83b0b47a4b7e 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter from vllm._ipex_ops import ipex_ops as ops from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -375,6 +376,10 @@ class XPUFp8MoEMethod(FusedMoEMethodBase): use_prepack=True, ) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return None + def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 4bcfcd04b3d8b..f10d20999bee3 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -46,11 +46,11 @@ def choose_mp_linear_kernel( performance. Args: - config (MPLinearLayerConfig): Description of the linear layer to be - implemented. + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the compute - capability. Defaults to None. + the target device, if None uses `current_platform` to get + the compute capability. Defaults to None. Raises: ValueError: If no kernel can implement the given config. diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index e5604670fb4c1..275a1c43fdd2b 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -88,6 +88,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): "Setting it to k_scale. This only matters for " "the flash-attn backend.") layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) @@ -124,6 +125,9 @@ class BaseKVCacheMethod(QuantizeMethodBase): # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale.item() if isinstance( + q_scale, torch.Tensor) else q_scale + layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 27aba609ce39e..5d7d9689d465e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch.nn import Module @@ -11,7 +11,9 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + nvfp4_moe_quant_config) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.fused_moe.layer import ( @@ -45,6 +47,9 @@ from vllm.utils import next_power_of_2 from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer, has_flashinfer_moe) +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + logger = init_logger(__name__) QUANT_ALGOS = ["FP8", "NVFP4"] @@ -63,7 +68,7 @@ class ModelOptFp8Config(QuantizationConfig): super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.kv_cache_quant_method = kv_cache_quant_method - self.exclude_modules = exclude_modules + self.exclude_modules = exclude_modules or [] if is_checkpoint_fp8_serialized: logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" " the format is experimental and could change.") @@ -84,6 +89,11 @@ class ModelOptFp8Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.exclude_modules is not None: + self.exclude_modules = hf_to_vllm_mapper.apply_list( + self.exclude_modules) + @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: @@ -170,7 +180,9 @@ class ModelOptFp8Config(QuantizationConfig): prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if self.is_layer_excluded(prefix): + if (is_layer_skipped(prefix, self.exclude_modules, + self.packed_modules_mapping) + or self.is_layer_excluded(prefix)): return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) elif isinstance(layer, Attention): @@ -284,8 +296,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): cutlass_fp8_supported) self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None - self.fused_experts: Optional[ - mk.FusedMoEModularKernel] = None # type: ignore if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( @@ -293,29 +303,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if self.fused_experts is not None or \ - self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS: - return super().maybe_make_prepare_finalize(moe) - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe, - layer=self.layer, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize + self, ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + # TRT LLM not supported with all2all yet. + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + return None + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + prepare_finalize = ( + build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe)) + logger.debug_once("%s", prepare_finalize.__class__.__name__) + return prepare_finalize + else: + return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None experts = select_cutlass_fp8_gemm_impl( - moe, - self.layer, + self.moe, + self.moe_quant_config, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts @@ -469,6 +477,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + return None + + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=False, + ) + def apply( self, layer: torch.nn.Module, @@ -491,12 +512,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet.") if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + assert self.fused_experts is None assert activation == 'silu', ( f"Expected 'silu' activation but got {activation}") assert not renormalize @@ -527,55 +549,56 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): indices_type=self.topk_indices_dtype, ) - if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + # + # Note: the order here is important. self.fused_experts can override + # cutlass or fused_experts. + # + if self.fused_experts is not None: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert not renormalize assert activation == 'silu', ( f"Expected 'silu' activation but got {activation}") - if self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - else: - return flashinfer_cutlass_moe_fp8( - x, - layer, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_fp8_w8a8=True, - per_channel_quant=False, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - apply_router_weight_on_input=apply_router_weight_on_input, - ) + return flashinfer_cutlass_moe_fp8( + x, + layer, + topk_weights, + topk_ids, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + else: + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts) + assert self.moe_quant_config is not None + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + quant_config=self.moe_quant_config, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) class ModelOptNvFp4Config(QuantizationConfig): @@ -615,6 +638,11 @@ class ModelOptNvFp4Config(QuantizationConfig): def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.exclude_modules is not None: + self.exclude_modules = hf_to_vllm_mapper.apply_list( + self.exclude_modules) + @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: @@ -763,7 +791,8 @@ class ModelOptNvFp4Config(QuantizationConfig): prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if (is_layer_skipped(prefix, self.exclude_modules) + if (is_layer_skipped(prefix, self.exclude_modules, + self.packed_modules_mapping) or self.is_layer_excluded(prefix, self.exclude_modules)): return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) @@ -1018,33 +1047,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): " for ModelOptNvFp4FusedMoE.") def maybe_make_prepare_finalize( - self, - moe: FusedMoEConfig, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if (self.allow_flashinfer and self.flashinfer_moe_backend - == FlashinferMoeBackend.CUTLASS): + self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if (self.use_marlin + or (self.allow_flashinfer and self.flashinfer_moe_backend + == FlashinferMoeBackend.TENSORRT_LLM)): + return None + elif (self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): + # For now, fp4 moe only works with the flashinfer dispatcher. prepare_finalize = ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe, - a1_gscale=self.layer.w13_input_scale_quant, - )) + build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize - - return super().maybe_make_prepare_finalize(moe) + else: + return super().maybe_make_prepare_finalize() def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None experts = select_nvfp4_gemm_impl( - moe, - g1_alphas=self.layer.g1_alphas, - g2_alphas=self.layer.g2_alphas, - a1_gscale=self.layer.w13_input_scale_quant, - a2_gscale=self.layer.w2_input_scale_quant, + self.moe, + self.moe_quant_config, allow_flashinfer=self.allow_flashinfer, ) logger.debug_once("Using %s", experts.__class__.__name__) @@ -1344,6 +1370,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if (self.use_marlin or self.flashinfer_moe_backend + == FlashinferMoeBackend.TENSORRT_LLM): + return None + + return nvfp4_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + ) + def apply( self, layer: torch.nn.Module, @@ -1366,18 +1407,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ): + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") - if self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if (self.allow_flashinfer and self.flashinfer_moe_backend + == FlashinferMoeBackend.TENSORRT_LLM): import flashinfer from vllm.model_executor.models.llama4 import Llama4MoE assert activation == "silu", "Only SiLU activation is supported." + assert self.fused_experts is None + a1_gscale = layer.w13_input_scale_quant (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( @@ -1441,7 +1484,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) + # + # Note: the order here is important. self.fused_experts can override + # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or + # trtllm. + # if self.use_marlin: + assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -1459,11 +1508,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, - expert_map=expert_map) + expert_map=expert_map, + workspace=layer.workspace) - assert activation == "silu", "Only SiLU activation is supported." - - if self.fused_experts is not None: + elif self.fused_experts is not None: assert self.allow_flashinfer and \ self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS @@ -1471,7 +1519,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): x, layer.w13_weight, layer.w2_weight), ( "Flashinfer CUTLASS Fused MoE not applicable!") - out = self.fused_experts( + return self.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1481,28 +1529,22 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) elif (self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 flashinfer_cutlass_moe_fp4) + assert self.moe_quant_config is not None - out = flashinfer_cutlass_moe_fp4( + return flashinfer_cutlass_moe_fp4( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - inplace=False, # TODO(shuw): fix later, now output is high prec + quant_config=self.moe_quant_config, + inplace=False, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1513,23 +1555,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp4) - out = cutlass_moe_fp4( + assert self.moe_quant_config is not None + return cutlass_moe_fp4( a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, - w1_blockscale=layer.w13_weight_scale, - w2_blockscale=layer.w2_weight_scale, - g1_alphas=layer.g1_alphas, - g2_alphas=layer.g2_alphas, - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, topk_weights=topk_weights, topk_ids=topk_ids, + quant_config=self.moe_quant_config, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + # TODO: derive from arguments m=x.shape[0], n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) - - return out + ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index fb3e4b518bf6c..145b614237fb3 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, @@ -190,7 +193,7 @@ class MoeWNA16Method(FusedMoEMethodBase): group_size = self.quant_config.group_size group_size_div_factor = 1 - # make intermediate_size and hidden_size diviable by group_size + # make intermediate_size and hidden_size divisible by group_size # we reduce the group size to ensure that # and we would repeat the loaded_weight later while intermediate_size_per_partition % group_size or \ @@ -283,6 +286,22 @@ class MoeWNA16Method(FusedMoEMethodBase): layer.register_parameter(key, param) set_weight_attrs(param, extra_weight_attrs) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + weight_bits = self.quant_config.weight_bits + has_zp = self.quant_config.has_zp + assert weight_bits == 4 or weight_bits == 8 + config_builder = (int4_w4a16_moe_quant_config + if weight_bits == 4 else int8_w8a16_moe_quant_config) + + return config_builder( + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size], + ) + def apply( self, layer: torch.nn.Module, @@ -305,7 +324,7 @@ class MoeWNA16Method(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( @@ -327,9 +346,6 @@ class MoeWNA16Method(FusedMoEMethodBase): e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) - weight_bits = self.quant_config.weight_bits - has_zp = self.quant_config.has_zp - return fused_experts( x, layer.w13_qweight, @@ -337,16 +353,11 @@ class MoeWNA16Method(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales, - w1_zp=layer.w13_qzeros if has_zp else None, - w2_zp=layer.w2_qzeros if has_zp else None, - block_shape=[0, layer.group_size]) + quant_config=self.moe_quant_config, + ) @staticmethod def get_weight_loader(layer, weight_loader): diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index a2301779c77e4..28c1e60ccd08a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from enum import Enum +from typing import Callable, Optional, Union import torch from torch.nn.parameter import Parameter @@ -11,6 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) from vllm.model_executor.layers.fused_moe import modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -33,33 +36,72 @@ from vllm.utils.flashinfer import has_flashinfer logger = init_logger(__name__) -def _should_use_flashinfer_mxfp4_bf16(): - """Determine if FlashInfer MXFP4 BF16 should be used.""" - # If explicitly set, respect the setting - if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): - return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 +# enum for mxfp4 backend +class Mxfp4Backend(Enum): + NONE = 0 - # Enable by default on SM100 if MXFP8 is not explicitly enabled - if (current_platform.is_device_capability(100) and has_flashinfer() - and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): - logger.info_once( - "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " - "For faster performance, consider setting " - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " - "though this may impact accuracy.") - return True + # FlashInfer Backend + SM100_FI_MXFP4_MXFP8_TRTLLM = 1 + SM100_FI_MXFP4_MXFP8_CUTLASS = 2 + SM100_FI_MXFP4_BF16 = 3 + SM90_FI_MXFP4_BF16 = 4 - return False + # Marlin Backend + MARLIN = 5 + + # Triton Backend + TRITON = 6 -def _should_use_flashinfer_mxfp4_mxfp8(): - """Determine if FlashInfer MXFP4 MXFP8 should be used.""" - return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 +def get_mxfp4_backend(): + # Backend Selection + if current_platform.is_cuda(): + if (current_platform.is_device_capability(90) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") + return Mxfp4Backend.SM90_FI_MXFP4_BF16 + elif (current_platform.is_device_capability(100) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + elif (current_platform.is_device_capability(100) and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, " + "for high concurrency throughput workloads consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better " + "performance") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + elif current_platform.is_device_capability(100) and has_flashinfer(): + logger.info_once( + "Using FlashInfer MXFP4 BF16 backend for SM100, " + "For faster performance on SM100, consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " + "accuracy.") + return Mxfp4Backend.SM100_FI_MXFP4_BF16 + elif ((current_platform.is_device_capability(100) + or current_platform.is_device_capability(90)) + and not has_flashinfer()): + logger.warning_once( + "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results.") + # If FlashInfer is not available, try either Marlin or Triton + if current_platform.get_device_capability( + )[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer( + "2.8.0"): + logger.info_once("Using Marlin backend") + return Mxfp4Backend.MARLIN + else: + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON + elif current_platform.is_rocm() and has_triton_kernels(): + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON -def should_use_flashinfer_mxfp4(): - return (_should_use_flashinfer_mxfp4_mxfp8() - or _should_use_flashinfer_mxfp4_bf16()) + return Mxfp4Backend.NONE class Mxfp4Config(QuantizationConfig): @@ -113,29 +155,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): super().__init__(moe) self.topk_indices_dtype = None self.moe = moe - self.use_marlin = self._should_use_marlin() + self.mxfp4_backend = get_mxfp4_backend() self.max_capture_size = get_current_vllm_config( ).compilation_config.max_capture_size - if current_platform.is_device_capability(100) and not has_flashinfer(): - logger.warning_once( - "MXFP4 MoE is enabled on Blackwell but FlashInfer " - "is not available. This may result in degraded performance. " - "Please `pip install vllm[flashinfer]` for best results.") - - def _should_use_marlin(self): - if envs.VLLM_MXFP4_USE_MARLIN is not None: - return envs.VLLM_MXFP4_USE_MARLIN - if current_platform.is_cuda() and \ - not current_platform.is_device_capability(100): - if not current_platform.has_device_capability(90): - # marlin kernel has better performance on ampere - return True - if not has_triton_kernels(): - return True - if not is_torch_equal_or_newer("2.8.0"): - return True - return False + assert self.mxfp4_backend != Mxfp4Backend.NONE, ( + "No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available." + "Please check your environment and try again.") + self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -156,7 +183,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): intermediate_size_per_partition_after_pad = \ intermediate_size_per_partition - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: # The moe marlin kernel requires that for each linear # n % 256 == 0 and k % 128 == 0. # In gate_up_proj: @@ -174,16 +201,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.hidden_size = hidden_size layer.intermediate_size_per_partition = \ intermediate_size_per_partition_after_pad - elif should_use_flashinfer_mxfp4(): + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) - elif current_platform.is_rocm(): + elif current_platform.is_rocm() or ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128) + hidden_size = round_up(hidden_size, 128) else: intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 64) @@ -263,10 +294,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: prepare_moe_fp4_layer_for_marlin(layer) - elif should_use_flashinfer_mxfp4(): - from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + from flashinfer.fp4_quantization import ( + nvfp4_block_scale_interleave) + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w2_permute_indices) layer.gemm1_alpha = Parameter(torch.tensor( [1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False) @@ -310,7 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w13_bias = layer.w13_bias.data.to(torch.float32) w2_bias = layer.w2_bias.data.to(torch.float32) - # Swap w1 and w3 as the defenition of + # Swap w1 and w3 as the definition of # swiglu is different in the trtllm-gen def swap_every_two_rows(x, axis=-1): shape = x.shape @@ -343,25 +378,63 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals for i in range(self.num_experts): - gemm1_weights_mxfp4_shuffled.append( - shuffle_matrix_a(w13_weight[i].view(torch.uint8), - epilogue_tile_m)) + # w13 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_mxfp4_shuffled.append(w13_weight[i].view( + torch.uint8)[permute_indices.to( + w13_weight.device)].contiguous()) + # w13 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm1_scales_mxfp4_shuffled.append( - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm1_bias_shuffled.append( - shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1), - epilogue_tile_m)) - - gemm2_weights_mxfp4_shuffled.append( - shuffle_matrix_a(w2_weight[i].view(torch.uint8), - epilogue_tile_m)) + nvfp4_block_scale_interleave(w13_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w13_weight_scale.device)].contiguous())) + # w13 bias shuffling + permute_bias_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm1_bias_shuffled.append(w13_bias[i].clone().reshape( + -1, + 1)[permute_bias_indices.to(w13_bias.device)].contiguous()) + # w2 weight shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_mxfp4_shuffled.append(w2_weight[i].view( + torch.uint8)[permute_indices.to( + w2_weight.device)].contiguous()) + # w2 scale shuffling + permute_sf_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ) gemm2_scales_mxfp4_shuffled.append( - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) - gemm2_bias_shuffled.append( - shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), - epilogue_tile_m)) + nvfp4_block_scale_interleave(w2_weight_scale[i].view( + torch.uint8)[permute_sf_indices.to( + w2_weight_scale.device)].contiguous())) + # w2 bias shuffling + permute_indices = _maybe_get_cached_w2_permute_indices( + self._cache_permute_indices, + w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m, + ) + gemm2_bias_shuffled.append(w2_bias[i].clone().reshape( + -1, 1)[permute_indices.to(w2_bias.device)].contiguous()) w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) w13_weight_scale = torch.stack( @@ -387,7 +460,116 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( self.num_experts, -1), requires_grad=False) - else: + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + layer.gemm1_alpha = Parameter(torch.tensor( + [1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_beta = Parameter(torch.tensor( + [1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_clamp_limit = Parameter(torch.tensor( + [7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + + sf_block_size = 32 # mxfp4 block size + + # Common shape assertions + assert (layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2) + assert (layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] + == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] + == self.hidden_size // sf_block_size) + assert (layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size and + layer.w2_weight.shape[2] == self.intermediate_size // 2) + assert (layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size) + assert (layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2) + assert (layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size) + + # De-interleave and swap for w13 weight, bias, and scales + w13_w = layer.w13_weight.data + gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :] + deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1) + w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + + w13_b = layer.w13_bias.data.to(torch.float32) + gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2] + deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1) + b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1) + w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) + + w13_s = layer.w13_weight_scale.data + gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :] + deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1) + s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1) + w13_scale_swapped = torch.cat([s3, s1], dim=1) + + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: + from flashinfer import block_scale_interleave + + orig_shape = w13_scale_swapped.shape + w13_scale_interleaved = block_scale_interleave( + w13_scale_swapped.view(torch.uint8)).reshape(orig_shape) + + w2_s = layer.w2_weight_scale.data + orig_shape = w2_s.shape + w2_scale_interleaved = block_scale_interleave( + w2_s.view(torch.uint8)).reshape(orig_shape) + + layer.w13_weight = Parameter(w13_weight_swapped, + requires_grad=False) + layer.w13_weight_scale = Parameter(w13_scale_interleaved, + requires_grad=False) + layer.w13_bias = Parameter(w13_bias_swapped, + requires_grad=False) + layer.w2_weight_scale = Parameter(w2_scale_interleaved, + requires_grad=False) + elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + + def _interleave_mxfp4_cutlass_sm90(w): + w_shape = w.shape + w_interleaved = w.reshape(w_shape[0], w_shape[1], + (w_shape[2] // 4), 4) + w_interleaved = w_interleaved.permute(0, 2, 1, 3) + w_interleaved = w_interleaved.reshape( + w_shape[0], w_shape[2] // 4, w_shape[1] * 4) + return w_interleaved + + w31_scales = w13_scale_swapped.to(torch.uint8).view( + torch.uint8) + w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90( + w31_scales) + + w2_weight_scale = layer.w2_weight_scale.data + w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8) + w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90( + w2_scales) + + layer.w13_weight = torch.nn.Parameter(torch.cat([w3_w, w1_w], + dim=1), + requires_grad=False) + layer.w13_bias = torch.nn.Parameter(w13_bias_swapped, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w31_scales_interleaved, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_scales_interleaved, requires_grad=False) + elif self.mxfp4_backend == Mxfp4Backend.TRITON: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig w13_bias = layer.w13_bias.to(torch.float32) @@ -422,6 +604,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w13_weight = None layer.w2_weight = None torch.cuda.empty_cache() + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): # Number of tokens in the input tensor. @@ -447,10 +631,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): return tile_tokens_dim + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + + if self.mxfp4_backend == Mxfp4Backend.MARLIN: + return None + + if self.mxfp4_backend == Mxfp4Backend.TRITON: + w1_scale = layer.w13_precision_config + w2_scale = layer.w2_precision_config + else: + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: if (prepare_finalize.activation_format == @@ -458,17 +661,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): raise NotImplementedError( "Mxfp4 does not support batched experts format for EP") else: - if should_use_flashinfer_mxfp4(): + if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): # B200 code-path kwargs = { "gemm1_alpha": layer.gemm1_alpha, "gemm1_beta": layer.gemm1_beta, "gemm1_clamp_limit": layer.gemm1_clamp_limit, - "w13_bias": layer.w13_bias, - "w2_bias": layer.w2_bias, + # TODO(bnell): part of quant_config "max_capture_size": self.max_capture_size, } - return TrtLlmGenExperts(moe, **kwargs) + assert self.moe_quant_config is not None + return TrtLlmGenExperts(self.moe, self.moe_quant_config, + **kwargs) else: # Use matmul_ogs from triton_kernels here! raise NotImplementedError( @@ -527,8 +732,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) @@ -554,12 +757,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") - if self.use_marlin: + if self.mxfp4_backend == Mxfp4Backend.MARLIN: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -623,16 +826,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): logical_replica_count), ( "MXFP4 are not supported with this configuration.") - if should_use_flashinfer_mxfp4(): - from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe - if _should_use_flashinfer_mxfp4_bf16(): + if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + from flashinfer import trtllm_fp4_block_scale_moe + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16: assert x.dtype == torch.bfloat16 x_quant = x x_scale = None - else: + elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM: + from flashinfer import mxfp8_quantize x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 x_scale = x_scale.view(torch.float8_e4m3fn).reshape( *x.shape[:-1], -1) + trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -664,7 +870,86 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): tune_max_num_tokens=self.max_capture_size, )[0] return trtllm_gen_output - else: + elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + # Backend-specific preparation + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: + + from flashinfer import mxfp8_quantize + + x_quant, x_scale = mxfp8_quantize(x, True, 32) + + fake_input_scale = torch.ones(self.num_experts, + device=x.device) + quant_scales = [ + layer.w13_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + layer.w2_weight_scale.contiguous().view(torch.int32), + fake_input_scale, + ] + + fi_input = x_quant + extra_kwargs = dict( + use_mxfp8_act_scaling=True, + input_sf=x_scale, + fc1_expert_weights=layer.w13_weight.contiguous().view( + torch.long), + fc2_expert_weights=layer.w2_weight.contiguous().view( + torch.long), + ) + elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + assert x.dtype == torch.bfloat16 + + quant_scales = [ + layer.w13_weight_scale, + layer.w2_weight_scale, + ] + + fi_input = x + extra_kwargs = dict( + use_w4_group_scaling=True, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + ) + + output = torch.empty_like(x, dtype=torch.bfloat16) + _ = flashinfer_cutlass_fused_moe( + input=fi_input, + token_selected_experts=topk_ids.to(torch.int).contiguous(), + token_final_scales=topk_weights, + output_dtype=torch.bfloat16, + output=output, + quant_scales=quant_scales, + fc1_expert_biases=layer.w13_bias, + fc2_expert_biases=layer.w2_bias, + swiglu_alpha=layer.gemm1_alpha, + swiglu_beta=layer.gemm1_beta, + swiglu_limit=layer.gemm1_clamp_limit, + tp_size=self.moe.tp_size, + tp_rank=self.moe.tp_rank, + ep_size=self.moe.ep_size, + ep_rank=self.moe.ep_rank, + tune_max_num_tokens=self.max_capture_size, + **extra_kwargs, + ) + + return output + elif self.mxfp4_backend == Mxfp4Backend.TRITON: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward) return triton_kernel_moe_forward( @@ -676,9 +961,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): renormalize=renormalize, global_num_experts=global_num_experts, expert_map=expert_map, - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_precision=self.w13_precision_config, - w2_precision=self.w2_precision_config, + quant_config=self.moe_quant_config, apply_router_weight_on_input=apply_router_weight_on_input, ) + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py deleted file mode 100644 index 8040236663dd1..0000000000000 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from importlib.util import find_spec -from typing import Any, Optional - -from torch.nn import Module - -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) - -SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] - - -class AlwaysSupportedDtypes(list): - - def __contains__(self, item): - return True - - -class NeuronQuantConfig(QuantizationConfig): - """Int8 Quantization Config class for Neuron Backend.""" - - def __init__( - self, - dequant_dtype: str = "f16", - quantize_method: str = "vector_dynamic", - ) -> None: - super().__init__() - self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") - if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: - raise ValueError( - f"Neuron quantization datatype {self.quant_dtype} is not valid," - f" the quantization datatype should match one of the below " - f"types {SUPPORTED_QUANT_DTYPE_LIST}") - self.dequant_dtype = dequant_dtype - self.quantize_method = quantize_method - - def get_name(self) -> QuantizationMethods: - return "neuron_quant" - - def get_supported_act_dtypes(self) -> list[str]: - # Neuron implements custom handling logic for quantization support - return AlwaysSupportedDtypes() - - @classmethod - def get_min_capability(cls) -> int: - raise NotImplementedError( - "This function should not be called with Neuron Backend") - - @staticmethod - def get_config_filenames() -> list[str]: - return [] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig": - quantize_method = cls.get_from_keys(config, ["quantize_method"]) - dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) - return cls(dequant_dtype=dequant_dtype, - quantize_method=quantize_method) - - def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: - if find_spec("transformers_neuronx") is not None: - return self.get_quantization_config() - else: - raise NotImplementedError( - "Neuron Quantization is only supported through" - " transformers_neuronx.") - - def get_quantization_config(self): - from transformers_neuronx.config import QuantizationConfig - return QuantizationConfig(quant_dtype=self.quant_dtype, - dequant_dtype=self.dequant_dtype, - quantize_method=self.quantize_method) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 466fd5fba7685..45ea8e3520f1d 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -92,13 +92,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): """ def __init__(self, quant_config: PTPCFp8Config): + assert current_platform.is_rocm(), \ + "PTPCFp8LinearMethod is only supported on ROCm." super().__init__(quant_config=quant_config) # Force weight quantization self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear = Fp8LinearOp( - act_quant_static=False, - act_quant_group_shape=GroupShape.PER_TOKEN, - force_fp8_e4m3fnuz=True) + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index fdf03ded04480..d2d990e46bcf3 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,21 +1,32 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + mxfp4_w4a4_moe_quant_config) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( OCP_MX_BLOCK_SIZE) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -67,21 +78,45 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): self.weight_quant = weight_config self.input_quant = input_config - weight_qscheme = self.weight_quant.get("qscheme") - input_qscheme = self.input_quant.get("qscheme") - if not (weight_qscheme == "per_tensor" - and input_qscheme == "per_tensor"): + self.weight_qscheme = self.weight_quant.get("qscheme") + self.input_qscheme = self.input_quant.get("qscheme") + per_tensor = (self.weight_qscheme == "per_tensor" + and self.input_qscheme == "per_tensor") + per_channel = (self.weight_qscheme == "per_channel" + and self.input_qscheme == "per_channel") + self.act_quant_group_shape = GroupShape.PER_TOKEN \ + if per_channel else GroupShape.PER_TENSOR + if not (per_tensor or per_channel): raise ValueError( - "For FP8 Fused MoE layers, only per-tensor scales " - "for weights and activations are supported. Found " - f"{weight_qscheme}, {input_qscheme}") # noqa E501 + "For FP8 Fused MoE layers, only per-tensor and per-channel " + "scales for weights and activations are supported. Found " + f"{self.weight_qscheme}, {self.input_qscheme}") # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization.") + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None params_dtype = torch.float8_e4m3fn # WEIGHTS @@ -104,24 +139,39 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + if self.weight_qscheme == "per_tensor": + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + elif self.weight_qscheme == "per_channel": + # quark's scale is 1 dim. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: @@ -185,24 +235,70 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, requires_grad=False) - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size + # For per-tensor case, Fp8 moe kernel needs single weight scale + # for w13 per expert. Use max then dequant and requant each expert. + if self.weight_qscheme == "per_tensor": + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + # quark's scale is 1 dim. + elif self.weight_qscheme == "per_channel": + if self.act_quant_group_shape == GroupShape.PER_TOKEN: + w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False) + w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + # Property to determine if AITER is used + if self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts, shuffle_weights) + + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts + elif self.use_marlin: + + prepare_moe_fp8_layer_for_marlin(layer, False) + # Activations not quantized for marlin. + del layer.w13_input_scale + del layer.w2_input_scale + self.fused_experts_func = None + else: + from vllm.model_executor.layers.fused_moe import fused_experts + self.fused_experts_func = fused_experts + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=self.weight_qscheme == "per_channel", + ) def apply( self, @@ -226,15 +322,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -249,22 +343,50 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, + if self.rocm_aiter_moe_enabled: + return self.rocm_aiter_fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + quant_config=self.moe_quant_config, + expert_map=expert_map) + if self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + None, + None, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map) + + assert self.fused_experts_func is not None + + return self.fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, + activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - activation=activation) + quant_config=self.moe_quant_config) class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): @@ -368,6 +490,16 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return mxfp4_w4a4_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + block_shape=None, + ) + def apply( self, layer: torch.nn.Module, @@ -390,7 +522,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -420,15 +552,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - use_mxfp4_w4a4=True, + activation=activation, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=None, - a2_scale=None, - block_shape=None, - activation=activation, + quant_config=self.moe_quant_config, ) return out diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 8f72b8cbea7a7..ed90e2e26460e 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -3,7 +3,7 @@ # Copyright © 2025, Oracle and/or its affiliates. import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase): fix_weights(layer, "w13_weight", weight_bits == 4) fix_weights(layer, "w2_weight", weight_bits == 4) + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + weight_bits = self.quant_config.weight_bits + group_size = self.quant_config.group_size + assert weight_bits == 4 or weight_bits == 8 + config_builder = (int4_w4a16_moe_quant_config + if weight_bits == 4 else int8_w8a16_moe_quant_config) + return config_builder( + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, group_size], + ) + def apply( self, layer: torch.nn.Module, @@ -291,7 +309,7 @@ class RTNMoEMethod(FusedMoEMethodBase): expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: @@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase): e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) - weight_bits = self.quant_config.weight_bits - group_size = self.quant_config.group_size - - ret = fused_experts( + return fused_experts( x, layer.w13_weight, layer.w2_weight, @@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, activation=activation, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=global_num_experts, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, expert_map=expert_map, - block_shape=[0, group_size]) - - return ret + quant_config=self.moe_quant_config, + ) def rtn_quantize(tensor: torch.Tensor, num_bits: int, diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 63b2ab6bab063..2efb605f203fd 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -144,34 +144,36 @@ def torchao_quantize_param_data(param: torch.Tensor, """Quantize a Tensor with torchao quantization specified by torchao_config Args: - `param`: weight parameter of the linear module - `torchao_config`: type of quantization and their arguments we want to - use to quantize the Tensor + param: weight parameter of the linear module + torchao_config: type of quantization and their arguments we want to + use to quantize the Tensor """ from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" - """ - Avoid real weight allocation for faster load, since we will + """ + Avoid real weight allocation for faster load, since we will end up setting it to param. """ with torch.device("meta"): - dummy_linear = torch.nn.Linear(param.shape[1], - param.shape[0], - bias=False) + # linear can't be top level module since quantize_ is inplace + # while some of our configs need to do module swap, and only non-top + # level modules support module swap + dummy_linear = torch.nn.Sequential( + torch.nn.Linear(param.shape[1], param.shape[0], bias=False)) - dummy_linear.weight = param + dummy_linear[0].weight = param quantize_(dummy_linear, torchao_config) - return dummy_linear.weight + return dummy_linear[0].weight class TorchAOLinearMethod(LinearMethodBase): """Linear method for torchao. Args: - torchao_config: The torchao quantization config, a string - that encodes the type of quantization and all relevant arguments. + quant_config: The torchao quantization config, a string that encodes + the type of quantization and all relevant arguments. """ def __init__(self, quant_config: TorchAOConfig): diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index f5d7c57fe2a87..fabf855b36e68 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -7,7 +7,8 @@ import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 @@ -47,32 +48,23 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor, def build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig, - a1_gscale: torch.Tensor, -) -> mk.FusedMoEPrepareAndFinalize: + moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 - return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale) + return FlashInferCutlassMoEPrepareAndFinalize(use_dp) def select_nvfp4_gemm_impl( moe: FusedMoEConfig, - g1_alphas: torch.Tensor, - g2_alphas: torch.Tensor, - a1_gscale: torch.Tensor, - a2_gscale: torch.Tensor, + moe_quant_config: FusedMoEQuantConfig, allow_flashinfer: bool, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" if allow_flashinfer: return FlashInferExperts( - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, - a1_gscale=a1_gscale, - a2_gscale=a2_gscale, out_dtype=moe.in_dtype, - quant_dtype="nvfp4", + quant_config=moe_quant_config, ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 9889808f0760f..aa66a42c588a7 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -8,7 +8,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 @@ -99,6 +100,8 @@ def apply_flashinfer_per_tensor_scale_fp8( apply_router_weight_on_input: bool, ) -> torch.Tensor: from flashinfer.fused_moe import RoutingMethodType + + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 assert layer.output1_scales_scalar is not None, ( "Expected output1_scales_scalar to be initialized") assert layer.output1_scales_scalar is not None, ( @@ -167,34 +170,23 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None: def build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe: Optional[FusedMoEConfig], - layer: torch.nn.Module, -) -> mk.FusedMoEPrepareAndFinalize: + moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False - return FlashInferCutlassMoEPrepareAndFinalize( - use_dp, a1_gscale=layer.w13_input_scale) + return FlashInferCutlassMoEPrepareAndFinalize(use_dp) def select_cutlass_fp8_gemm_impl( moe: Optional[FusedMoEConfig], - layer: torch.nn.Module, + quant_config: FusedMoEQuantConfig, out_dtype: Optional[torch.dtype] = None, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for fused-MoE layers""" - from vllm.model_executor.models.llama4 import Llama4MoE - assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ - "FusedMoE flashinfer kernels are only supported for Llama4" - if moe is not None: return FlashInferExperts( - g1_alphas=layer.output1_scales_gate_scalar, - g2_alphas=layer.output2_scales_scalar, - a1_gscale=layer.w13_input_scale, - a2_gscale=layer.w2_input_scale_inv, out_dtype=moe.in_dtype, - quant_dtype=torch.float8_e4m3fn, + quant_config=quant_config, ep_rank=moe.moe_parallel_config.ep_rank, ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, @@ -204,12 +196,8 @@ def select_cutlass_fp8_gemm_impl( assert out_dtype is not None, ( "If moe config is None, out_dtype must be passed") return FlashInferExperts( - g1_alphas=layer.output1_scales_gate_scalar, - g2_alphas=layer.output2_scales_scalar, - a1_gscale=layer.w13_input_scale, - a2_gscale=layer.w2_input_scale_inv, out_dtype=out_dtype, - quant_dtype=torch.float8_e4m3fn, + quant_config=quant_config, ) @@ -224,11 +212,13 @@ def flashinfer_cutlass_moe_fp8( expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: + quant_config = layer.quant_method.get_fused_moe_quant_config(layer) + assert quant_config is not None + fused_experts = mk.FusedMoEModularKernel( - build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None, - layer=layer), + build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None), select_cutlass_fp8_gemm_impl(moe=None, - layer=layer, + quant_config=quant_config, out_dtype=hidden_states.dtype)) return fused_experts( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7b324dce3c367..fc12483de0c0e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -17,6 +17,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( group_broadcast) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ChannelQuantScaleParameter, + PerTensorScaleParameter) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op @@ -40,11 +43,14 @@ def cutlass_scaled_mm( block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - return ops.cutlass_scaled_mm(A, - B.T, - out_dtype=output_dtype, - scale_a=As, - scale_b=Bs.T) + return ops.cutlass_scaled_mm( + A, + B.T, + out_dtype=output_dtype, + scale_a=As, + # SM90 block FP8 requires row-major scale_b, which we do ahead of time + scale_b=Bs if block_size is not None + and current_platform.is_device_capability(90) else Bs.T) def rocm_aiter_gemm_w8a8_blockscale_impl( @@ -152,35 +158,32 @@ def apply_w8a8_block_fp8_linear( output += bias return output.to(dtype=output_dtype).view(*output_shape) - if current_platform.is_cuda(): - if current_platform.has_device_capability(100): - - use_cutlass = cutlass_block_fp8_supported and ( - cdiv(weight.shape[0], 128) == weight_scale.shape[0] - and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) - else: - # TODO: update this after switching to public sm90 block scale gemm - # as it also supports weight.shape % 128 != 0 - use_cutlass = cutlass_block_fp8_supported and ( - weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) - else: - use_cutlass = False - w8a8_blockscale_func = dispatch_w8a8_blockscale_func( - use_cutlass, use_aiter_and_is_supported) - if use_cutlass: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) + cutlass_block_fp8_supported, use_aiter_and_is_supported) + if cutlass_block_fp8_supported: + num_pad = 0 + if current_platform.is_device_capability(90): + # pad first dimension to be divisible by 4 due to + # cutlass blockwise gemm limitation for hopper + num_pad = 4 - (input_2d.shape[0] % 4) + if num_pad > 0: + input_2d = torch.nn.functional.pad(input_2d, + (0, 0, 0, num_pad), + "constant", 0) + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=True) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype) - + if num_pad > 0: + output = output[:-num_pad] else: if use_aiter_and_is_supported: q_input, x_scale = aiter_per1x128_quant( input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) else: q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=use_cutlass) + input_2d, block_size[1], column_major_scales=False) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype) @@ -411,6 +414,7 @@ def per_token_group_quant_fp8( x_s = torch.empty(shape, device=x.device, dtype=torch.float32) # prefer CUDA kernel if available + # TODO(bnell): this causes some fp8 moe test to fail. if current_platform.is_cuda() and x.is_contiguous(): torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0) @@ -793,3 +797,220 @@ def requant_weight_ue8m0_inplace( # Write back the results in-place. w_q.copy_(w_requant) s_old.copy_(s_requant) + + +def check_aiter_fp8_linear_support() -> bool: + """AITER is only supported on ROCm and only for FP8_FNUZ + and at the moment are MI300 series""" + return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) + + +def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: + """Pad the weight tensor. This is an optimization on ROCm platform, which + can benefit from tensors located far enough from one another in memory""" + if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + import torch.nn.functional as F + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + + +def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int, + output_size: int, input_size_per_partition: int, + output_partition_sizes: list[int], + block_size: list[int]) -> None: + """Validate block quantization shapes for tensor parallelism.""" + from vllm.distributed import get_tensor_model_parallel_world_size + + tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size()) + block_n, block_k = block_size[0], block_size[1] + + # Required by row parallel + if (tp_size > 1 and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition} " + f"is not divisible by weight quantization block_k = {block_k}.") + + # Required by column parallel or enabling merged weights + is_tp_split = (tp_size > 1 + and output_size // sum(output_partition_sizes) == tp_size) + is_merged_gemm = len(output_partition_sizes) > 1 + if is_tp_split or is_merged_gemm: + sizes_to_check = output_partition_sizes + if not is_tp_split and is_merged_gemm: + # In case of merged matrices, we allow the last + # matrix to not be a multiple of block size + sizes_to_check = output_partition_sizes[:-1] + for output_partition_size in sizes_to_check: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}.") + + +def create_fp8_weight_parameter( + output_size_per_partition: int, input_size_per_partition: int, + weight_loader: Optional[Callable]) -> torch.nn.Parameter: + """Create FP8 weight parameter.""" + from vllm.model_executor.parameter import ModelWeightParameter + + return ModelWeightParameter(data=torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + +def create_fp8_scale_parameter( + parameter_type: torch.nn.Parameter, output_partition_sizes: list[int], + input_size_per_partition: int, block_size: Optional[list[int]], + weight_loader: Optional[Callable]) -> torch.nn.Parameter: + """Create scale parameter based on quantization strategy.""" + if parameter_type == ChannelQuantScaleParameter: + scale = parameter_type(data=torch.empty( + (sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + elif parameter_type == BlockQuantScaleParameter: + assert block_size is not None + block_n, block_k = block_size[0], block_size[1] + output_size_per_partition = sum(output_partition_sizes) + scale = parameter_type( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + elif parameter_type == PerTensorScaleParameter: + scale = parameter_type(data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader) + else: + raise ValueError(f"Unknown parameter type: {parameter_type}") + + scale[:] = torch.finfo(torch.float32).min + return scale + + +def create_fp8_input_scale( + output_partition_sizes: list[int], + weight_loader: Optional[Callable]) -> torch.nn.Parameter: + """Create input scale parameter for static activation quantization.""" + from vllm.model_executor.parameter import PerTensorScaleParameter + + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + scale[:] = torch.finfo(torch.float32).min + return scale + + +def process_fp8_weight_tensor_strategy( + weight: torch.Tensor, + weight_scale: torch.Tensor, + logical_widths: list[int], + input_scale: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Process weights for tensor-wise quantization strategy.""" + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale, input_scale=input_scale) + + # Requantize with max scale + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=logical_widths, + ) + + weight = _maybe_pad_fp8_weight(weight) + return weight, weight_scale, input_scale + + +def process_fp8_weight_channel_strategy( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Process weights for channel-wise quantization strategy.""" + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz) + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale, input_scale=input_scale) + + return weight, weight_scale, input_scale + + +def process_fp8_weight_block_strategy( + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Process weights for block-wise quantization strategy.""" + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + normalize_e4m3fn_to_e4m3fnuz) + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale) + + weight = _maybe_pad_fp8_weight(weight) + return weight, weight_scale + + +def maybe_post_process_fp8_weight_block(layer: torch.nn.Module, + cutlass_block_fp8_supported: bool): + assert layer.weight_block_size is not None + + from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear) + + # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to + # requantize the weight and input to the specific scale + # at the same time. + if is_deep_gemm_e8m0_used(): + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace(layer.weight.data, + layer.weight_scale.data, block_sz) + # SM90 Block FP8 CUTLASS requires row-major weight scales + elif (current_platform.is_device_capability(90) + and cutlass_block_fp8_supported + and not should_use_deepgemm_for_fp8_linear(torch.bfloat16, + layer.weight)): + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data.T.contiguous(), requires_grad=False) + + +def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor, + bias: Optional[torch.Tensor], + cutlass_block_fp8_supported: bool, + use_aiter_and_is_supported: bool) -> torch.Tensor: + """Apply block-wise FP8 linear operation.""" + assert layer.weight_block_size is not None + + return torch.ops.vllm.apply_w8a8_block_fp8_linear( + input=input, + weight=layer.weight, + block_size=layer.weight_block_size, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_block_fp8_supported=cutlass_block_fp8_supported, + use_aiter_and_is_supported=use_aiter_and_is_supported, + ) diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 6840cabbf1ae3..62e458ec3c93e 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -423,7 +423,7 @@ def w8a8_block_int8_matmul( Bs: The per-block quantization scale for `B`. block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. - output_dytpe: The dtype of the returned tensor. + output_dtype: The dtype of the returned tensor. Returns: torch.Tensor: The result of matmul. diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 02057b476c6e2..317ad079b392d 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -201,7 +201,7 @@ def marlin_make_workspace(output_size_per_partition: int, def marlin_make_workspace_new(device: torch.device, max_blocks_per_sm: int = 1) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace - # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + # size. The num of threadblocks is sms_count * max_blocks_per_sm. sms = torch.cuda.get_device_properties(device).multi_processor_count return torch.zeros(sms * max_blocks_per_sm, dtype=torch.int, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index b2c228c242532..f5acd03cc6622 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -19,7 +19,7 @@ class MarlinWorkspace: def __init__(self, out_features, min_thread_n, max_parallel): assert (out_features % min_thread_n == 0), ( - "out_features = {} is undivisible by min_thread_n = {}".format( + "out_features = {} is indivisible by min_thread_n = {}".format( out_features, min_thread_n)) max_workspace_size = ((out_features // min_thread_n) * max_parallel) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index f4ff875adb21c..5339c6043cc16 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -34,6 +34,15 @@ class GroupShape(_GroupShape): PER_TENSOR: ClassVar['GroupShape'] PER_TOKEN: ClassVar['GroupShape'] + def is_per_tensor(self) -> bool: + return self.row == -1 and self.col == -1 + + def is_per_token(self) -> bool: + return self.row == 1 and self.col == -1 + + def is_per_group(self) -> bool: + return self.row == 1 and self.col >= 1 + GroupShape.PER_TENSOR = GroupShape(-1, -1) GroupShape.PER_TOKEN = GroupShape(1, -1) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 5333bbd310ff9..8cda1789e6c97 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -23,7 +23,7 @@ TORCH_DEVICE_IDENTITY = None # The condition to determine if it is on a platform that supports # torch._scaled_mm rowwise feature. # The condition is determined once as the operations -# are time consuming. +# are time-consuming. USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() and version.parse( torch.__version__) >= version.parse("2.7") and current_platform.has_device_capability(94)) @@ -171,13 +171,15 @@ def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, bias=bias) -def rocm_per_tensor_w8a8_scaled_mm_impl( - qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( - ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: + ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None: output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, current_platform.get_cu_count()) else: @@ -190,10 +192,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl( return output -def rocm_per_tensor_w8a8_scaled_mm_fake( - qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, - scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm_fake(qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor) -> torch.Tensor: return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype) @@ -203,11 +207,10 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list) -> torch.Tensor: output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( - qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + qinput, weight, out_dtype, scale_a, scale_b, bias) + return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) direct_register_custom_op( @@ -224,7 +227,6 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list) -> torch.Tensor: output = torch._scaled_mm(qinput, weight, @@ -237,7 +239,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, if type(output) is tuple and len(output) == 2: output = output[0] - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, @@ -245,7 +247,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list, + output_shape: list, **kwargs) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. @@ -265,7 +267,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_b=scale_b.t(), bias=bias) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = torch.narrow(output, 0, 0, qinput.shape[0]) output = output.view(*output_shape) return output @@ -275,7 +277,6 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, output_shape: list, **kwargs) -> torch.Tensor: # Use unfused DQ due to limitations with scaled_mm @@ -305,8 +306,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, if type(output) is tuple and len(output) == 2: output = output[0] # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) + output = torch.narrow(output, 0, 0, qinput.shape[0]) + x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0]) # DQ # C = sw * sx * (X * W) + bias @@ -355,12 +356,10 @@ class Fp8LinearOp: def __init__(self, act_quant_static: bool, act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: Optional[bool] = None, - force_fp8_e4m3fnuz: bool = False): + pad_output: Optional[bool] = None): if current_platform.is_rocm(): self.preferred_backend = "rocm" - elif current_platform.is_cuda( - ) and not force_fp8_e4m3fnuz and cutlass_fp8_supported(): + elif current_platform.is_cuda() and cutlass_fp8_supported(): if has_flashinfer() and current_platform.has_device_capability( 100): self.preferred_backend = "flashinfer" @@ -432,7 +431,6 @@ class Fp8LinearOp: scale_a=x_scale, scale_b=weight_scale, bias=bias, - input_2d=input_2d, output_shape=output_shape) diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 564f9a5c00750..c9653aa9e4405 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -103,6 +103,8 @@ def get_rope( is_neox_style, dtype, mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", + False), ) else: rotary_emb = RotaryEmbedding( diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 10fce857a8ae2..db50eb08db3ff 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -7,7 +7,7 @@ import torch from vllm.model_executor.custom_op import CustomOp -from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch +from .common import apply_rotary_emb_torch @CustomOp.register("rotary_embedding") @@ -62,11 +62,8 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) @@ -96,7 +93,6 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm import _custom_ops as ops @@ -107,16 +103,10 @@ class RotaryEmbedding(CustomOp): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, - self.is_neox_style, self.rotary_dim, - offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key def forward_xpu( @@ -124,110 +114,21 @@ class RotaryEmbedding(CustomOp): positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: from vllm._ipex_ops import ipex_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. if key is None: # XPU kernel doesn't support key=None so fall back to native impl # TODO(sarckk): add support for optional key in # ipex.llm.functional.rotary_embedding_batched - return self.forward_native(positions, query, key, offsets) + return self.forward_native(positions, query, key) else: - if offsets is not None: - ops.batched_rotary_embedding(positions, query, key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - self.rotary_dim, offsets) - else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) - return query, key - - def forward_neuron( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - - def _apply_rotary_emb_neuron( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, - ) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - # x1 = x[..., ::2] - - # x2 = x[..., 1::2] - d = x.shape[-1] // 2 - x_reshaped = x.view(-1, x.shape[-1]) - x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d) - x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d) - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - if offsets is not None: - positions = positions + offsets - - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) - - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - if key is not None: - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - - if self.rotary_dim == self.head_size: - query = apply_rotary_emb_dispatch(query, cos, sin, - self.is_neox_style) - query = query.reshape(query_shape) - if key is not None: - key = apply_rotary_emb_dispatch(key, cos, sin, - self.is_neox_style) - key = key.reshape(key_shape) - else: - head_size = query.shape[-1] - query_reshaped = query.view(-1, head_size) - query_pass = query_reshaped[:, self.rotary_dim:].view( - *query.shape[:-1], head_size - self.rotary_dim) - query_rot = query_reshaped[:, :self.rotary_dim].view( - *query.shape[:-1], self.rotary_dim) - query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin, - self.is_neox_style) - query = torch.cat((query_rot, query_pass), - dim=-1).reshape(query_shape) - - if key is not None: - key_reshaped = key.view(-1, head_size) - key_pass = key_reshaped[:, self.rotary_dim:].view( - *key.shape[:-1], head_size - self.rotary_dim) - key_rot = key_reshaped[:, :self.rotary_dim].view( - *key.shape[:-1], self.rotary_dim) - key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, - self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index cd888b733426b..7ac2e4bb6c34f 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -88,7 +88,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): cache = torch.cat((cos, sin), dim=-1) return cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -129,3 +129,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): query = query_rot key = key_rot return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) diff --git a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py index 3d8da0fa9d8f5..27e41dd0fa97e 100644 --- a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py @@ -111,7 +111,7 @@ class DualChunkRotaryEmbedding(CustomOp): device=self.device) return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -161,6 +161,15 @@ class DualChunkRotaryEmbedding(CustomOp): dim=-1) return query, key + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.forward_native(positions, query, key, offsets) + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 05322e56f2620..4960c20f4060a 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -12,7 +12,7 @@ from .mrope import MRotaryEmbedding class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): """3D rotary positional embedding. 3D is t:time h:height w:width""" - def forward( + def forward_native( # type: ignore[override] self, positions: torch.Tensor, query: torch.Tensor, @@ -70,3 +70,11 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + + def forward_cuda( # type: ignore[override] + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key) \ No newline at end of file diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index 415a85ab698bc..37ead43e22bc4 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -53,7 +53,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) return cache - def forward( + def forward_native( # type: ignore[override] self, query: torch.Tensor, key: Optional[torch.Tensor] = None, @@ -72,3 +72,10 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) + + def forward_cuda( # type: ignore[override] + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(query, key) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0ab4bc5375daf..ef61dbc1a5ab1 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -8,7 +8,6 @@ import numpy as np import torch from transformers import PretrainedConfig -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from .base import RotaryEmbedding @@ -136,8 +135,8 @@ def triton_mrope( """Qwen2VL mrope kernel. Args: - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] + q: [num_tokens, num_heads * head_size] + k: [num_tokens, num_kv_heads * head_size] cos: [3, num_tokens, head_size //2 ] (T/H/W positions with multimodal inputs) sin: [3, num_tokens, head_size //2 ] @@ -178,6 +177,18 @@ def triton_mrope( return q, k +def apply_interleaved_rope(x: torch.Tensor, + mrope_section: list[int]) -> torch.Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + """ + x_t = x[0].clone() + x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3] + x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3] + return x_t + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -190,6 +201,7 @@ class MRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, dtype: torch.dtype, mrope_section: Optional[list[int]] = None, + mrope_interleaved: Optional[bool] = False, ) -> None: # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get @@ -199,31 +211,10 @@ class MRotaryEmbedding(RotaryEmbedding): base, is_neox_style, dtype) self.mrope_section = mrope_section + self.mrope_interleaved = mrope_interleaved if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 - self.use_triton = current_platform.is_cuda_alike() - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """MRope forward. - - Args: - positions: - [num_tokens,] (text only) or - [3, num_tokens] (T/H/W positions with multimodal inputs) - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] - """ - if self.use_triton: - return self.forward_cuda(positions, query, key) - else: - return self.forward_native(positions, query, key) - def forward_native( self, positions: torch.Tensor, @@ -248,17 +239,20 @@ class MRotaryEmbedding(RotaryEmbedding): cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section - - cos = torch.cat([ - m[i] - for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) - ], - dim=-1) - sin = torch.cat([ - m[i] - for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) - ], - dim=-1) + if self.mrope_interleaved: + cos = apply_interleaved_rope(cos, self.mrope_section) + sin = apply_interleaved_rope(sin, self.mrope_section) + else: + cos = torch.cat([ + m[i] for i, m in enumerate( + cos.split(self.mrope_section, dim=-1)) + ], + dim=-1) + sin = torch.cat([ + m[i] for i, m in enumerate( + sin.split(self.mrope_section, dim=-1)) + ], + dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) @@ -288,6 +282,10 @@ class MRotaryEmbedding(RotaryEmbedding): assert positions.ndim == 1 or positions.ndim == 2 assert key is not None + if self.mrope_interleaved: + # TODO: add triton implementation to support mrope-interleaved + return self.forward_native(positions, query, key) + num_tokens = positions.shape[-1] cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) @@ -323,6 +321,24 @@ class MRotaryEmbedding(RotaryEmbedding): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) + + def forward_cpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) + @classmethod def get_input_positions( cls, @@ -393,6 +409,15 @@ class MRotaryEmbedding(RotaryEmbedding): context_len=context_len, seq_len=seq_len, ) + elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + return cls._qwen3vl_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: return cls._ernie_get_input_positions_tensor( input_tokens=input_tokens, @@ -531,6 +556,98 @@ class MRotaryEmbedding(RotaryEmbedding): len(input_tokens)).item() return llm_positions, mrope_position_delta + @classmethod + def _qwen3vl_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw + for _ in range(t)] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions, mrope_position_delta + @classmethod def _ernie_get_input_positions_tensor( cls, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 829dd82b0bd4d..9d93cad2420ad 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -649,7 +649,7 @@ def _sample_with_torch( else: sampled_token_ids_tensor = None - # Counterintiutively, having two loops here is actually faster. + # Counterintuitively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: sample_indices = categorized_sample_indices[sampling_type] diff --git a/vllm/model_executor/layers/shared_fused_moe/__init__.py b/vllm/model_executor/layers/shared_fused_moe/__init__.py new file mode 100644 index 0000000000000..b87c69d3edd04 --- /dev/null +++ b/vllm/model_executor/layers/shared_fused_moe/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import ( + SharedFusedMoE) + +__all__ = ["SharedFusedMoE"] diff --git a/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py new file mode 100644 index 0000000000000..e1e3d188d9852 --- /dev/null +++ b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + +# TODO(bnell): Add shared + fused combo function? e.g. + +class SharedFusedMoE(FusedMoE): + """ + A FusedMoE operation that also computes the results of shared experts. + If an all2all communicator is being used the shared expert computation + can be interleaved with the fused all2all dispatch communication step. + """ + + def __init__( + self, + shared_experts: torch.nn.Module, + use_overlapped: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._shared_experts = shared_experts + self.use_overlapped = use_overlapped + + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return self._shared_experts if self.use_overlapped else None + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not self.use_overlapped: + shared_out = self._shared_experts(hidden_states) + + # Reduce outputs if necessary, since the MLP should + # have been created with reduce_results=False. + if (self.reduce_results and self.tp_size > 1 + and self.must_reduce_shared_expert_outputs()): + shared_out = tensor_model_parallel_all_reduce(shared_out) + + fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + else: + shared_out, fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 2897f75b3129e..d2b135c1e4d4e 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -142,20 +142,49 @@ direct_register_custom_op( ) -def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype): +def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool: return (torch._C._cpu._is_amx_tile_supported() and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0 and n % 16 == 0) +def dispatch_cpu_unquantized_gemm( + layer: torch.nn.Module, + remove_weight: bool, +) -> None: + N, K = layer.weight.size() + dtype = layer.weight.dtype + if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype): + packed_weight = torch.ops._C.convert_weight_packed(layer.weight) + if getattr(layer, "bias", None) is not None: + bias_f32 = layer.bias.to(torch.float32) + else: + bias_f32 = None + layer.cpu_linear = ( + lambda x, weight, bias: torch.ops._C.weight_packed_linear( + x, packed_weight, bias_f32 + if bias is not None else None, True)) + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), + requires_grad=False) + elif ops._supports_onednn: + origin_weight = layer.weight + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), + requires_grad=False) + handler = ops.create_onednn_mm(origin_weight.t(), 32) + layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm( + handler, x, bias) + else: + layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( + x, weight, bias) + + def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - if getattr(layer, "use_cpu_sgl", False): - return torch.ops._C.weight_packed_linear(x, weight, bias, True) - else: - return torch.nn.functional.linear(x, weight, bias) + return layer.cpu_linear(x, weight, bias) def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9f223998e554f..aa64d4e09ae18 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -40,6 +40,12 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_cpu(): + from vllm.model_executor.layers.utils import ( + dispatch_cpu_unquantized_gemm) + dispatch_cpu_unquantized_gemm(layer, remove_weight=False) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, @@ -252,7 +258,7 @@ class VocabParallelEmbedding(CustomOp): if params_dtype is None: params_dtype = torch.get_default_dtype() - # Divide the weight matrix along the vocaburaly dimension. + # Divide the weight matrix along the vocabulary dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size self.num_embeddings_per_partition = divide(self.num_embeddings_padded, self.tp_size) @@ -393,7 +399,7 @@ class VocabParallelEmbedding(CustomOp): param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[loaded_weight.shape[0]:].data.fill_(0) - def forward(self, input_): + def forward_native(self, input_): if self.tp_size > 1: # Build the mask. masked_input, input_mask = get_masked_input_and_mask( @@ -414,6 +420,9 @@ class VocabParallelEmbedding(CustomOp): output = tensor_model_parallel_all_reduce(output_parallel) return output + def forward_cuda(self, input_): + return self.forward_native(input_) + def extra_repr(self) -> str: s = f"num_embeddings={self.num_embeddings_per_partition}" s += f", embedding_dim={self.embedding_dim}" @@ -423,6 +432,7 @@ class VocabParallelEmbedding(CustomOp): return s +@CustomOp.register("parallel_lm_head") class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head. diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 2dada794a8f3e..138a2ff30b622 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -5,7 +5,8 @@ from typing import Literal, Optional from torch import nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.bitsandbytes_loader import ( @@ -67,7 +68,7 @@ def register_model_loader(load_format: str): load_format (str): The model loader format name. Examples: - >>> from vllm.config import LoadConfig + >>> from vllm.config.load import LoadConfig >>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader >>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader >>> diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 4cf6c7988960d..ab538a3c95620 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -5,7 +5,8 @@ from abc import ABC, abstractmethod import torch import torch.nn as nn -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index b8393956eed3f..4edf193b54ac5 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -16,7 +16,8 @@ from packaging import version from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) # yapf: enable @@ -69,6 +70,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Store all module names (from transformers) that support # BNB quantization. self.target_modules: list[str] = [] + self.tp_disabled_modules: list[str] = [] # Store the mapping of expert parameters for MoE models. self.expert_params_mapping: list[tuple[str, str, int, str]] = [] # mapping weight names from transformers to vllm. @@ -322,25 +324,36 @@ class BitsAndBytesModelLoader(BaseModelLoader): quant_state_dict) -> Generator: from bitsandbytes.functional import quantize_4bit - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - + global_tp_size = get_tensor_model_parallel_world_size() + global_tp_rank = get_tensor_model_parallel_rank() + check_match = (lambda weight_name, module_name: weight_name. + removesuffix(".weight") == module_name) for ( org_weight_name, mapped_weight_name, weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + + # override tp_size and tp_rank if the module has disabled TP + if any(tp_disabled_module in mapped_weight_name + for tp_disabled_module in self.tp_disabled_modules): + tp_size = 1 + tp_rank = 0 + else: + tp_size = global_tp_size + tp_rank = global_tp_rank + if any(target_module in mapped_weight_name for target_module in self.target_modules ) and mapped_weight_name.endswith(".weight"): # Without sharding if any( - mapped_weight_name.startswith(module) + check_match(mapped_weight_name, module) for module in self.unsharded_weights_modules): weight_sub_tensor = weight_tensor # Shard by column elif any( - mapped_weight_name.startswith(module) + check_match(mapped_weight_name, module) for module in self.column_sharded_weights_modules): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank @@ -350,14 +363,14 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Weights have fused on disk. In this case, we assume that the # weight and module use same name. elif any( - mapped_weight_name.startswith(module) + check_match(mapped_weight_name, module) for module in self.maybe_fused_weights_modules): # special case for fused weights # get the size of each shard weight tensor total_shard_sizes = next( (sizes for module, sizes in self.maybe_fused_weights_modules.items() - if mapped_weight_name.startswith(module))) + if check_match(mapped_weight_name, module))) total_size = weight_tensor.size(0) assert total_size == sum(total_shard_sizes) # get the start/end index of each shard weight tensor @@ -418,12 +431,16 @@ class BitsAndBytesModelLoader(BaseModelLoader): # Map vllm's names to transformers's names. rep_name, sub_modules = modules_info for sub_name in sub_modules: - self.target_modules.append( - name.replace(rep_name, sub_name)) + new_name = name.replace(rep_name, sub_name) + self.target_modules.append(new_name) + if module.disable_tp: + self.tp_disabled_modules.append(new_name) # Add original module name even if the module has stacked map, # in case model has a mixture of disk-merged and disk-split # weights with same last name. self.target_modules.append(name) + if module.disable_tp: + self.tp_disabled_modules.append(name) elif isinstance(module, FusedMoE) and hasattr( module.quant_method, "quant_config"): # TODO: support FusedMoE with prequant and 8bit. diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 34b8d8e4ed622..d1bdec21fd974 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -7,19 +7,20 @@ import time from collections.abc import Generator, Iterable from typing import Optional, cast -import huggingface_hub import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm import envs -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, - filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator, + filter_files_not_needed_for_inference, maybe_download_from_modelscope, + multi_thread_pt_weights_iterator, + multi_thread_safetensors_weights_iterator, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.platforms import current_platform @@ -29,6 +30,9 @@ logger = init_logger(__name__) class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" + # default number of thread when enable multithread weight loading + DEFAULT_NUM_THREADS = 8 + @dataclasses.dataclass class Source: """A source for weights.""" @@ -53,38 +57,15 @@ class DefaultModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) - if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") - def _maybe_download_from_modelscope( - self, model: str, revision: Optional[str]) -> Optional[str]: - """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + extra_config = load_config.model_loader_extra_config + allowed_keys = {"enable_multithread_load", "num_threads"} + unexpected_keys = set(extra_config.keys()) - allowed_keys - Returns the path to the downloaded model, or None if the model is not - downloaded from ModelScope.""" - if envs.VLLM_USE_MODELSCOPE: - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model, self.load_config.download_dir): - if not os.path.exists(model): - model_path = snapshot_download( - model_id=model, - cache_dir=self.load_config.download_dir, - local_files_only=huggingface_hub.constants. - HF_HUB_OFFLINE, - revision=revision, - ignore_file_pattern=self.load_config.ignore_patterns, - ) - else: - model_path = model - return model_path - return None + if unexpected_keys: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{unexpected_keys}") def _prepare_weights( self, @@ -96,7 +77,7 @@ class DefaultModelLoader(BaseModelLoader): """Prepare weights for the model. If the model is not local, it will be downloaded.""" - model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path = (maybe_download_from_modelscope( model_name_or_path, revision) or model_name_or_path) is_local = os.path.isdir(model_name_or_path) @@ -175,6 +156,7 @@ class DefaultModelLoader(BaseModelLoader): self, source: "Source" ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" + extra_config = self.load_config.model_loader_extra_config hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt, source.allow_patterns_overrides) @@ -195,23 +177,42 @@ class DefaultModelLoader(BaseModelLoader): self.load_config.use_tqdm_on_load, ) else: - weights_iterator = safetensors_weights_iterator( + if extra_config.get("enable_multithread_load"): + weights_iterator = ( + multi_thread_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + max_workers=extra_config.get( + "num_threads", self.DEFAULT_NUM_THREADS), + )) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.safetensors_load_strategy, + ) + else: + if extra_config.get("enable_multithread_load"): + weights_iterator = multi_thread_pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + max_workers=extra_config.get("num_threads", + self.DEFAULT_NUM_THREADS), + ) + else: + weights_iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, ) - else: - weights_iterator = pt_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - self.load_config.pt_load_map_location, - ) if current_platform.is_tpu(): from vllm.platforms.tpu import USE_TPU_COMMONS if not USE_TPU_COMMONS: # In PyTorch XLA, we should call `xm.mark_step` - # requently so that not too many ops are accumulated + # frequently so that not too many ops are accumulated # in the XLA program. import torch_xla.core.xla_model # as xm import torch_xla.core.xla_model as xm diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index f4a7da5744e04..5b8c6268f64ef 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch.nn as nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( initialize_dummy_weights) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 9877cb3b7c06e..aaee8f3f76353 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -9,7 +9,8 @@ import torch.nn as nn from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM -from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.config import ModelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py deleted file mode 100644 index ee484e9a7b0a4..0000000000000 --- a/vllm/model_executor/model_loader/neuron.py +++ /dev/null @@ -1,476 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Utilities for selecting and loading Neuron models in transformers-neuronx -framework.""" -import ast -import copy -import importlib -import os -from typing import Optional - -import torch -import torch.nn as nn -from transformers import PretrainedConfig - -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig) -from vllm.logprobs import Logprob -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import get_quantization_config -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput - -TORCH_DTYPE_TO_NEURON_AMP = { - "auto": "f32", - "half": "f16", - "float16": "f16", - "bfloat16": "bf16", - "float": "f32", - "float32": "f32", - torch.float16: "f16", - torch.bfloat16: "bf16", - torch.float32: "f32", -} - -# Models supported by Neuron. -_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = { - "LlamaForCausalLM": ("transformers_neuronx.llama.model", - "LlamaForSampling", "LlamaForCausalLM"), - "MistralForCausalLM": ("transformers_neuronx.mistral.model", - "MistralForSampling", "MistralForCausalLM") -} - - -class NeuronCausalLM(nn.Module): - - def __init__(self, - config: PretrainedConfig, - on_device_sampling_disabled: bool = False) -> None: - super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) - - self.on_device_sampling_disabled = on_device_sampling_disabled - if self.on_device_sampling_disabled: - # Use default sampler - self.sampler = Sampler() - - # Lazy initialized - self.model: nn.Module - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - ) -> torch.Tensor: - logits = self.model(input_ids, - cache_ids=positions, - start_ids=input_block_ids) - return logits - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - - if self.on_device_sampling_disabled: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - # On-device sampling outputs the token ids directly. - sampled_token_ids = logits.flatten() - next_tokens = [] - sample_idx = 0 - for seq_group in sampling_metadata.seq_groups: - samples = [] - for seq_id in seq_group.seq_ids: - token_id = sampled_token_ids[sample_idx].item() - samples.append( - SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, - logprobs={token_id: Logprob(token_id)})) - sample_idx += 1 - next_tokens.append( - CompletionSequenceGroupOutput(samples=samples, - prompt_logprobs=None)) - - return SamplerOutput(outputs=next_tokens) - - def load_weights(self, model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - - self.model = neuronx_model_cls.from_pretrained(model_name_or_path, - **kwargs) - self.model.to_neuron() - - -class NeuronSpeculationCausalLM(nn.Module): - """A Neuron-optimized causal language model with speculative decoding.""" - - SPECULATION_TERMINATION_ID = -1 - - def __init__(self, speculation_model) -> None: - super().__init__() - self.model = speculation_model - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - ) -> torch.Tensor: - tokens, counts = self.model.speculative_iteration( - input_ids, positions, input_block_ids) - - # Mark the end of accepted speculative tokens for each sequence with the - # speculation termination id. - batch_size, steps = tokens.shape - mask = torch.arange(steps).expand(batch_size, -1) >= counts - tokens[mask] = self.SPECULATION_TERMINATION_ID - - return tokens - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - batch_size, num_steps = logits.shape - seq_ids = [ - seq_id for sg in sampling_metadata.seq_groups - for seq_id in sg.seq_ids - ] - # Organize input tensors by step instead of by sequence. - accepted_token_ids_by_step = logits.transpose(0, 1) - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - sampler_output_list = [] - for step_index in range(num_steps): - if all(token_id == self.SPECULATION_TERMINATION_ID - for token_id in accepted_token_ids_by_step[step_index]): - break - step_output_token_ids = [] - for sequence_index in range(batch_size): - token_id = accepted_token_ids_by_step[step_index][ - sequence_index] - step_output_token_ids.append( - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_ids[sequence_index], - output_token=token_id, - logprobs={token_id: Logprob(token_id)}) - ], - prompt_logprobs=None)) - sampler_output_list.append( - SamplerOutput(outputs=step_output_token_ids)) - return sampler_output_list - - -def _get_model_architecture(config: PretrainedConfig) -> str: - architectures = getattr(config, "architectures", []) - for arch in architectures: - if arch in _NEURON_SUPPORTED_MODELS: - return arch - raise ValueError( - f"Model architectures {architectures} are not supported on Neuron " - f"for now. Supported architectures: " - f"{list(_NEURON_SUPPORTED_MODELS.keys())}") - - -def _get_buckets(env: str, default_value: list[int]) -> list[int]: - env_value = os.getenv(env) - if env_value is None: - return default_value - buckets_remove_empty = filter( - lambda x: x is not None and len(x.strip()) > 0, env_value.split(",")) - buckets_int = map(int, buckets_remove_empty) - buckets_list = list(buckets_int) - return buckets_list - - -def _get_default_neuron_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig): - """Generate a neuron config based on vllm config args.""" - from transformers_neuronx.config import ContinuousBatchingConfig - from transformers_neuronx.constants import LAYOUT_BSH - - continuous_batching_config = ContinuousBatchingConfig( - batch_size_for_shared_caches=scheduler_config.max_num_seqs) - quant_config = dict( - dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - quantize_method="vector_dynamic") - neuron_quantization_config_builder = lambda quant: get_quantization_config( - quant).from_config(quant_config).get_quant_method(None, "") - # TODO: Add Paged attention config to the default neuron arguments. - default_neuron_args = dict( - collectives_layout=LAYOUT_BSH, - attention_layout=LAYOUT_BSH, - fuse_qkv=True, - quant=neuron_quantization_config_builder(model_config.quantization) - if model_config.quantization else None, - continuous_batching=continuous_batching_config, - weight_tiling=bool(model_config.quantization), - on_device_generation=_get_neuron_on_device_generation_config( - model_config)) - return default_neuron_args - - -def _get_default_neuron_config_for_speculation( - model_config: ModelConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig): - """Generate a neuron config for speculative decoding based on - vllm config args.""" - from transformers_neuronx.config import ContinuousBatchingConfig - from transformers_neuronx.constants import LAYOUT_BSH - - continuous_batching_config = ContinuousBatchingConfig( - batch_size_for_shared_caches=scheduler_config.max_num_seqs) - - default_neuron_args = dict(collectives_layout=LAYOUT_BSH, - attention_layout=LAYOUT_BSH, - fuse_qkv=True, - on_device_embedding=True, - continuous_batching=continuous_batching_config, - on_device_generation=copy.deepcopy( - model_config.neuron_sampling_params)) - return default_neuron_args - - -def _get_neuron_on_device_generation_config(model_config: ModelConfig): - if not _is_neuron_on_device_sampling_disabled(model_config): - return copy.deepcopy(model_config.neuron_sampling_params) - return None - - -def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool: - return not getattr(model_config, "neuron_sampling_params", None) - - -def _get_neuron_config_after_override(default_neuron_config, - overridden_neuron_config): - from transformers_neuronx.config import (ContinuousBatchingConfig, - GenerationConfig, - KVCacheQuantizationConfig, - NeuronConfig, QuantizationConfig, - SparseAttnConfig) - - sparse_attn = overridden_neuron_config.pop("sparse_attn", {}) - if sparse_attn: - overridden_neuron_config["sparse_attn"] = SparseAttnConfig( - **sparse_attn) - - kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {}) - if kv_cache_quant: - overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig( - **kv_cache_quant) - - continuous_batching = overridden_neuron_config.pop("continuous_batching", - {}) - if continuous_batching: - overridden_neuron_config[ - "continuous_batching"] = ContinuousBatchingConfig( - **continuous_batching) - - quant = overridden_neuron_config.pop("quant", {}) - if quant: - overridden_neuron_config["quant"] = QuantizationConfig(**quant) - - on_device_generation = overridden_neuron_config.pop( - "on_device_generation", {}) - if on_device_generation: - overridden_neuron_config["on_device_generation"] = GenerationConfig( - **on_device_generation) - default_neuron_config.update(overridden_neuron_config) - return NeuronConfig(**default_neuron_config) - - -def get_neuron_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: - """Initializes a neuron-optimized model for inference.""" - # Create a model instance. - model = NeuronCausalLM( - model_config.hf_config, - _is_neuron_on_device_sampling_disabled(model_config)) - - default_neuron_config_args = _get_default_neuron_config( - model_config, parallel_config, scheduler_config) - - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", - [scheduler_config.max_model_len]) - n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - - model.load_weights(model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - return model.eval() - - -def get_neuron_speculation_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Initializes a neuron-optimized speculation model for inference. - - This method is only applicable for speculation with a standalone draft model - """ - from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder - - # For Eagle SD, we need to pass in additional parameters in neuron config. - is_eagle = getattr(speculation_config.draft_model_config.hf_config, - "is_eagle", False) - - # Create target model instance. - target_model = NeuronCausalLM(model_config.hf_config) - - default_neuron_config_args = _get_default_neuron_config_for_speculation( - model_config, parallel_config, scheduler_config) - if is_eagle: - default_neuron_config_args['is_eagle_target'] = True - - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", - [scheduler_config.max_model_len]) - n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - - target_model.load_weights( - model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - target_model.eval() - - # Create draft model instance. - draft_model = NeuronCausalLM( - speculation_config.draft_model_config.hf_config) - - default_draft_neuron_config_args = ( - _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, parallel_config, - scheduler_config)) - if is_eagle: - default_draft_neuron_config_args['is_eagle_draft'] = True - default_draft_neuron_config_args['has_pre_attention_norm'] = False - - draft_neuron_config = _get_neuron_config_after_override( - default_draft_neuron_config_args, - speculation_config.draft_model_config.override_neuron_config) - - draft_model.load_weights(speculation_config.draft_model_config.model, - tp_degree=speculation_config. - draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[ - speculation_config.draft_model_config.dtype], - neuron_config=draft_neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - draft_model.eval() - - num_speculative_tokens = speculation_config.num_speculative_tokens - # Create speculation model instance. - speculation_model = FusedSpeculativeDecoder(draft_model.model, - target_model.model, - num_speculative_tokens) - speculation_model.to_neuron() - - return NeuronSpeculationCausalLM(speculation_model) - - -def get_neuron_eagle_speculation_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Initializes a neuron-optimized EAGLE speculation model for inference.""" - from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder - - # Create target model instance. - target_model = NeuronCausalLM(model_config.hf_config) - - default_neuron_config_args = _get_default_neuron_config_for_speculation( - model_config, parallel_config, scheduler_config) - default_neuron_config_args['is_eagle_target'] = True - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", - [scheduler_config.max_model_len]) - n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - - target_model.load_weights( - model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - target_model.eval() - - # Create draft model instance. - draft_model = NeuronCausalLM( - speculation_config.draft_model_config.hf_config) - - default_draft_neuron_config_args = ( - _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, parallel_config, - scheduler_config)) - default_draft_neuron_config_args['is_eagle_draft'] = True - default_draft_neuron_config_args['has_pre_attention_norm'] = False - draft_neuron_config = _get_neuron_config_after_override( - default_draft_neuron_config_args, - speculation_config.draft_model_config.override_neuron_config) - - draft_model.load_weights(speculation_config.draft_model_config.model, - tp_degree=speculation_config. - draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[ - speculation_config.draft_model_config.dtype], - neuron_config=draft_neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) - - draft_model.eval() - - token_tree: dict[int, list[int]] = ast.literal_eval( - speculation_config.speculative_token_tree) - - speculation_model = EagleSpeculativeDecoder(draft_model.model, - target_model.model, - token_tree=token_tree) - speculation_model.to_neuron() - - return NeuronSpeculationCausalLM(speculation_model) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py deleted file mode 100644 index 34bf43fe7b57c..0000000000000 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ /dev/null @@ -1,685 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Utilities for selecting and loading Neuron models in -neuronx-distributed-inference framework.""" -# Disabling yapf because yapf and isort have conflicts for the below imports -# yapf: disable -import copy -import hashlib -import importlib -import multiprocessing -import os -import shutil -from typing import Optional - -import torch -import torch.nn as nn -from neuronx_distributed_inference.models.config import ( - FusedSpecNeuronConfig, OnDeviceSamplingConfig) -from neuronx_distributed_inference.models.mllama.utils import ( - create_vision_mask) -from neuronx_distributed_inference.modules.lora_serving import ( - LoraServingConfig) -from neuronx_distributed_inference.utils.hf_adapter import ( - load_pretrained_config) -from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig - -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig) -from vllm.logger import init_logger -from vllm.logprobs import Logprob -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput - -# yapf: enable -logger = init_logger(__name__) - -TORCH_DTYPE_TO_NEURON_AMP = { - "auto": "float32", - "half": "float16", - "float16": "float16", - "bfloat16": "bfloat16", - "float": "float32", - "float32": "float32", - torch.float16: "float16", - torch.bfloat16: "bfloat16", - torch.float32: "float32", -} - -# Models supported by Neuronx distributed for inference. -_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = { - "LlamaForCausalLM": - ("neuronx_distributed_inference.models.llama.modeling_llama", - "NeuronLlamaForCausalLM"), - "MistralForCausalLM": - ("neuronx_distributed_inference.models.llama.modeling_llama", - "NeuronLlamaForCausalLM"), - "DbrxForCausalLM": - ("neuronx_distributed_inference.models.dbrx.modeling_dbrx", - "NeuronDbrxForCausalLM"), - "MixtralForCausalLM": - ("neuronx_distributed_inference.models.mixtral.modeling_mixtral", - "NeuronMixtralForCausalLM"), - "MllamaForConditionalGeneration": - ("neuronx_distributed_inference.models.mllama.modeling_mllama", - "NeuronMllamaForCausalLM"), -} - - -class NeuronCausalLM(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - ) -> None: - super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) - self.sampler = Sampler() - - # Lazy initialized - self.model: nn.Module - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, - prev_hidden: Optional[torch.Tensor] = None, - adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor: - # sort block ids sequentially for perf/neuron support reasons - sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) - input_ids = torch.index_select(input_ids, 0, sorted_indices) - positions = torch.index_select(positions, 0, sorted_indices) - sampling_params = torch.index_select(sampling_params, 0, - sorted_indices) - output = self.model(input_ids, - attention_mask=None, - position_ids=positions, - seq_ids=sorted_input_block_ids, - sampling_params=sampling_params, - prev_hidden=prev_hidden, - adapter_ids=adapter_ids) - # on-device sampling - if self.config.neuron_config.on_device_sampling_config: - output = output.hidden_states - else: - output = output.logits[:, -1, :] - - restored_indices = torch.argsort(sorted_indices) - if input_block_ids.shape[0] != 1: - output = torch.index_select(output, 0, restored_indices) - - return output - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - # on-device sampling - if self.config.neuron_config.on_device_sampling_config: - batch_size = logits.shape - seq_ids = [ - seq_id for sg in sampling_metadata.seq_groups - for seq_id in sg.seq_ids - ] - assert len(seq_ids) == list(batch_size)[0], "batch size mismatch" - # Organize input tensors by step instead of by sequence. - accepted_token_ids_by_step = logits.flatten() - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - step_output_token_ids = [] - for i, seq_id in enumerate(seq_ids): - token_id = accepted_token_ids_by_step[i] - step_output_token_ids.append( - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, - logprobs={token_id: Logprob(token_id)}) - ], - prompt_logprobs=None)) - return SamplerOutput(outputs=step_output_token_ids) - else: - return self.sampler(logits, sampling_metadata) - - def load_weights(self, model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()( - **kwargs['neuron_config']) - self.config.neuron_config = neuron_config - config = neuronx_model_cls.get_config_cls()( - neuron_config, - load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), - usedforsecurity=False).hexdigest() - if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: - compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") - elif os.path.exists(model_name_or_path): - compiled_model_path = os.path.join(model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - else: - compiled_model_path = os.path.join("local-models", - model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - try: - self.model = neuronx_model_cls(compiled_model_path) - override_neuron_config = kwargs["override_neuron_config"] - for k, v in override_neuron_config.items(): - setattr(self.model.config.neuron_config, k, v) - self.model.load(compiled_model_path) - return - except (FileNotFoundError, ValueError) as e: - logger.warning("Exception: %s", e) - logger.warning("Failed to load the model from %s, Recompiling...", - compiled_model_path) - if not os.path.exists(model_name_or_path): - hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - saved_path = os.path.join("local-models", model_name_or_path) - hf_model.save_pretrained(saved_path) - model_name_or_path = saved_path - self.model = neuronx_model_cls(model_name_or_path, config) - self.model.compile(compiled_model_path) - self.model.load(compiled_model_path) - - -class NeuronMllamaForCausalLM(nn.Module): - - def __init__(self, - config: PretrainedConfig, - on_device_sampling_disabled: bool = False) -> None: - super().__init__() - # has_image is the only multimodal input that is used in - # token-generation - # This is a cache (on CPU) that saves has_image data per sequence id - # The number of entries in this cache is <= Batch-Size - self.has_image_cache: dict[int, torch.Tensor] = {} - self.config = config - self.logits_processor = LogitsProcessor( - config.get_text_config().vocab_size, logits_as_input=True) - - self.on_device_sampling_disabled = on_device_sampling_disabled - if self.on_device_sampling_disabled: - # Use default sampler - self.sampler = Sampler() - - # Lazy initialized - self.model: nn.Module - self.is_reorder_needed: bool = True - - def read_from_has_image_cache(self, seq_ids: torch.Tensor): - has_image_list = [] - for index in range(len(seq_ids)): - seq_id = seq_ids[index].item() - if seq_id in self.has_image_cache: - has_image_list.append(self.has_image_cache[seq_id]) - else: - has_image_list.append(torch.tensor([0])) - return torch.tensor(has_image_list) - - def write_to_has_image_cache(self, seq_ids: torch.Tensor, - has_image: torch.Tensor): - for index in range(len(seq_ids)): - seq_id = seq_ids[index].item() - if index < len(has_image): - self.has_image_cache[seq_id] = has_image[index] - else: - self.has_image_cache[seq_id] = torch.zeros(1) - - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - seq_ids: torch.Tensor, pixel_values: torch.Tensor, - aspect_ratios: torch.Tensor, num_chunks: torch.Tensor, - has_image: torch.Tensor, sampling_params) -> torch.Tensor: - - # We update the has_image cache during prefill - # and read the has_image cache during decode - if input_ids.shape[-1] > 1: # prefill - self.write_to_has_image_cache(seq_ids, has_image) - else: - has_image = self.read_from_has_image_cache(seq_ids) - bs = input_ids.shape[0] - num_chunks = torch.zeros((bs, 1)) - aspect_ratios = torch.zeros((bs, 1, 2)) - - input_block_ids = seq_ids - origin_input_block_ids = seq_ids - if self.is_reorder_needed: - # sort block ids sequentially for perf/neuron support reasons - input_block_ids, sorted_indices = torch.sort(input_block_ids) - input_ids = torch.index_select(input_ids, 0, sorted_indices) - positions = torch.index_select(positions, 0, sorted_indices) - sampling_params = torch.index_select(sampling_params, 0, - sorted_indices) - pixel_values = torch.index_select(pixel_values, 0, sorted_indices) - aspect_ratios = torch.index_select(aspect_ratios, 0, - sorted_indices) - num_chunks = torch.index_select(num_chunks, 0, sorted_indices) - has_image = torch.index_select(has_image, 0, sorted_indices) - - self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) - output = self.model( - input_ids.to(torch.int32), - attention_mask=None, - position_ids=positions.to(torch.int32), - seq_ids=seq_ids.flatten().to(torch.int32), - pixel_values=pixel_values.to( - self.config.vision_config.torch_dtype), - aspect_ratios=aspect_ratios.to(torch.int32), - vision_mask=self.vision_mask.to(torch.int32), - sampling_params=sampling_params, - num_chunks=num_chunks.to(torch.int32), - has_image=has_image.to(torch.int32), - ) - if self.config.neuron_config.on_device_sampling_config: - output = output.hidden_states - else: - output = output.logits[:, -1, :] - - if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1: - restored_indices = torch.argsort(sorted_indices) - output = torch.index_select(output, 0, restored_indices) - return output - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(None, hidden_states, sampling_metadata) - return logits - - def sample(self, hidden_states, sampling_metadata): - if not self.on_device_sampling_disabled: - with torch.profiler.record_function("sample"): - hidden_states = hidden_states.flatten() - res = [] - sample_idx = 0 - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - samples = [] - for seq_id in seq_ids: - token_id = hidden_states[sample_idx].item() - samples.append( - SequenceOutput( - parent_seq_id=seq_id, - output_token=token_id, - logprobs={token_id: Logprob(token_id)})) - sample_idx += 1 - res.append( - CompletionSequenceGroupOutput(samples=samples, - prompt_logprobs=None)) - next_tokens = SamplerOutput(outputs=res) - else: - next_tokens = self.sampler(None, hidden_states, sampling_metadata) - return next_tokens - - def load_weights(self, model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()( - **kwargs['neuron_config']) - self.config.neuron_config = neuron_config - logger.info("neuron_config buckets: %s", - self.config.neuron_config.buckets) - config = neuronx_model_cls.get_config_cls()( - neuron_config, - load_config=load_pretrained_config(model_name_or_path)) - hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), - usedforsecurity=False).hexdigest() - if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: - compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") - elif os.path.exists(model_name_or_path): - compiled_model_path = os.path.join(model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - else: - compiled_model_path = os.path.join("local-models", - model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - try: - self.model = neuronx_model_cls(compiled_model_path) - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - self.vision_token_id = tokenizer( - "<|image|>", add_special_tokens=False).input_ids[0] - self.model.load(compiled_model_path) - return - except (FileNotFoundError, ValueError): - logger.warning("Failed to load the model from %s, Recompiling...", - compiled_model_path) - if not os.path.exists(model_name_or_path): - hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - saved_path = os.path.join("local-models", model_name_or_path) - hf_model.save_pretrained(saved_path) - model_name_or_path = saved_path - self.model = neuronx_model_cls(model_name_or_path, config) - - logger.info("\nCompiling and saving model to %s", model_name_or_path) - - p = multiprocessing.Process(target=compile_model, - args=(self, compiled_model_path)) - p.start() - p.join() - - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - tokenizer.save_pretrained(compiled_model_path) - logger.info("Successfully compiled and saved the model in %s", - compiled_model_path) - - # Read "<|image|>" token_id from the tokenizer - self.vision_token_id = tokenizer("<|image|>", - add_special_tokens=False).input_ids[0] - logger.info("\nLoading model from compiled checkpoint...") - self.model.load(compiled_model_path) - - -def compile_model(neuron_model, traced_model_path): - neuron_model.model.compile(traced_model_path) - - -class NeuronSpeculationCausalLM(nn.Module): - """A Neuron-optimized causal language model with speculative decoding.""" - - def __init__( - self, - config: PretrainedConfig, - ) -> None: - super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) - # Lazy initialized - self.model: nn.Module - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, - ) -> torch.Tensor: - # sort block ids sequentially for perf/neuron support reasons - sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) - input_ids = torch.index_select(input_ids, 0, sorted_indices) - positions = torch.index_select(positions, 0, sorted_indices) - sampling_params = torch.index_select(sampling_params, 0, - sorted_indices) - - output = self.model(input_ids, - attention_mask=None, - position_ids=positions, - seq_ids=sorted_input_block_ids, - sampling_params=sampling_params) - restored_indices = torch.argsort(sorted_indices) - - # CTX encoding - if (positions[:, 0]).sum().item() == 0: - output = output.fused_outputs[0][:, 0:1] - if input_block_ids.shape[0] != 1: - output = torch.index_select(output, 0, restored_indices) - return output - - # Fused Spec (Generation) - accepted_tokens_with_padding = output.fused_outputs[0] - next_pos_ids = output.fused_outputs[-1] - generated_token_counts = next_pos_ids - positions - - assert torch.any(generated_token_counts == 0).item() is False, \ - "NxDI model generated no output for one or more sequences." - - batch_size, steps = accepted_tokens_with_padding.shape - mask = torch.arange(steps).expand(batch_size, - -1) >= generated_token_counts - accepted_tokens_with_padding[mask] = -1 - - if input_block_ids.shape[0] != 1: - accepted_tokens_with_padding = torch.index_select( - accepted_tokens_with_padding, 0, restored_indices) - - return accepted_tokens_with_padding - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - batch_size, num_steps = logits.shape - seq_ids = [ - seq_id for sg in sampling_metadata.seq_groups - for seq_id in sg.seq_ids - ] - # Organize input tensors by step instead of by sequence. - accepted_token_ids_by_step = logits.transpose(0, 1) - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - sampler_output_list = [] - for step_index in range(num_steps): - if all(token_id == -1 - for token_id in accepted_token_ids_by_step[step_index]): - break - step_output_token_ids = [] - for sequence_index in range(batch_size): - token_id = accepted_token_ids_by_step[step_index][ - sequence_index] - step_output_token_ids.append( - CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_ids[sequence_index], - output_token=token_id, - logprobs={token_id: Logprob(token_id)}) - ], - prompt_logprobs=None)) - sampler_output_list.append( - SamplerOutput(outputs=step_output_token_ids)) - return sampler_output_list - - def load_weights(self, model_name_or_path: str, - draft_model_name_or_path: str, **kwargs): - arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls_name = ( - _NEURON_SUPPORTED_MODELS[arch]) - neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()( - **kwargs['neuron_config']) - config = neuronx_model_cls.get_config_cls()( - neuron_config, - load_config=load_pretrained_config(model_name_or_path)) - - draft_neuron_config = copy.deepcopy(config.neuron_config) - if not config.neuron_config.enable_eagle_speculation: - draft_neuron_config.speculation_length = 0 - draft_neuron_config.trace_tokengen_model = True - draft_neuron_config.enable_fused_speculation = False - if getattr(config.neuron_config, "draft_model_modules_to_not_convert", - None): - draft_neuron_config.modules_to_not_convert = ( - draft_neuron_config.draft_model_modules_to_not_convert) - if config.neuron_config.enable_eagle_speculation: - draft_neuron_config.is_eagle_draft = True - draft_neuron_config.sequence_parallel_enabled = False - draft_config = neuronx_model_cls.get_config_cls()( - draft_neuron_config, - load_config=load_pretrained_config(draft_model_name_or_path)) - fused_spec_config = (FusedSpecNeuronConfig( - neuronx_model_cls._model_cls, - draft_config=draft_config, - draft_model_path=draft_model_name_or_path)) - config.fused_spec_config = fused_spec_config - self.config.neuron_config = neuron_config - - hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), - usedforsecurity=False).hexdigest() - if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: - compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") - elif os.path.exists(model_name_or_path): - compiled_model_path = os.path.join(model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - else: - compiled_model_path = os.path.join("local-models", - model_name_or_path, - "neuron-compiled-artifacts", - hashed_config) - shutil.rmtree(compiled_model_path, ignore_errors=True) - try: - self.model = neuronx_model_cls(compiled_model_path) - override_neuron_config = kwargs["override_neuron_config"] - for k, v in override_neuron_config.items(): - setattr(self.model.config.neuron_config, k, v) - self.model.load(compiled_model_path) - return - except (FileNotFoundError, ValueError) as e: - logger.warning("Exception: %s", e) - logger.warning("Failed to load the model from %s Recompiling...", - compiled_model_path) - if not os.path.exists(model_name_or_path): - hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - saved_path = os.path.join("local-models", model_name_or_path) - hf_model.save_pretrained(saved_path) - model_name_or_path = saved_path - if not os.path.exists(draft_model_name_or_path): - if draft_model_name_or_path != model_name_or_path: - hf_model = AutoModelForCausalLM.from_pretrained( - draft_model_name_or_path) - saved_path = os.path.join("local-models", - draft_model_name_or_path) - hf_model.save_pretrained(saved_path) - draft_model_name_or_path = saved_path - else: - draft_model_name_or_path = model_name_or_path - config.fused_spec_config.draft_model_path = draft_model_name_or_path - self.model = neuronx_model_cls(model_name_or_path, config) - self.model.compile(compiled_model_path) - self.model.load(compiled_model_path) - - -def _get_model_architecture(config: PretrainedConfig) -> str: - architectures = getattr(config, "architectures", []) - for arch in architectures: - if arch in _NEURON_SUPPORTED_MODELS: - return arch - raise ValueError( - f"Model architectures {architectures} are not supported on Neuron " - f"for now. Supported architectures: " - f"{list(_NEURON_SUPPORTED_MODELS.keys())}") - - -def _get_default_neuron_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_serving_config: LoraServingConfig): - """Generate a neuron config based on vllm config args.""" - on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True, - deterministic=False) - batch_size = scheduler_config.max_num_seqs - - neuron_config = dict( - tp_degree=parallel_config.tensor_parallel_size, - ctx_batch_size=1, - batch_size=batch_size, - max_context_length=scheduler_config.max_model_len, - seq_len=scheduler_config.max_model_len, - enable_bucketing=True, - is_continuous_batching=True, - quantized=False, - torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - padding_side="right", - on_device_sampling_config=on_device_sampling_config, - sequence_parallel_enabled=True, - lora_serving_config=lora_serving_config) - return neuron_config - - -def _get_default_speculation_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Generate a neuron config for speculative decoding based on vllm config - args.""" - neuron_config = dict( - tp_degree=parallel_config.tensor_parallel_size, - ctx_batch_size=1, - batch_size=scheduler_config.max_num_seqs, - max_context_length=scheduler_config.max_model_len, - seq_len=scheduler_config.max_model_len, - speculation_length=speculation_config.num_speculative_tokens, - trace_tokengen_model=False, - enable_fused_speculation=True, - enable_bucketing=True, - is_continuous_batching=True, - quantized=False, - torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - on_device_sampling_config=dict( - top_k=1, - do_sample=False, - )) - return neuron_config - - -def _get_neuron_config_after_override(default_neuron_config, - overridden_neuron_config): - """Update default neuron config values with override args""" - overridden_neuron_config = overridden_neuron_config or {} - default_neuron_config.update(overridden_neuron_config) - return default_neuron_config - - -def get_neuron_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_serving_config: LoraServingConfig) -> nn.Module: - """Initializes a neuron-optimized model for inference.""" - model_arch = _get_model_architecture(model_config.hf_config) - if model_arch == "MllamaForConditionalGeneration": - model = NeuronMllamaForCausalLM(model_config.hf_config) - else: - model = NeuronCausalLM(model_config.hf_config) - default_neuron_config_args = _get_default_neuron_config( - model_config, parallel_config, scheduler_config, lora_serving_config) - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - override_neuron_config = model_config.override_neuron_config - model.load_weights(model_config.model, - neuron_config=neuron_config, - override_neuron_config=override_neuron_config) - return model.eval() - - -def get_neuron_speculation_model(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Initializes a neuron-optimized speculation model for inference. - - This model handles speculation using both a draft model and an EAGLE draft. - """ - model = NeuronSpeculationCausalLM(model_config.hf_config) - default_neuron_config_args = _get_default_speculation_config( - model_config, parallel_config, scheduler_config, speculation_config) - neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, model_config.override_neuron_config) - - override_neuron_config = model_config.override_neuron_config - model.load_weights(model_config.model, - speculation_config.draft_model_config.model, - neuron_config=neuron_config, - override_neuron_config=override_neuron_config) - return model.eval() diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index 83e0f386c1082..dc941401a04e0 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: SIM117 -import glob import os from collections.abc import Generator from typing import Optional @@ -10,13 +9,14 @@ import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, runai_safetensors_weights_iterator) -from vllm.transformers_utils.s3_utils import glob as s3_glob -from vllm.transformers_utils.utils import is_s3 +from vllm.transformers_utils.runai_utils import (is_runai_obj_uri, + list_safetensors) class RunaiModelStreamerLoader(BaseModelLoader): @@ -53,27 +53,22 @@ class RunaiModelStreamerLoader(BaseModelLoader): If the model is not local, it will be downloaded.""" - is_s3_path = is_s3(model_name_or_path) + is_object_storage_path = is_runai_obj_uri(model_name_or_path) is_local = os.path.isdir(model_name_or_path) safetensors_pattern = "*.safetensors" index_file = SAFE_WEIGHTS_INDEX_NAME - hf_folder = (model_name_or_path if - (is_local or is_s3_path) else download_weights_from_hf( + hf_folder = (model_name_or_path if (is_local or is_object_storage_path) + else download_weights_from_hf( model_name_or_path, self.load_config.download_dir, [safetensors_pattern], revision, ignore_patterns=self.load_config.ignore_patterns, )) - if is_s3_path: - hf_weights_files = s3_glob(path=hf_folder, - allow_pattern=[safetensors_pattern]) - else: - hf_weights_files = glob.glob( - os.path.join(hf_folder, safetensors_pattern)) + hf_weights_files = list_safetensors(path=hf_folder) - if not is_local and not is_s3_path: + if not is_local and not is_object_storage_path: download_safetensors_index_file_from_hf( model_name_or_path, index_file, self.load_config.download_dir, revision) diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index 3edd4ec4007e8..a85ca065d1d27 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -10,7 +10,8 @@ from typing import Any, Optional import torch from torch import nn -from vllm.config import LoadConfig, ModelConfig +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 3d491be3156b6..58296131fadb9 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -171,51 +171,52 @@ class TensorizerConfig(MutableMapping): _is_sharded: bool = field(init=False, default=False) _fields: ClassVar[tuple[str, ...]] _keys: ClassVar[frozenset[str]] - """ - Args for the TensorizerConfig class. These are used to configure the - behavior of model serialization and deserialization using Tensorizer. + """Configuration class for Tensorizer settings. - Args: - tensorizer_uri: Path to serialized model tensors. Can be a local file - path or a S3 URI. This is a required field unless lora_dir is - provided and the config is meant to be used for the - `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or - `lora_dir` is passed to this object's initializer, this is a required - argument. - tensorizer_dir: Path to a directory containing serialized model tensors, - and all other potential model artifacts to load the model, such as - configs and tokenizer files. Can be passed instead of `tensorizer_uri` - where the `model.tensors` file will be assumed to be in this - directory. - vllm_tensorized: If True, indicates that the serialized model is a - vLLM model. This is used to determine the behavior of the - TensorDeserializer when loading tensors from a serialized model. - It is far faster to deserialize a vLLM model as it utilizes - tensorizer's optimized GPU loading. Note that this is now - deprecated, as serialized vLLM models are now automatically - inferred as vLLM models. - verify_hash: If True, the hashes of each tensor will be verified against - the hashes stored in the metadata. A `HashMismatchError` will be - raised if any of the hashes do not match. - num_readers: Controls how many threads are allowed to read concurrently - from the source file. Default is `None`, which will dynamically set - the number of readers based on the number of available - resources and model size. This greatly increases performance. - encryption_keyfile: File path to a binary file containing a - binary key to use for decryption. `None` (the default) means - no decryption. See the example script in - examples/others/tensorize_vllm_model.py. - s3_access_key_id: The access key for the S3 bucket. Can also be set via - the S3_ACCESS_KEY_ID environment variable. - s3_secret_access_key: The secret access key for the S3 bucket. Can also - be set via the S3_SECRET_ACCESS_KEY environment variable. - s3_endpoint: The endpoint for the S3 bucket. Can also be set via the - S3_ENDPOINT_URL environment variable. - lora_dir: Path to a directory containing LoRA adapter artifacts for - serialization or deserialization. When serializing LoRA adapters - this is the only necessary parameter to pass to this object's - initializer. - """ + These settings configure the behavior of model serialization and + deserialization using Tensorizer. + + Attributes: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. This is a required field unless lora_dir is + provided and the config is meant to be used for the + `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or + `lora_dir` is passed to this object's initializer, this is + a required argument. + tensorizer_dir: Path to a directory containing serialized model tensors, + and all other potential model artifacts to load the model, such as + configs and tokenizer files. Can be passed instead of + `tensorizer_uri` where the `model.tensors` file will be assumed + to be in this directory. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. + verify_hash: If True, the hashes of each tensor will be verified + against the hashes stored in the metadata. A `HashMismatchError` + will be raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/others/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + lora_dir: Path to a directory containing LoRA adapter artifacts for + serialization or deserialization. When serializing LoRA adapters + this is the only necessary parameter to pass to this object's + initializer. + """ def __post_init__(self): # check if the configuration is for a sharded vLLM model diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index fa01758ab4cee..65ea49c642944 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -8,7 +8,8 @@ from typing import Union import torch from torch import nn -from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import ( diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f57ebdb1abcbc..0c2441a6db44d 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -19,10 +19,11 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.models.adapters import (as_embedding_model, - as_reward_model, - as_seq_cls_model) -from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.model_executor.models.adapters import ( + as_embedding_model, as_reward_model, as_seq_cls_model, + try_create_mm_pooling_model_cls) +from vllm.model_executor.models.interfaces import (SupportsQuant, + supports_multimodal) from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -169,22 +170,6 @@ def get_model_architecture( model_config: ModelConfig) -> tuple[type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - mixtral_supported = [ - "fp8", - "compressed-tensors", - "gptq_marlin", - "awq_marlin", - "quark", - "bitsandbytes", - ] - - if (model_config.quantization is not None - and model_config.quantization not in mixtral_supported - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] - model_cls, arch = model_config.registry.resolve_model_cls( architectures, model_config=model_config, @@ -199,6 +184,15 @@ def get_model_architecture( "performance may not be optimal.", arch) convert_type = model_config.convert_type + if convert_type != "none" and supports_multimodal(model_cls): + logger.debug_once("Detected conversion of Multi Modal model.") + converted = try_create_mm_pooling_model_cls(model_cls) + if converted is not None: + logger.debug_once("Creating wrapper class to forward pooler.") + return converted, arch + else: + logger.debug_once("Attempting direct conversion.") + if convert_type == "none": pass elif convert_type == "embed": diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f87eeaa4563ff..f2c66763d0816 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for downloading and initializing model weights.""" +import concurrent.futures import fnmatch import glob import hashlib @@ -18,10 +19,12 @@ import huggingface_hub.constants import numpy as np import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download -from safetensors.torch import load_file, safe_open, save_file +from safetensors.torch import load, load_file, safe_open, save_file from tqdm.auto import tqdm -from vllm.config import LoadConfig, ModelConfig +from vllm import envs +from vllm.config import ModelConfig +from vllm.config.load import LoadConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, @@ -95,6 +98,41 @@ def get_lock(model_name_or_path: Union[str, Path], return lock +def maybe_download_from_modelscope( + model: str, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + ignore_patterns: Optional[Union[str, list[str]]] = None, + allow_patterns: Optional[Union[list[str], + str]] = None) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if envs.VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, download_dir): + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=ignore_patterns, + allow_patterns=allow_patterns, + ) + else: + model_path = model + return model_path + return None + + def _shared_pointers(tensors): ptrs = defaultdict(list) for k, v in tensors.items(): @@ -169,7 +207,13 @@ def get_quant_config(model_config: ModelConfig, # Inflight BNB quantization if model_config.quantization == "bitsandbytes": return quant_cls.from_config({}) - is_local = os.path.isdir(model_config.model) + model_name_or_path = maybe_download_from_modelscope( + model_config.model, + revision=model_config.revision, + download_dir=load_config.download_dir, + allow_patterns=["*.json"], + ) or model_config.model + is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. with get_lock(model_config.model, load_config.download_dir): @@ -182,7 +226,7 @@ def get_quant_config(model_config: ModelConfig, tqdm_class=DisabledTqdm, ) else: - hf_folder = model_config.model + hf_folder = model_name_or_path possible_config_filenames = quant_cls.get_config_filenames() @@ -475,18 +519,58 @@ def np_cache_weights_iterator( def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, + safetensors_load_strategy: str = "lazy", ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" + loading_desc = "Loading safetensors checkpoint shards" + if safetensors_load_strategy == "eager": + loading_desc += " (eager)" + for st_file in tqdm( hf_weights_files, - desc="Loading safetensors checkpoint shards", + desc=loading_desc, disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT, ): - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param + if safetensors_load_strategy == "eager": + with open(st_file, "rb") as f: + state_dict = load(f.read()) + yield from state_dict.items() + else: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def multi_thread_safetensors_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, + max_workers: int = 4, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model safetensor files.""" + + def _load_file(st_file: str): + result = load_file(st_file, device="cpu") + return result + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) as executor: + futures = [ + executor.submit(_load_file, st_file) + for st_file in hf_weights_files + ] + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ) + + for future in futures_iter: + state_dict = future.result() + yield from state_dict.items() def runai_safetensors_weights_iterator( @@ -569,6 +653,39 @@ def pt_weights_iterator( del state +def multi_thread_pt_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, + pt_load_map_location: Union[str, dict[str, str]] = "cpu", + max_workers: int = 4, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Multi-Thread iterate over the weights in the model bin/pt files.""" + + def _load_file(bin_file: str): + return torch.load(bin_file, + map_location=pt_load_map_location, + weights_only=True) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) as executor: + futures = [ + executor.submit(_load_file, bin_file) + for bin_file in hf_weights_files + ] + futures_iter = tqdm( + concurrent.futures.as_completed(futures), + total=len(hf_weights_files), + desc="Multi-thread loading pt checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ) + + for future in futures_iter: + state = future.result() + yield from state.items() + del state + + def get_gguf_extra_tensor_names( gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]: reader = gguf.GGUFReader(gguf_file) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 50c2cd97f3d09..c4328a176a5de 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import inspect from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import torch import torch.nn as nn +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig @@ -49,26 +52,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: if not dense_modules: return None - module = dense_modules[0] - folder = module.get("path", "") + layers = [] + for module in dense_modules: + folder = module.get("path", "") - config_path = f"{folder}/config.json" if folder else "config.json" - layer_config = get_hf_file_to_dict(config_path, model_config.model, - model_config.revision) - if not layer_config: - return None + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model_config.model, + model_config.revision) + if not layer_config: + continue - linear = nn.Linear(layer_config.get("in_features", 768), - layer_config.get("out_features", 768), - bias=layer_config.get("bias", True), - dtype=torch.float32) + linear = nn.Linear(layer_config.get("in_features", 768), + layer_config.get("out_features", 768), + bias=layer_config.get("bias", True), + dtype=model_config.head_dtype) - if _load_dense_weights(linear, folder, model_config): - layers = [linear] + if not _load_dense_weights(linear, folder, model_config): + continue + + layers.append(linear) if act_name := layer_config.get("activation_function"): layers.append(get_act_fn(act_name)) - return nn.Sequential(*layers).to(dtype=torch.float32) - + return nn.Sequential(*layers).to(dtype=model_config.head_dtype) except Exception: logger.exception("ST projector loading failed") @@ -103,15 +108,13 @@ def _load_dense_weights(linear: nn.Linear, folder: str, if weight_key in state_dict: weight_loader = getattr(linear.weight, "weight_loader", default_weight_loader) - weight_loader(linear.weight, - state_dict[weight_key].to(torch.float32)) + weight_loader(linear.weight, state_dict[weight_key]) bias_key = weight_key.replace("weight", "bias") if linear.bias is not None and bias_key in state_dict: bias_loader = getattr(linear.bias, "weight_loader", default_weight_loader) - bias_loader(linear.bias, - state_dict[bias_key].to(torch.float32)) + bias_loader(linear.bias, state_dict[bias_key]) return True except Exception: logger.exception("Failed to load %s", filename) @@ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: return model_name + pooling_suffix +def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T: + + class CallVisitor(ast.NodeVisitor): + + def __init__(self): + self.calls = [] + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + self.calls.append(node.func.id) + self.generic_visit(node) + + visitor = CallVisitor() + visitor.visit(ast.parse(inspect.getsource(orig_cls))) + if "init_vllm_registered_model" not in visitor.calls: + return None + + class ModelForPooling(orig_cls, VllmModelForPooling): + + is_pooling_model = True + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + self.pooler = self.get_language_model().pooler + + return ModelForPooling # type: ignore + + def _create_pooling_model_cls(orig_cls: _T) -> _T: # Lazy import from .utils import AutoWeightsLoader, WeightsMapper @@ -255,7 +293,7 @@ def as_seq_cls_model(cls: _T) -> _T: from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.sequence import IntermediateTensors - from .utils import maybe_prefix + from .utils import get_model_hidden_size, maybe_prefix class ModelForSequenceClassification(_create_pooling_model_cls(cls), SupportsCrossEncoding): @@ -263,9 +301,10 @@ def as_seq_cls_model(cls: _T) -> _T: def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + hidden_size = get_model_hidden_size(config) self.score = ReplicatedLinear( - config.hidden_size, + hidden_size, config.num_labels, bias=False, params_dtype=torch.float32, @@ -398,6 +437,7 @@ def load_weights_using_from_2_way_softmax( from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config + tokens = getattr(model.config, "classifier_from_token", []) tokens = cast(list[int], tokens) assert len(tokens) == 2 @@ -405,9 +445,10 @@ def load_weights_using_from_2_way_softmax( if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: + quant_config = model.vllm_config.quant_config model.lm_head = ParallelLMHead(model.config.vocab_size, model.config.hidden_size, - quant_config=model.quant_config) + quant_config=quant_config) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) @@ -451,9 +492,10 @@ def load_weights_no_post_processing(model, if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: + quant_config = model.vllm_config.quant_config model.lm_head = ParallelLMHead(model.config.vocab_size, model.config.hidden_size, - quant_config=model.quant_config) + quant_config=quant_config) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 0de683d2cd060..f6400b05e110a 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -415,6 +415,12 @@ class ApertusModel(nn.Module): (".qkv_proj", ".v_proj", "v"), ] params_dict = dict(self.named_parameters()) + + # we need to load the buffers for beta and eps (XIELU) + for name, buffer in self.named_buffers(): + if name.endswith(".beta") or name.endswith(".eps"): + params_dict[name] = buffer + loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 13ed4da0602ad..be82c2fd59644 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -342,7 +342,7 @@ class ArceeModel(nn.Module): class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): """Arcee Model for causal language modeling, integrated with vLLM runtime.""" - # Map fused module names to their sub-module components + # Map fused module names to their submodule components # (for quantization and LoRA) packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index c566611266af7..b6dd559968415 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -427,6 +427,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): self.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 1c7960fa3e0a5..a7cb6b35a4ab4 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -143,16 +143,8 @@ class AriaProjector(nn.Module): projects ViT's outputs into MoE's inputs. Args: - patch_to_query_dict (dict): Maps patch numbers to their corresponding - query numbers, - e.g., {1225: 128, 4900: 256}. This allows for different query sizes - based on image resolution. - embed_dim (int): Embedding dimension. - num_heads (int): Number of attention heads. - kv_dim (int): Dimension of key and value. - ff_dim (int): Hidden dimension of the feed-forward network. - output_dim (int): Output dimension. - norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig) + containing projector configuration parameters. Outputs: A tensor with the shape of (batch_size, query_number, output_dim) @@ -282,8 +274,8 @@ class AriaTextMoELayer(nn.Module): Forward pass of the MoE Layer. Args: - hidden_states (torch.Tensor): Input tensor of shape (batch_size, - sequence_length, hidden_size). + hidden_states: Input tensor of shape + (batch_size, sequence_length, hidden_size). Returns: torch.Tensor: Output tensor after passing through the MoE layer. @@ -547,6 +539,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): config.text_config.hidden_size, org_num_embeddings=self.language_model.org_vocab_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 4563c356666ac..ae25033410407 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -394,7 +395,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, position_embedding=position_embedding) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) self.lm_head.weight.weight_loader = self.lm_head_weight_loader if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index a42640cef9d44..5f6025abf315c 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -68,6 +67,7 @@ class BailingAttention(nn.Module): config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, prefix: str = "", ): super().__init__() @@ -84,10 +84,11 @@ class BailingAttention(nn.Module): self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) self.q_size_per_rank = self.head_dim * self.num_heads - self.num_kv_heads = self.total_kv_heads // tp_size self.kv_size_per_rank = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.use_rmsnorm = getattr(config, "use_rmsnorm", False) self.query_key_value = QKVParallelLinear( self.hidden_size, @@ -99,28 +100,45 @@ class BailingAttention(nn.Module): prefix=f"{prefix}.query_key_value", ) + if self.use_qk_norm: + self.query_layernorm = (RMSNorm( + self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm + else nn.LayerNorm(self.head_dim, eps=1e-6)) + self.key_layernorm = (RMSNorm( + self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm + else nn.LayerNorm(self.head_dim, eps=1e-6)) + self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=config.use_bias, quant_config=quant_config, + reduce_results=reduce_results, prefix=f"{prefix}.dense", ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - prefix=f"{prefix}.attn") + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1.0) + + self.rotary_dim = getattr(config, "rotary_dim", self.head_dim) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, + rotary_dim=self.rotary_dim, max_position=config.max_position_embeddings, base=config.rope_theta, is_neox_style=True, rope_scaling=config.rope_scaling, + partial_rotary_factor=self.partial_rotary_factor, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", ) def forward( @@ -135,6 +153,14 @@ class BailingAttention(nn.Module): ], dim=-1) + if self.use_qk_norm: + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_kv_heads, self.head_dim) + q = self.query_layernorm(q) + k = self.key_layernorm(k) + q = q.view(-1, self.q_size_per_rank) + k = k.view(-1, self.kv_size_per_rank) + q, k = self.rotary_emb(position_ids, q, k) context_layer = self.attn(q, k, v) @@ -198,24 +224,72 @@ class BailingMoE(nn.Module): self.hidden_size = config.hidden_size self.quant_config = quant_config self.num_shared_experts = config.num_shared_experts - # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(self.hidden_size, - self.num_experts, - bias=False, - quant_config=None) + self.score_function = getattr(config, "score_function", None) + self.n_group = getattr(config, "n_group", None) + self.topk_group = getattr(config, "topk_group", None) + self.use_grouped_topk = (self.n_group is not None + and self.topk_group is not None) + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", + 1.0) - self.experts = FusedMoE(num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=self.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=self.norm_expert_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts") + router_dtype = getattr(config, "router_dtype", None) + if router_dtype is None: + self.router_dtype = None + elif router_dtype == "fp32": + self.router_dtype = torch.float32 + else: + self.router_dtype = torch.bfloat16 + + self.gate = nn.Linear( + self.hidden_size, + self.num_experts, + bias=False, + dtype=self.router_dtype, + ) + + if getattr(config, "moe_router_enable_expert_bias", False): + self.gate.expert_bias = nn.Parameter( + torch.empty((config.num_experts, ), dtype=torch.float32)) + else: + self.gate.expert_bias = None + + self.correction_bias = (self.gate.expert_bias.data + if self.gate.expert_bias is not None else None) + + if self.score_function is not None: + assert ( + self.score_function == "softmax" + and self.correction_bias is None + ) or ( + self.score_function == "sigmoid" + and self.correction_bias is not None + ), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501 + else: + # default value for scoring_func + self.score_function = "softmax" + + self.experts = FusedMoE( + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + scoring_func=self.score_function, + e_score_correction_bias=self.gate.expert_bias, + num_expert_group=self.n_group, + topk_group=self.topk_group, + use_grouped_topk=self.use_grouped_topk, + ) if self.num_shared_experts > 0: - intermediate_size = (config.moe_intermediate_size * - self.num_shared_experts) + if hasattr(config, "moe_shared_expert_intermediate_size"): + intermediate_size = config.moe_shared_expert_intermediate_size + else: + intermediate_size = config.moe_intermediate_size + intermediate_size *= config.num_shared_experts self.shared_experts = BailingMLP( intermediate_size=intermediate_size, config=config, @@ -228,14 +302,18 @@ class BailingMoE(nn.Module): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_size) - if self.num_shared_experts > 0: + if self.shared_experts: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) + router_logits = self.gate(hidden_states.to(self.router_dtype)) + router_logits = router_logits.to(hidden_states.dtype) + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - if self.num_shared_experts > 0: + final_hidden_states *= self.routed_scaling_factor + + if self.shared_experts: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: @@ -254,20 +332,30 @@ class BailingMoeBlock(nn.Module): prefix: str = "", ): super().__init__() + layer_idx = int(prefix.split('.')[-1]) + self.config = config hidden_size = config.hidden_size intermediate_size = config.intermediate_size + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) self.attention = BailingAttention(config, cache_config, quant_config, prefix=f"{prefix}.attention") + self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) - self.mlp = BailingMoE(intermediate_size, - config, - quant_config, - True, - prefix=f"{prefix}.mlp") + + # Choose MLP class based on the number of experts and layer index + if layer_idx < config.first_k_dense_replace: + mlp_class = BailingMLP + else: + mlp_class = BailingMoE + self.mlp = mlp_class(intermediate_size, + config, + quant_config, + True, + prefix=f"{prefix}.mlp") def forward( self, @@ -310,11 +398,17 @@ class BailingMoeModel(nn.Module): self.config = config self.vocab_size = config.vocab_size self.embed_dim = config.hidden_size + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", + False) - if get_pp_group().is_first_rank or (config.tie_word_embeddings + if get_pp_group().is_first_rank or (self.tie_word_embeddings and get_pp_group().is_last_rank): self.word_embeddings = VocabParallelEmbedding( - self.vocab_size, self.embed_dim) + self.vocab_size, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.word_embeddings", + ) else: self.word_embeddings = PPMissingLayer() @@ -372,8 +466,11 @@ class BailingMoeModel(nn.Module): "hidden_states": hidden_states, "residual": residual }) - - hidden_states, _ = self.norm(hidden_states, residual) + else: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -396,7 +493,8 @@ class BailingMoeModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if self.config.norm_head and "lm_head.weight" in name: + if (hasattr(self.config, "norm_head") and self.config.norm_head + and "lm_head.weight" in name): loaded_weight = F.normalize(loaded_weight, dim=0, p=2, @@ -430,13 +528,17 @@ class BailingMoeModel(nn.Module): if is_pp_missing_parameter(name, self): continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: if name.endswith(".bias") and name not in params_dict: @@ -473,19 +575,30 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ) -> None: super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config.get_text_config() + vllm_config.model_config.hf_config = config quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config self.config = config + self.lora_config = lora_config self.quant_config = quant_config self.max_position_embeddings = config.max_position_embeddings self.model = BailingMoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", + False) + if get_pp_group().is_last_rank: - self.lm_head = (self.word_embeddings if config.tie_word_embeddings - else ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config)) + if self.tie_word_embeddings: + self.lm_head = self.model.word_embeddings + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) self.logits_processor = LogitsProcessor(config.vocab_size) else: self.lm_head = PPMissingLayer() @@ -520,10 +633,13 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None), ) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + + +class BailingMoeV2ForCausalLM(BailingMoeForCausalLM): + pass diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index a72bbdebe5317..397089f31cdf6 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -514,6 +514,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) # Used to track and store by the Mamba cache between steps. self.mamba_cache: Optional[MambaCacheManager] = None diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py deleted file mode 100644 index 32551d8102f32..0000000000000 --- a/vllm/model_executor/models/bart.py +++ /dev/null @@ -1,1342 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Derived from BART implementation posted on HuggingFace; license below: -# -# coding=utf-8 -# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch BART model.""" -import math -from collections.abc import Iterable -from typing import Optional - -import torch -from torch import nn -from transformers import BartConfig -from transformers.utils import logging - -from vllm.attention import Attention, AttentionType -from vllm.config import CacheConfig, LoRAConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVCrossParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsQuant, SupportsV0Only -from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, - maybe_prefix) - -logger = logging.get_logger(__name__) - - -def get_bsz_seq_len(input_ids): - shp = input_ids.shape - ndim = len(shp) - if ndim == 1: - return 1, input_ids.numel() - else: - return shp[:2] - - -class BartLearnedPositionalEmbedding(VocabParallelEmbedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int): - # Bart is set up so that if padding_idx is - # specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. - # Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - - def forward( - self, - positions: torch.Tensor, - ) -> torch.Tensor: - """`input_ids' shape is expected to be [bsz x seqlen].""" - return super().forward(positions + self.offset) - - -class BartScaledWordEmbedding(VocabParallelEmbedding): - """ - This module overrides VocabParallelEmbedding's - forward by multiplying with embeddings scale. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - embed_scale: float = 1.0): - super().__init__(num_embeddings, embedding_dim) - self.embed_scale = embed_scale - - def forward(self, input_ids: torch.Tensor) -> torch.Tensor: - return super().forward(input_ids) * self.embed_scale - - -class BartParallelLMHead(ParallelLMHead): - """ - This module overrides ParallelLMHead's - forward by dividing by embeddings scale, - yielding effectively the inverse of - BartScaledWordEmbedding - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - embed_scale: float = 1.0): - super().__init__(num_embeddings, embedding_dim) - self.embed_scale = embed_scale - - def forward(self, input_ids: torch.Tensor) -> torch.Tensor: - return super().forward(input_ids) / self.embed_scale - - -class BartEncoderAttention(nn.Module): - - def __init__( - self, - embed_dim: int, - num_heads: int, - bias: bool = True, - config: Optional[BartConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.d_model = config.d_model - self.embed_dim = embed_dim - self.total_num_heads = num_heads - self.total_num_kv_heads = self.total_num_heads - self.head_dim = embed_dim // num_heads - self.config = config - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError(f"embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") - self.scaling = self.head_dim**-0.5 - - self.qkv_proj = QKVParallelLinear( - self.d_model, - self.d_model // self.total_num_heads, - self.total_num_heads, - self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - ) - - self.out_proj = RowParallelLinear( - embed_dim, - embed_dim, - bias=bias, - quant_config=quant_config, - ) - - tp_world_size = get_tensor_model_parallel_world_size() - assert self.total_num_heads % tp_world_size == 0 - self.num_heads = self.total_num_heads // tp_world_size - - if self.total_num_kv_heads >= tp_world_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_world_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = self.num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Input shape: Batch x Time x Channel""" - - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - attn_output = self.attn(q, k, v) - - output, _ = self.out_proj(attn_output) - return output - - -class BartDecoderSelfAttention(nn.Module): - - def __init__( - self, - embed_dim: int, - num_heads: int, - bias: bool = True, - config: Optional[BartConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.d_model = config.d_model - self.embed_dim = embed_dim - self.total_num_heads = num_heads - self.total_num_kv_heads = self.total_num_heads - self.head_dim = embed_dim // num_heads - self.config = config - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError(f"embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") - self.scaling = self.head_dim**-0.5 - - self.qkv_proj = QKVParallelLinear( - self.d_model, - self.d_model // self.total_num_heads, - self.total_num_heads, - self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - ) - - self.out_proj = RowParallelLinear( - embed_dim, - embed_dim, - bias=bias, - quant_config=quant_config, - ) - - tp_world_size = get_tensor_model_parallel_world_size() - assert self.total_num_heads % tp_world_size == 0 - self.num_heads = self.total_num_heads // tp_world_size - - if self.total_num_kv_heads >= tp_world_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_world_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = self.num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.DECODER) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Input shape: Batch x Time x Channel""" - - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - attn_output = self.attn(q, k, v) - - output, _ = self.out_proj(attn_output) - return output - - -class BartCrossAttention(nn.Module): - - def __init__( - self, - embed_dim: int, - num_heads: int, - bias: bool = True, - config: Optional[BartConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.d_model = config.d_model - self.embed_dim = embed_dim - self.total_num_heads = num_heads - self.total_num_kv_heads = self.total_num_heads - self.head_dim = embed_dim // num_heads - self.config = config - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError(f"embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads}).") - self.scaling = self.head_dim**-0.5 - - # TP sharding sizes is accounted for within "*Parallel" layers. - self.qkv_proj = QKVCrossParallelLinear(self.d_model, - self.d_model // - self.total_num_heads, - self.total_num_heads, - self.total_num_kv_heads, - bias, - quant_config=quant_config) - - self.out_proj = RowParallelLinear( - embed_dim, - embed_dim, - bias=bias, - quant_config=quant_config, - ) - - tp_world_size = get_tensor_model_parallel_world_size() - assert self.total_num_heads % tp_world_size == 0 - self.num_heads = self.total_num_heads // tp_world_size - - if self.total_num_kv_heads >= tp_world_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_world_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_world_size % self.total_num_kv_heads == 0 - self.num_kv_heads = self.num_heads # No GQA in bart - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_DECODER) - - def forward( - self, - decoder_hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Input shape: Batch x Time x Channel""" - - q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states) - - attn_output = self.attn(q, k, v) - - output, _ = self.out_proj(attn_output) - return output - - -class BartEncoderLayer(nn.Module): - - def __init__( - self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = BartEncoderAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.activation_fn = get_act_fn(config.activation_function) - - ffn_hidden_size = self.embed_dim - ffn_intermediate_size = config.encoder_ffn_dim - ffn_has_bias = True - self.fc1 = ColumnParallelLinear( - ffn_hidden_size, - ffn_intermediate_size, - bias=ffn_has_bias, - quant_config=quant_config, - ) - self.act = get_act_fn("gelu") - self.fc2 = RowParallelLinear( - ffn_intermediate_size, - ffn_hidden_size, - bias=ffn_has_bias, - quant_config=quant_config, - ) - - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - r""" - Args: - hidden_states - torch.Tensor of *encoder* input embeddings. - Returns: - Encoder layer output torch.Tensor - """ - residual = hidden_states - hidden_states = self.self_attn(hidden_states=hidden_states) - - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - fc1_out, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(fc1_out) - - hidden_states, _ = self.fc2(hidden_states) - - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): - hidden_states = cast_overflow_tensors(hidden_states) - - return hidden_states - - -class BartDecoderLayer(nn.Module): - - def __init__( - self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = BartDecoderSelfAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - self.activation_fn = get_act_fn(config.activation_function) - - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - ''' - afeldman-nm: personally I would call this "cross-attention", - however I left the name as "encoder_attn" to maintain consistency - with the name of the pretrained weights. - ''' - self.encoder_attn = BartCrossAttention( - self.embed_dim, - config.decoder_attention_heads, - config=config, - prefix=f"{prefix}.encoder_attn", - ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - - ffn_hidden_size = self.embed_dim - ffn_intermediate_size = config.encoder_ffn_dim - ffn_has_bias = True - self.fc1 = ColumnParallelLinear( - ffn_hidden_size, - ffn_intermediate_size, - bias=ffn_has_bias, - quant_config=quant_config, - ) - self.fc2 = RowParallelLinear( - ffn_intermediate_size, - ffn_hidden_size, - bias=ffn_has_bias, - quant_config=quant_config, - ) - - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - decoder_hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - decoder_hidden_states - torch.Tensor of *decoder* input embeddings. - encoder_hidden_states - torch.Tensor of *encoder* input embeddings. - Returns: - Decoder layer output torch.Tensor - """ - residual = decoder_hidden_states - - # Self Attention - hidden_states = self.self_attn(hidden_states=decoder_hidden_states) - - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - - residual = hidden_states - - hidden_states = self.encoder_attn( - decoder_hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # Fully Connected - residual = hidden_states - fc1_out, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(fc1_out) - - hidden_states, _ = self.fc2(hidden_states) - - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - return hidden_states - - -class BartEncoder(nn.Module): - """ - Transformer encoder consisting of *config.encoder_layers* - self attention layers. Each layer is a [`BartEncoderLayer`]. - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__(self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None, - prefix: str = ""): - super().__init__() - - self.cache_config = cache_config - self.quant_config = quant_config - self.lora_config = lora_config - embed_dim = config.d_model - self.max_source_positions = config.max_position_embeddings - embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - embed_dim, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - embed_dim, - ) - self.layers = nn.ModuleList([ - BartEncoderLayer(config, - cache_config, - quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(config.encoder_layers) - ]) - - self.layernorm_embedding = nn.LayerNorm(embed_dim) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *encoder* input sequence tokens. - Returns: - Decoder output torch.Tensor - """ - # retrieve input_ids and inputs_embeds - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - embed_pos = self.embed_positions(positions) - embed_pos = embed_pos.to(inputs_embeds.device) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - - for encoder_layer in self.layers: - hidden_states = encoder_layer(hidden_states=hidden_states) - - return hidden_states - - -class BartDecoder(nn.Module): - """ - Transformer decoder consisting of *config.decoder_layers* layers. - Each layer is a [`BartDecoderLayer`] - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__( - self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None, - prefix: str = "", - ): - super().__init__() - self.cache_config = cache_config - self.quant_config = quant_config - self.lora_config = lora_config - self.max_target_positions = config.max_position_embeddings - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - config.d_model, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - ) - - self.layers = nn.ModuleList( - [BartDecoderLayer(config,cache_config,quant_config, - prefix=f"{prefix}.layers.{layer_idx}") \ - for layer_idx in range(config.decoder_layers)]) - - self.layernorm_embedding = nn.LayerNorm(config.d_model) - - def forward( - self, - decoder_input_ids: torch.Tensor, - decoder_positions: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - decoder_input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - decoder_positions - Positions of *decoder* input sequence tokens. - encoder_hidden_states: - Tensor of encoder output embeddings - Returns: - Decoder output torch.Tensor - """ - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(decoder_input_ids) - else: - decoder_positions = inputs_embeds[:, -1] - - # embed positions - embed_pos = self.embed_positions(decoder_positions) - embed_pos = embed_pos.to(inputs_embeds.device) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - - # decoder layers - - for decoder_layer in self.layers: - hidden_states = decoder_layer( - decoder_hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - return hidden_states - - -class BartModel(nn.Module, SupportsQuant): - _tied_weights_keys = [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - ] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.encoder = BartEncoder(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.encoder") - self.decoder = BartDecoder(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decoder") - - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *decoder* input sequence tokens. - encoder_input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - encoder_positions: - Positions of *encoder* input sequence tokens. - Returns: - Model output torch.Tensor - """ - - encoder_hidden_states = None - - if encoder_input_ids.numel() > 0: - # Run encoder attention if a non-zero number of encoder tokens - # are provided as input - encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions) - - # decoder outputs consists of - # (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - decoder_input_ids=input_ids, - decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states) - - return decoder_outputs - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - other_weights = [] - loaded_stacked_params = [] - model_params_dict = dict(self.named_parameters()) - - for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if name not in model_params_dict: - continue - param = model_params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - loaded_stacked_params.append(name) - break - else: - if name in model_params_dict: - other_weights.append((name, loaded_weight)) - - loader = AutoWeightsLoader(self) - loaded_params = loader.load_weights(other_weights) - loaded_params.update(loaded_stacked_params) - return loaded_params - - -class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "decoder.": "model.decoder.", - "encoder.": "model.encoder.", - "shared.": "model.shared." - }, - orig_to_new_substr={ - "beta": "bias", - "gamma": "weight", - "LayerNorm": "layernorm", - }, - ) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - - super().__init__() - config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config - # currently all existing BART models have `tie_word_embeddings` enabled - assert config.tie_word_embeddings - self.config = config - self.model = BartModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.lm_head = BartParallelLMHead(config.vocab_size, - config.d_model, - embed_scale=embed_scale) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - *, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - r""" - Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices - Returns: - Output torch.Tensor - """ - return self.model(input_ids, positions, encoder_input_ids, - encoder_positions) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - weights_tuple_list = list(weights) - - shared_embedding_weight = None - for name, loaded_weight in weights_tuple_list: - if ('shared.weight' in name - or 'encoder.embed_tokens.weight' in name - or 'decoder.embed_tokens.weight' in name - or 'lm_head.weight' in name): - assert shared_embedding_weight is None, ( - "Conflicting embedding weights.") - shared_embedding_weight = loaded_weight - - loader = AutoWeightsLoader( - self, - skip_prefixes=(["cls.", "pooler."]), - ) - loaded_params = loader.load_weights(weights_tuple_list, - mapper=self.hf_to_vllm_mapper) - - if shared_embedding_weight is not None: - weight_loader = getattr(self.lm_head.weight, "weight_loader", - default_weight_loader) - weight_loader(self.lm_head.weight, shared_embedding_weight) - - self.model.encoder.embed_tokens.weight = self.lm_head.weight - self.model.decoder.embed_tokens.weight = self.lm_head.weight - loaded_params.update({ - 'model.encoder.embed_tokens.weight', 'lm_head.weight', - 'model.decoder.embed_tokens.weight' - }) - - return loaded_params - - -class MBartEncoderLayer(BartEncoderLayer): - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - r""" - Args: - hidden_states - torch.Tensor of *encoder* input embeddings. - Returns: - Encoder layer output torch.Tensor - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states) - - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - fc1_out, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(fc1_out) - - hidden_states, _ = self.fc2(hidden_states) - - hidden_states = residual + hidden_states - - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): - hidden_states = cast_overflow_tensors(hidden_states) - - return hidden_states - - -class MBartDecoderLayer(BartDecoderLayer): - - def forward( - self, - decoder_hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - residual = decoder_hidden_states - hidden_states = self.self_attn_layer_norm(decoder_hidden_states) - - # Self Attention - hidden_states = self.self_attn(hidden_states=hidden_states) - - hidden_states = residual + hidden_states - - # Cross-Attention Block - - residual = hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - hidden_states = self.encoder_attn( - decoder_hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - fc1_out, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(fc1_out) - - hidden_states, _ = self.fc2(hidden_states) - - hidden_states = residual + hidden_states - - return hidden_states - - -class MBartEncoder(nn.Module): - """ - Transformer encoder consisting of *config.encoder_layers* - self attention layers. Each layer is a [`BartEncoderLayer`]. - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__(self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None, - prefix: str = ""): - super().__init__() - - self.cache_config = cache_config - self.quant_config = quant_config - self.lora_config = lora_config - embed_dim = config.d_model - self.max_source_positions = config.max_position_embeddings - embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - embed_dim, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - embed_dim, - ) - self.layers = nn.ModuleList([ - MBartEncoderLayer(config, - cache_config, - quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(config.encoder_layers) - ]) - - self.layernorm_embedding = nn.LayerNorm(embed_dim) - self.layer_norm = nn.LayerNorm(config.d_model) # 改动 - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *encoder* input sequence tokens. - Returns: - Decoder output torch.Tensor - """ - # retrieve input_ids and inputs_embeds - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - embed_pos = self.embed_positions(positions) - embed_pos = embed_pos.to(inputs_embeds.device) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - - for encoder_layer in self.layers: - hidden_states = encoder_layer(hidden_states=hidden_states) - - hidden_states = self.layer_norm(hidden_states) - return hidden_states - - -class MBartDecoder(nn.Module): - """ - Transformer decoder consisting of *config.decoder_layers* layers. - Each layer is a [`BartDecoderLayer`] - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__( - self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None, - prefix: str = "", - ): - super().__init__() - self.cache_config = cache_config - self.quant_config = quant_config - self.lora_config = lora_config - self.max_target_positions = config.max_position_embeddings - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, - config.d_model, - embed_scale=embed_scale) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbedding( - config.max_position_embeddings, - config.d_model, - ) - - self.layers = nn.ModuleList( - [MBartDecoderLayer(config, cache_config, quant_config, - prefix=f"{prefix}.layers.{layer_idx}") \ - for layer_idx in range(config.decoder_layers)]) - - self.layernorm_embedding = nn.LayerNorm(config.d_model) - self.layer_norm = nn.LayerNorm(config.d_model) - - def forward( - self, - decoder_input_ids: torch.Tensor, - decoder_positions: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - decoder_input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - decoder_positions - Positions of *decoder* input sequence tokens. - encoder_hidden_states: - Tensor of encoder output embeddings - Returns: - Decoder output torch.Tensor - """ - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(decoder_input_ids) - else: - decoder_positions = inputs_embeds[:, -1] - - # embed positions - embed_pos = self.embed_positions(decoder_positions) - embed_pos = embed_pos.to(inputs_embeds.device) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - - # decoder layers - - for decoder_layer in self.layers: - hidden_states = decoder_layer( - decoder_hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - ) - - hidden_states = self.layer_norm(hidden_states) - return hidden_states - - -class MBartModel(nn.Module, SupportsQuant): - _tied_weights_keys = [ - "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" - ] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.encoder = MBartEncoder(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.encoder") - self.decoder = MBartDecoder(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decoder") - - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *decoder* input sequence tokens. - encoder_input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - encoder_positions: - Positions of *encoder* input sequence tokens. - Returns: - Model output torch.Tensor - """ - - encoder_hidden_states = None - - if encoder_input_ids.numel() > 0: - # Run encoder attention if a non-zero number of encoder tokens - # are provided as input - encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions) - - # decoder outputs consists of - # (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - decoder_input_ids=input_ids, - decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states) - - return decoder_outputs - - -class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): - base_model_prefix = "model" - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "decoder.": "model.decoder.", - "encoder.": "model.encoder.", - "shared.": "model.shared." - }, - orig_to_new_substr={ - "beta": "bias", - "gamma": "weight", - "LayerNorm": "layernorm", - }, - ) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - lora_config = vllm_config.lora_config - assert config.tie_word_embeddings - self.config = config - self.model = MBartModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.lm_head = BartParallelLMHead(config.vocab_size, - config.d_model, - embed_scale=embed_scale) - - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - *, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - return self.model(input_ids, positions, encoder_input_ids, - encoder_positions) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - model_params_dict = dict(self.named_parameters()) - loaded_params = set() - remaining_weights = [] - shared_embedding_weight = None - - for name, loaded_weight in weights: - if any(skip in name - for skip in ["cls.", "pooler.", "final_logits_bias"]): - continue - if any(embed_name in name for embed_name in [ - 'shared.weight', 'encoder.embed_tokens.weight', - 'decoder.embed_tokens.weight' - ]): - if shared_embedding_weight is None: - shared_embedding_weight = loaded_weight - continue - is_stacked = False - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - vllm_name = name - for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items( - ): - vllm_name = vllm_name.replace(src, dst) - for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items( - ): - if vllm_name.startswith(src): - vllm_name = dst + vllm_name[len(src):] - break - vllm_name = vllm_name.replace(weight_name, param_name) - if vllm_name in model_params_dict: - param = model_params_dict[vllm_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(vllm_name) - is_stacked = True - break - if not is_stacked: - remaining_weights.append((name, loaded_weight)) - loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."]) - auto_loaded_params = loader.load_weights(remaining_weights, - mapper=self.hf_to_vllm_mapper) - loaded_params.update(auto_loaded_params) - if shared_embedding_weight is not None: - lm_head_param = self.lm_head.weight - weight_loader = getattr(lm_head_param, "weight_loader", - default_weight_loader) - weight_loader(lm_head_param, shared_embedding_weight) - self.model.encoder.embed_tokens.weight = self.lm_head.weight - self.model.decoder.embed_tokens.weight = self.lm_head.weight - loaded_params.update({ - 'model.encoder.embed_tokens.weight', 'lm_head.weight', - 'model.decoder.embed_tokens.weight' - }) - return loaded_params diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 8f23439655ed7..ee32587f6b1b4 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -562,7 +562,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, self.bert = BertPoolingModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), embedding_class=BertEmbedding) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.classifier = nn.Linear(config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -609,3 +611,55 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) + + +@default_pooling_type("ALL") +class BertForTokenClassification(nn.Module): + is_pooling_model = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.head_dtype = vllm_config.model_config.head_dtype + self.num_labels = config.num_labels + self.bert = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding) + self.classifier = nn.Linear(config.hidden_size, + config.num_labels, + dtype=self.head_dtype) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + }) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if token_type_ids is not None: + assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) + assert input_ids is not None + _encode_token_type_ids(input_ids, token_type_ids) + + hidden_states = self.bert(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + + hidden_states = hidden_states.to(self.head_dtype) + return self.classifier(hidden_states) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 3be7e11d947d5..bfc1408ddf880 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -15,8 +15,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import (get_act_and_mul_fn, get_act_fn) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, torch_vllm_outplace_fused_experts) +from vllm.model_executor.layers.fused_moe import (activation_without_mul, + fused_topk) from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -230,7 +230,7 @@ class NomicMoE(nn.Module): self.hidden_size = hidden_size self.total_intermediate_size = intermediate_size self.intermediate_size = divide(intermediate_size, self.tp_size) - self.hidden_act = hidden_act + self.hidden_act = activation_without_mul(hidden_act) if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -297,14 +297,14 @@ class NomicMoE(nn.Module): router_logits, self.top_k, renormalize=False) - final_hidden_states = torch_vllm_outplace_fused_experts( + + final_hidden_states = torch.ops.vllm.outplace_fused_experts( hidden_states=hidden_states, w1=self.w1, w2=self.w2, topk_weights=topk_weights, topk_ids=topk_ids, activation=self.hidden_act, - is_act_and_mul=False, ) if self.tp_size > 1: @@ -637,14 +637,14 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): self.new = GteNewModel(vllm_config=vllm_config, prefix=prefix, add_pooling_layer=True) - self.classifier = RowParallelLinear(config.hidden_size, - config.num_labels, - input_is_parallel=False, - bias=True, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "classifier"), - return_bias=False) + self.classifier = ReplicatedLinear( + config.hidden_size, + config.num_labels, + bias=True, + quant_config=quant_config, + params_dtype=vllm_config.model_config.head_dtype, + prefix=maybe_prefix(prefix, "classifier"), + return_bias=False) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index ed98a3008c567..c1e7a7d498b11 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -678,7 +678,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each input image. Info: [Blip2ImageInputs][] diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 13ecda0122be6..4c37622b049c8 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -257,7 +257,7 @@ class BloomModel(nn.Module): config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) + return self.word_embeddings(input_ids) def forward( self, @@ -271,6 +271,7 @@ class BloomModel(nn.Module): hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) + hidden_states = self.word_embeddings_layernorm(hidden_states) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -329,7 +330,9 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): self.lm_head = self.transformer.word_embeddings else: self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size) + self.config.hidden_size, + prefix=maybe_prefix( + prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 28a1a66c23291..7a56236483749 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -960,6 +960,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 377b7bf26a07a..ce3d23763ed64 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -24,6 +24,14 @@ class VerifyAndUpdateConfig: raise NotImplementedError +class Gemma3TextModelConfig: + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + hf_config = vllm_config.model_config.hf_config + hf_config.is_causal = not hf_config.use_bidirectional_attention + + class GteNewModelConfig(VerifyAndUpdateConfig): @staticmethod @@ -210,8 +218,10 @@ class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config - config.num_labels = 1 + pooler_config = vllm_config.model_config.pooler_config + if pooler_config.logit_bias is None: + pooler_config.logit_bias = 2.65 class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): @@ -252,9 +262,9 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - decoding_config = vllm_config.decoding_config - if decoding_config.reasoning_backend == "": - decoding_config.reasoning_backend = "GptOss" + structured_outputs_config = vllm_config.structured_outputs_config + if structured_outputs_config.reasoning_parser == "": + structured_outputs_config.reasoning_parser = "openai_gptoss" # Increase the max capture size from 512 to 1024 for performance. # NOTE(woosuk): This will increase the number of CUDA graphs @@ -302,7 +312,8 @@ class MambaModelConfig(VerifyAndUpdateConfig): # TODO(tdoublep): remove as full cuda graph support is added FCG_NOT_SUPPORTED_MODELS = [ - "Lfm2ForCausalLM", "MiniMaxText01ForCausalLM" + "Lfm2ForCausalLM", + "MiniMaxText01ForCausalLM", ] if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS @@ -407,6 +418,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, "GteNewForSequenceClassification": GteNewModelConfig, + "Gemma3TextModel": Gemma3TextModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 519cd522213b2..003cf4563a22f 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -438,6 +438,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP): org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 3f9349d766df6..59c9921881497 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -37,7 +37,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, @@ -163,13 +163,19 @@ class DeepseekMoE(nn.Module): shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - inplace=True) + + topk_weights, topk_ids, _ = fused_topk( + hidden_states, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob) + + final_hidden_states = fused_experts(hidden_states, + self.w1, + self.w2, + topk_weights, + topk_ids, + inplace=True) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output @@ -453,9 +459,12 @@ class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.quant_config = quant_config self.model = DeepseekModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 0c9c83cf61000..b1d7f24c2f18b 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -37,8 +37,6 @@ class DeepseekV2Model(nn.Module): super().__init__() self.config = vllm_config. \ speculative_config.draft_model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.vocab_size = self.config.vocab_size @@ -51,11 +49,8 @@ class DeepseekV2Model(nn.Module): self.layers = nn.ModuleList([ DeepseekV2DecoderLayer( - self.config, + vllm_config, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, ) for i in range(self.config.num_hidden_layers) ]) @@ -204,7 +199,8 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) logit_scale = getattr(self.config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.config.vocab_size, diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 0ad001be71c19..8fbf16d206a86 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -43,23 +43,19 @@ class SharedHead(nn.Module): class DeepSeekMultiTokenPredictorLayer(nn.Module): - def __init__( - self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, - cache_config, quant_config) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix) def forward( self, @@ -95,13 +91,8 @@ class DeepSeekMultiTokenPredictor(nn.Module): # to map the exact layer index from weights self.layers = torch.nn.ModuleDict({ str(idx): - DeepSeekMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) + DeepSeekMultiTokenPredictorLayer(vllm_config, + f"{prefix}.layers.{idx}") for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) }) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 36c9427e474e9..636554bd648f2 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -32,29 +32,34 @@ import torch from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +import vllm.envs as envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - MergedReplicatedLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv, direct_register_custom_op from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, @@ -71,19 +76,27 @@ class DeepseekV2MLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + is_sequence_parallel=False, prefix: str = "", ) -> None: super().__init__() + + # If is_sequence_parallel, the input and output tensors are sharded + # across the ranks within the tp_group. In this case the weights are + # replicated and no collective ops are needed. + # Otherwise we use standard TP with an allreduce at the end. self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.gate_up_proj") self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.down_proj") if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -97,17 +110,58 @@ class DeepseekV2MLP(nn.Module): return x +# Chunk x along the num_tokens axis for sequence parallelism +# NOTE: This is wrapped in a torch custom op to work around the following issue: +# The output tensor can have a sequence length 0 at small input sequence lengths +# even though we explicitly pad to avoid this. +def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + # all_gather needs the sequence length to be divisible by tp_size + seq_len = x.size(0) + remainder = seq_len % tp_size + if remainder != 0: + pad_len = tp_size - remainder + x = nn.functional.pad(x, (0, 0, 0, pad_len)) + + chunk = x.shape[0] // tp_size + start = tp_rank * chunk + return torch.narrow(x, 0, start, chunk) + + +def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + seq_len = cdiv(x.size(0), tp_size) + shape = list(x.shape) + shape[0] = seq_len + out = torch.empty(shape, dtype=x.dtype, device=x.device) + return out + + +direct_register_custom_op( + op_name="sequence_parallel_chunk", + op_func=sequence_parallel_chunk, + mutates_args=[], + fake_impl=sequence_parallel_chunk_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + class DeepseekV2MoE(nn.Module): def __init__( self, config: Union[DeepseekV2Config, DeepseekV3Config], + parallel_config: ParallelConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.routed_scaling_factor = config.routed_scaling_factor self.ep_group = get_ep_group().device_group @@ -116,6 +170,21 @@ class DeepseekV2MoE(nn.Module): self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts + # The all_reduce at the end of attention (during o_proj) means that + # inputs are replicated across each rank of the tensor parallel group. + # If using expert-parallelism with DeepEP All2All ops, replicated + # tokens results in useless duplicate computation and communication. + # + # In this case, ensure the input to the experts is sequence parallel + # to avoid the excess work. + # + # Not needed for pplx-kernels as it can handle duplicate input tokens. + self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND + in ("deepep_high_throughput", + "deepep_low_latency") + and parallel_config.enable_expert_parallel + and self.tp_size > 1) + if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") @@ -132,9 +201,8 @@ class DeepseekV2MoE(nn.Module): self.gate.e_score_correction_bias = None # Load balancing settings. - vllm_config = get_current_vllm_config() - eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts @@ -147,64 +215,105 @@ class DeepseekV2MoE(nn.Module): self.physical_expert_end = (self.physical_expert_start + self.n_local_physical_experts) - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - - if config.n_shared_experts is not None: + if config.n_shared_experts is None: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + self.shared_experts = None + else: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, prefix=f"{prefix}.shared_experts", ) + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + + # Chunk the hidden states so they aren't replicated across TP ranks. + # This avoids duplicate computation in self.experts. + # TODO: We can replace the all_reduce at the end of attn with a + # reduce_scatter instead of chunking here. + if self.is_sequence_parallel: + hidden_states = torch.ops.vllm.sequence_parallel_chunk( + hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if hidden_states.dtype != torch.float16: - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - if shared_output is not None: - if hidden_states.dtype != torch.float16: - final_hidden_states = final_hidden_states + shared_output - else: - # Fix FP16 overflow - # See DeepseekV2DecoderLayer for more details. - final_hidden_states = final_hidden_states + shared_output \ - * (1. / self.routed_scaling_factor) + fused_moe_out = self.experts(hidden_states=hidden_states, + router_logits=router_logits) - if self.tp_size > 1: + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out + else: + shared_output = None + final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= (1. / self.routed_scaling_factor) + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states)) @@ -410,12 +519,13 @@ class DeepseekV2MLAAttention(nn.Module): self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: - self.fused_qkv_a_proj = MergedReplicatedLinear( + self.fused_qkv_a_proj = MergedColumnParallelLinear( self.hidden_size, [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], bias=False, quant_config=quant_config, - prefix=f"{prefix}.fused_qkv_a_proj") + prefix=f"{prefix}.fused_qkv_a_proj", + disable_tp=True) else: self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, @@ -468,86 +578,54 @@ class DeepseekV2MLAAttention(nn.Module): mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=self.scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, + rotary_emb=self.rotary_emb, + o_proj=self.o_proj, + fused_qkv_a_proj=self.fused_qkv_a_proj + if self.q_lora_rank is not None else None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa + if self.q_lora_rank is None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else None, + ) + self.mla_attn = MultiHeadLatentAttention( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config, + quant_config, + prefix, ) - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - q_c = None - kv_lora = None - - if self.q_lora_rank is not None: - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_lora = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] - else: - kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] - q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], - dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) - - q = q.view(-1, self.num_local_heads, self.qk_head_dim) - # Add head dim of 1 to k_pe - k_pe = k_pe.unsqueeze(1) - - q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim:], k_pe) - - attn_out = self.mla_attn( - q, - kv_c_normed, - k_pe, - output_shape=(hidden_states.shape[0], - self.num_local_heads * self.v_head_dim)) - return self.o_proj(attn_out)[0] + return self.mla_attn(positions, hidden_states) class DeepseekV2DecoderLayer(nn.Module): - def __init__( - self, - config: Union[DeepseekV2Config, DeepseekV3Config], - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - enable_eplb: bool = False, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -584,9 +662,9 @@ class DeepseekV2DecoderLayer(nn.Module): and layer_idx % config.moe_layer_freq == 0): self.mlp = DeepseekV2MoE( config=config, + parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb, ) else: self.mlp = DeepseekV2MLP( @@ -656,10 +734,7 @@ class DeepseekV2Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.vocab_size = config.vocab_size @@ -675,14 +750,7 @@ class DeepseekV2Model(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer( - config, - prefix, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - enable_eplb=enable_eplb, - ), + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix), prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: @@ -755,9 +823,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, self.model = DeepseekV2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 5eab02b17151c..d7ae8206baca5 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -21,7 +21,8 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems, MultiModalUUIDDict, + NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -290,7 +291,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is @@ -302,7 +303,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return super()._cached_apply_hf_processor( @@ -310,7 +311,7 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py deleted file mode 100644 index c00db52371b68..0000000000000 --- a/vllm/model_executor/models/donut.py +++ /dev/null @@ -1,387 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union - -import torch -import torch.nn as nn -from transformers import BatchFeature, NougatProcessor - -from vllm.config import VllmConfig -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder -from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, - SupportsMultiModal, - SupportsV0Only) -from vllm.model_executor.models.swin import SwinModel -from vllm.model_executor.models.utils import (AutoWeightsLoader, - _flatten_embeddings, flatten_bn) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseProcessingInfo, - EncDecMultiModalProcessor, - PromptIndexTargets, PromptInsertion, - PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.utils.tensor_schema import TensorSchema, TensorShape - - -class MBartDecoderWrapper(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.decoder = MBartDecoder(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decoder") - - def forward(self, *args, **kwargs): - return self.decoder(*args, **kwargs) - - -class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - - self.config = config - self.model = MBartDecoderWrapper(vllm_config=vllm_config, - prefix=f"{prefix}.model") - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.vocab_size = config.vocab_size - self.lm_head = BartParallelLMHead(self.vocab_size, - config.d_model, - embed_scale=embed_scale) - - self.logits_processor = LogitsProcessor(self.vocab_size, - config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - r""" - Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - Returns: - Output torch.Tensor - """ - - return self.model(decoder_input_ids=input_ids, - decoder_positions=positions, - encoder_hidden_states=inputs_embeds) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if "final_logits_bias" in name: - continue - # if self.config.tie_word_embeddings and "embed_tokens" in name: - # continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class DonutImagePixelInputs(TensorSchema): - """ - Dimensions: - - b: Batch size - - c: Number of channels (3) - - h: Height - - w: Width - """ - type: Literal["pixel_values"] - data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")] - - -class DonutProcessingInfo(BaseProcessingInfo): - - def get_hf_config(self): - return self.ctx.get_hf_config() - - def get_hf_processor(self): - return self.ctx.get_hf_processor() - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": 1} - - def get_num_image_tokens(self) -> int: - return 1 - - -class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]): - - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - return "" - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - - target_width, target_height = self.info.get_hf_config( - ).encoder.image_size - - return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - -class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]): - - def _hf_processor_applies_updates( - self, - prompt_text: str, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> bool: - return False - - def create_encoder_prompt( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: - return prompt - - def create_decoder_prompt( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: - return prompt - - @property - def pad_dummy_encoder_prompt(self) -> bool: - return True - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - hf_processor = self.info.get_hf_processor() - if mm_data: - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs) - if isinstance(hf_processor, NougatProcessor): - processed_outputs["input_ids"] = processed_outputs["labels"] - else: - tokenizer = hf_processor.tokenizer - processed_outputs = tokenizer(prompt, - add_special_tokens=False, - return_tensors="pt") - return processed_outputs - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values=MultiModalFieldConfig.batched("image")) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - hf_processor = self.info.get_hf_processor() - tokenizer = hf_processor.tokenizer - pad_token_id = tokenizer.pad_token_id - num_image_tokens = self.info.get_num_image_tokens() - image_tokens = [pad_token_id] * num_image_tokens - - return [ - PromptInsertion( - modality="image", - target=PromptIndexTargets.start(), - insertion=image_tokens, - ) - ] - - -@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor, - info=DonutProcessingInfo, - dummy_inputs=DonutDummyInputsBuilder) -class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsV0Only): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - processor_config = vllm_config.model_config.hf_image_processor_config - - self.config = config - self.vision_config = config.encoder - self.processor_config = processor_config - self.encoder = SwinModel(config=config.encoder) - - self.decoder = DonutLanguageForConditionalGeneration( - vllm_config=vllm_config.with_hf_config(config.decoder), - prefix=f"{prefix}.decoder", - ) - self.pad_token_id = config.pad_token_id - - def _parse_and_validate_image_input(self, **kwargs: object): - pixel_values: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "pixel_values", None) - image_embeds: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "image_embeds", None) - - if pixel_values is None and image_embeds is None: - return None - - if pixel_values is not None and image_embeds is not None: - raise ValueError( - "Both pixel values and image embeds are provided.") - - if pixel_values is not None: - h, w = self.config.encoder.image_size - return DonutImagePixelInputs(type="pixel_values", - data=flatten_bn(pixel_values, - concat=True), - resolve_bindings={ - "h": h, - "w": w, - }) - - if image_embeds is not None: - raise NotImplementedError - - raise AssertionError("This line should be unreachable.") - - def _process_image_input( - self, image_input: DonutImagePixelInputs) -> torch.Tensor: - assert image_input["type"] == "pixel_values" - pixel_values = image_input["data"] - dtype = next(self.encoder.parameters()).dtype - pixel_values = pixel_values.to(dtype) - return self.encoder(pixel_values) - - def get_language_model(self) -> torch.nn.Module: - return self.decoder - - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - return None - vision_embeddings = self._process_image_input(image_input) - return vision_embeddings - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: MultiModalEmbeddings, - ) -> torch.Tensor: - return _flatten_embeddings(multimodal_embeddings) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - *, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - r""" - Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices - Returns: - Output torch.Tensor - """ - - inputs_embeds = None - if encoder_input_ids.numel() > 0: - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(encoder_input_ids, - vision_embeddings) - - hidden_states = self.decoder(input_ids, - positions, - inputs_embeds=inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - return self.decoder.compute_logits(hidden_states, sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index a5477af8694b4..20555e48b73d4 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -137,7 +137,8 @@ class Dots1MoE(nn.Module): topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias) if config.n_shared_experts is not None: @@ -503,7 +504,9 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 33ec27fc630e0..ebab018ed67e7 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -562,7 +562,9 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index d880fc434e20f..3396c67f42b7b 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,6 +34,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from transformers import BatchFeature +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -66,8 +67,6 @@ from .vision import get_vit_attn_backend logger = init_logger(__name__) -_MAX_FRAMES_PER_VIDEO = 16 - # === Vision Transformer === # @@ -172,7 +171,16 @@ class Ernie4_5_VisionAttention(nn.Module): prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -235,7 +243,10 @@ class Ernie4_5_VisionAttention(nn.Module): if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -459,7 +470,11 @@ class Ernie4_5_VisionTransformer(nn.Module): ), "vit's config.hidden must be equal to config.embed_dim" self.ln = nn.LayerNorm(hidden_size, eps=1e-6) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -839,6 +854,15 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + max_image_tokens = self.get_max_image_tokens() + max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts) + return {"image": max_image_tokens, "video": max_video_tokens} + def _get_vision_info( self, *, @@ -964,8 +988,7 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 2) @@ -1315,7 +1338,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) @@ -1423,7 +1446,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 780974c3b758e..7f791852ceb91 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -287,8 +287,13 @@ class Ernie4_5_VLMoeMoE(nn.Module): if self.has_shared_experts: shared_output = self.shared_experts(hidden_states) - if visual_token_mask is not None and visual_token_mask.any(): - # assert visual_token_mask.shape[0] != hidden_states.shape[0] + if visual_token_mask is not None and visual_token_mask.all(): + # only vision modal input + router_logits, _ = self.vision_experts_gate(hidden_states) + final_hidden_states = self.vision_experts( + hidden_states=hidden_states, router_logits=router_logits) + elif visual_token_mask is not None and visual_token_mask.any(): + # text and vision modals input visual_token_mask = visual_token_mask.repeat( 1, self.hidden_size).bool() text_token_mask = ~visual_token_mask @@ -310,7 +315,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): hidden_states=vision_hidden_states, router_logits=vision_router_logits).flatten() else: - # text modal input processing directly + # only text modal input text_router_logits, _ = self.text_experts_gate(hidden_states) final_hidden_states = self.text_experts( @@ -552,7 +557,9 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 90a1267b28f0a..57c5348874378 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -158,7 +158,8 @@ class ErnieMTP(nn.Module, SupportsPP): prefix=maybe_prefix( prefix, "model")) self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size) + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head")) self.sampler = get_sampler() if self.config.tie_word_embeddings: diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 942db0143a457..f503fb0f9364a 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -502,6 +502,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 971fcbd2aa275..9f7d57d938140 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -164,8 +164,8 @@ class Exaone4Attention(nn.Module): is_sliding = config.layer_types[layer_idx] == "sliding_attention" self.sliding_window = config.sliding_window if is_sliding else None - # apply rotary embeddings to every layer - self.apply_all_layers = not is_sliding + # apply rotary embeddings to every layer in full attention models + self.apply_rope_all_layers = "sliding_attention" not in config.layer_types self.rotary_emb = get_rope( self.head_dim, @@ -201,7 +201,7 @@ class Exaone4Attention(nn.Module): k = self.k_norm(k) k = k.flatten(-2, -1) - if self.sliding_window or self.apply_all_layers: + if self.sliding_window or self.apply_rope_all_layers: q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -485,6 +485,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index a9fe0924babd8..42c378e5c389a 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -473,6 +473,7 @@ class FalconForCausalLM(nn.Module, SupportsPP): config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 5e2b6d69124c8..757051b3b1447 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -607,6 +607,7 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, # compatibility if not lora_config else lora_config.lora_vocab_padding_size), + prefix=maybe_prefix(prefix, "lm_head"), ) self.lm_head_multiplier = config.lm_head_multiplier if self.tie_word_embeddings: diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py deleted file mode 100644 index d0881231fb1e7..0000000000000 --- a/vllm/model_executor/models/florence2.py +++ /dev/null @@ -1,1107 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from collections import OrderedDict -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from transformers import BartTokenizer, BatchFeature, PretrainedConfig - -from vllm.config import VllmConfig -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, - BartParallelLMHead, - BartScaledWordEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseProcessingInfo, - EncDecMultiModalProcessor, - PromptIndexTargets, PromptInsertion, - PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsV0Only) -from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings - - -class Florence2ImagePixelInputs(TensorSchema): - """ - Dimensions: - - b: Batch size - - c: Number of channels (3) - - h: Height of the image - - w: Width of the image - """ - - type: Literal["pixel_values"] - - data: Annotated[ - torch.Tensor, - TensorShape("b", 3, "h", "w"), - ] - - -# ViT implementation are all copied from -# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py -class LearnedAbsolutePositionEmbedding2D(nn.Module): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, embedding_dim=256, num_pos=50): - super().__init__() - self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) - self.column_embeddings = nn.Embedding( - num_pos, embedding_dim - (embedding_dim // 2)) - - def forward(self, pixel_values): - """ - pixel_values: (batch_size, height, width, num_channels) - returns: (batch_size, height, width, embedding_dim * 2) - """ - if len(pixel_values.shape) != 4: - raise ValueError('pixel_values must be a 4D tensor') - height, width = pixel_values.shape[1:3] - width_values = torch.arange(width, device=pixel_values.device) - height_values = torch.arange(height, device=pixel_values.device) - x_emb = self.column_embeddings(width_values) - y_emb = self.row_embeddings(height_values) - # (height, width, embedding_dim * 2) - pos = torch.cat([ - x_emb.unsqueeze(0).repeat(height, 1, 1), - y_emb.unsqueeze(1).repeat(1, width, 1) - ], - dim=-1) - # (embedding_dim * 2, height, width) - pos = pos.permute(2, 0, 1) - pos = pos.unsqueeze(0) - # (batch_size, embedding_dim * 2, height, width) - pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) - # (batch_size, height, width, embedding_dim * 2) - pos = pos.permute(0, 2, 3, 1) - return pos - - -class PositionalEmbeddingCosine1D(nn.Module): - """ - This class implements a very simple positional encoding. It follows closely - the encoder from the link below: - https://pytorch.org/tutorials/beginner/translation_transformer.html - Args: - embed_dim: The dimension of the embeddings. - dropout_prob: The dropout probability. - max_seq_len: The maximum length to precompute the positional encodings. - """ - - def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: - super().__init__() - self.embed_dim = embed_dim - self.max_seq_len = max_seq_len - # Generate the sinusoidal arrays. - factor = math.log(10000) - denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / - self.embed_dim) - # Matrix where rows correspond to a positional embedding as a function - # of the position index (i.e., the row index). - frequencies = \ - torch.arange(0, self.max_seq_len) \ - .reshape(self.max_seq_len, 1) * denominator - pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) - # Populate uneven entries. - pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) - pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) - # Save the positional embeddings in a constant buffer. - # self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) - self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed, - requires_grad=False) - - def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: - """ - Args: - seq_embeds: The sequence embeddings in order. Allowed size: - 1. [T, D], where T is the length of the sequence, and D is the - frame embedding dimension. - 2. [B, T, D], where B is the batch size and T and D are the - same as above. - Returns a tensor of with the same dimensions as the input: i.e., - [1, T, D] or [T, D]. - """ - shape_len = len(seq_embeds.shape) - assert 2 <= shape_len <= 3 - len_seq = seq_embeds.size(-2) - assert len_seq <= self.max_seq_len - pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] - # Adapt pre-computed positional embeddings to the input. - if shape_len == 3: - pos_embeds = pos_embeds.view( - (1, pos_embeds.size(0), pos_embeds.size(1))) - return pos_embeds - - -class MySequential(nn.Sequential): - - def forward(self, *inputs): - for module in self._modules.values(): - if isinstance(inputs, tuple): - inputs = module(*inputs) - else: - inputs = module(inputs) - return inputs - - -class PreNorm(nn.Module): - - def __init__(self, norm, fn): - super().__init__() - self.norm = norm - self.fn = fn - - def forward(self, x, *args, **kwargs): - shortcut = x - if self.norm is not None: - x, size = self.fn(self.norm(x), *args, **kwargs) - else: - x, size = self.fn(x, *args, **kwargs) - - x = shortcut + x - - return x, size - - -class Mlp(nn.Module): - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.net = nn.Sequential( - OrderedDict([("fc1", nn.Linear(in_features, hidden_features)), - ("act", act_layer()), - ("fc2", nn.Linear(hidden_features, out_features))])) - - def forward(self, x, size): - return self.net(x), size - - -class DepthWiseConv2d(nn.Module): - - def __init__( - self, - dim_in, - kernel_size, - padding, - stride, - bias=True, - ): - super().__init__() - self.dw = nn.Conv2d(dim_in, - dim_in, - kernel_size=kernel_size, - padding=padding, - groups=dim_in, - stride=stride, - bias=bias) - - def forward(self, x, size): - B, N, C = x.shape - H, W = size - assert N == H * W - - x = self.dw(x.transpose(1, 2).view(B, C, H, W)) - size = (x.size(-2), x.size(-1)) - x = x.flatten(2).transpose(1, 2) - return x, size - - -class ConvEmbed(nn.Module): - """ Image to Patch Embedding - """ - - def __init__(self, - patch_size=7, - in_chans=3, - embed_dim=64, - stride=4, - padding=2, - norm_layer=None, - pre_norm=True): - super().__init__() - self.patch_size = patch_size - - self.proj = nn.Conv2d(in_chans, - embed_dim, - kernel_size=patch_size, - stride=stride, - padding=padding) - - dim_norm = in_chans if pre_norm else embed_dim - self.norm = norm_layer(dim_norm) if norm_layer else None - - self.pre_norm = pre_norm - - def forward(self, x, size): - H, W = size - if len(x.size()) == 3: - if self.norm and self.pre_norm: - x = self.norm(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) - - x = self.proj(x) - - _, _, H, W = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') - if self.norm and not self.pre_norm: - x = self.norm(x) - - return x, (H, W) - - -class ChannelAttention(nn.Module): - - def __init__(self, dim, groups=8, qkv_bias=True): - super().__init__() - - self.groups = groups - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - def forward(self, x, size): - B, N, C = x.shape - - qkv = self.qkv(x).reshape(B, N, 3, self.groups, - C // self.groups).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - q = q * (float(N)**-0.5) - attention = q.transpose(-1, -2) @ k - attention = attention.softmax(dim=-1) - x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - return x, size - - -class ChannelBlock(nn.Module): - - def __init__(self, - dim, - groups, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - conv_at_attn=True, - conv_at_ffn=True): - super().__init__() - - self.conv1 = PreNorm(None, DepthWiseConv2d( - dim, 3, 1, 1)) if conv_at_attn else None - self.channel_attn = PreNorm( - norm_layer(dim), - ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), - ) - self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, - 1)) if conv_at_ffn else None - self.ffn = PreNorm( - norm_layer(dim), - Mlp(in_features=dim, - hidden_features=int(dim * mlp_ratio), - act_layer=act_layer), - ) - - def forward(self, x, size): - if self.conv1: - x, size = self.conv1(x, size) - x, size = self.channel_attn(x, size) - - if self.conv2: - x, size = self.conv2(x, size) - x, size = self.ffn(x, size) - - return x, size - - -def window_partition(x, window_size: int): - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, - C) - windows = x.permute(0, 1, 3, 2, 4, - 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): - B = batch_size - - x = windows.view(B, H // window_size, W // window_size, window_size, - window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - - def __init__(self, dim, num_heads, window_size, qkv_bias=True): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = float(head_dim)**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, size): - - H, W = size - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - x = x.view(B, H, W, C) - - pad_l = pad_t = 0 - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) - _, Hp, Wp, _ = x.shape - - x = window_partition(x, self.window_size) - x = x.view(-1, self.window_size * self.window_size, C) - - # W-MSA/SW-MSA - # attn_windows = self.attn(x_windows) - - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - attn = self.softmax(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - - # merge windows - x = x.view(-1, self.window_size, self.window_size, C) - x = window_reverse(x, B, self.window_size, Hp, Wp) - - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() - - x = x.view(B, H * W, C) - - return x, size - - -class SpatialBlock(nn.Module): - - def __init__(self, - dim, - num_heads, - window_size, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - conv_at_attn=True, - conv_at_ffn=True): - super().__init__() - - self.conv1 = PreNorm(None, DepthWiseConv2d( - dim, 3, 1, 1)) if conv_at_attn else None - self.window_attn = PreNorm( - norm_layer(dim), - WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), - ) - self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, - 1)) if conv_at_ffn else None - self.ffn = PreNorm( - norm_layer(dim), - Mlp(in_features=dim, - hidden_features=int(dim * mlp_ratio), - act_layer=act_layer), - ) - - def forward(self, x, size): - if self.conv1: - x, size = self.conv1(x, size) - x, size = self.window_attn(x, size) - - if self.conv2: - x, size = self.conv2(x, size) - x, size = self.ffn(x, size) - return x, size - - -class DaViT(nn.Module): - - def __init__( - self, - in_chans=3, - num_classes=1000, - depths=(1, 1, 3, 1), - patch_size=(7, 2, 2, 2), - patch_stride=(4, 2, 2, 2), - patch_padding=(3, 0, 0, 0), - patch_prenorm=(False, False, False, False), - embed_dims=(64, 128, 192, 256), - num_heads=(3, 6, 12, 24), - num_groups=(3, 6, 12, 24), - window_size=7, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0.1, - norm_layer=nn.LayerNorm, - enable_checkpoint=False, - conv_at_attn=True, - conv_at_ffn=True, - ): - super().__init__() - - self.num_classes = num_classes - self.embed_dims = embed_dims - self.num_heads = num_heads - self.num_groups = num_groups - self.num_stages = len(self.embed_dims) - self.enable_checkpoint = enable_checkpoint - assert self.num_stages == len(self.num_heads) == len(self.num_groups) - - num_stages = len(embed_dims) - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, - sum(depths) * 2) - ] - - depth_offset = 0 - convs = [] - blocks = [] - for i in range(num_stages): - conv_embed = ConvEmbed( - patch_size=patch_size[i], - stride=patch_stride[i], - padding=patch_padding[i], - in_chans=in_chans if i == 0 else self.embed_dims[i - 1], - embed_dim=self.embed_dims[i], - norm_layer=norm_layer, - pre_norm=patch_prenorm[i]) - convs.append(conv_embed) - - block = MySequential(*[ - MySequential( - OrderedDict([('spatial_block', - SpatialBlock( - embed_dims[i], - num_heads[i], - window_size, - drop_path_rate=dpr[depth_offset + j * 2], - qkv_bias=qkv_bias, - mlp_ratio=mlp_ratio, - conv_at_attn=conv_at_attn, - conv_at_ffn=conv_at_ffn, - )), - ('channel_block', - ChannelBlock( - embed_dims[i], - num_groups[i], - drop_path_rate=dpr[depth_offset + j * 2 + - 1], - qkv_bias=qkv_bias, - mlp_ratio=mlp_ratio, - conv_at_attn=conv_at_attn, - conv_at_ffn=conv_at_ffn, - ))])) for j in range(depths[i]) - ]) - blocks.append(block) - depth_offset += depths[i] * 2 - - self.convs = nn.ModuleList(convs) - self.blocks = nn.ModuleList(blocks) - - self.avgpool = nn.AdaptiveAvgPool1d(1) - - @property - def dim_out(self): - return self.embed_dims[-1] - - def forward_features_unpool(self, x): - """ - forward until avg pooling - Args: - x (_type_): input image tensor - """ - input_size = (x.size(2), x.size(3)) - for conv, block in zip(self.convs, self.blocks): - x, input_size = conv(x, input_size) - x, input_size = block(x, input_size) - return x - - def forward_features(self, x): - x = self.forward_features_unpool(x) - - # (batch_size, num_tokens, token_dim) - x = self.avgpool(x.transpose(1, 2)) - # (batch_size, 1, num_tokens) - x = torch.flatten(x, 1) - x = self.norms(x) - - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - - @classmethod - def from_config(cls, config): - return cls( - depths=config.depths, - embed_dims=config.dim_embed, - num_heads=config.num_heads, - num_groups=config.num_groups, - patch_size=config.patch_size, - patch_stride=config.patch_stride, - patch_padding=config.patch_padding, - patch_prenorm=config.patch_prenorm, - drop_path_rate=config.drop_path_rate, - window_size=config.window_size, - ) - - -# Language backbone and processor implementation -class Florence2LanguageModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.config = config - - self.vocab_size = config.vocab_size - - self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) - self.encoder = BartEncoder(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.encoder") - self.decoder = BartDecoder(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decoder") - - if self.config.tie_word_embeddings: - self.encoder.embed_tokens.weight = self.shared.weight - self.decoder.embed_tokens.weight = self.shared.weight - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - r""" - Args: - input_ids - Indices of *decoder* input sequence tokens in the vocabulary. - Padding will be ignored by default should you - provide it. - positions - Positions of *decoder* input sequence tokens. - encoder_input_ids - Indices of *encoder* input sequence tokens in the vocabulary. - encoder_positions: - Positions of *encoder* input sequence tokens. - Returns: - Model output torch.Tensor - """ - - encoder_hidden_states = None - - if ((inputs_embeds is not None and inputs_embeds.numel() > 0) - or encoder_input_ids.numel() > 0): - # Run encoder attention if a non-zero number of encoder tokens - # are provided as input - encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions, - inputs_embeds=inputs_embeds) - - # decoder outputs consists of - # (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - decoder_input_ids=input_ids, - decoder_positions=positions, - encoder_hidden_states=encoder_hidden_states) - - return decoder_outputs - - -class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - - self.config = config - self.model = Florence2LanguageModel(vllm_config=vllm_config, - prefix=f"{prefix}.model") - embed_scale = math.sqrt( - config.d_model) if config.scale_embedding else 1.0 - - self.vocab_size = config.vocab_size - self.lm_head = BartParallelLMHead(self.vocab_size, - config.d_model, - embed_scale=embed_scale) - if self.config.tie_word_embeddings: - self.lm_head.tie_weights(self.model.shared) - - self.logits_processor = LogitsProcessor(self.vocab_size, - config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - r""" - Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices - Returns: - Output torch.Tensor - """ - - return self.model(input_ids, - positions, - encoder_input_ids, - encoder_positions, - inputs_embeds=inputs_embeds) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.encoder.embed_tokens(input_ids) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if "final_logits_bias" in name: - continue - if self.config.tie_word_embeddings and ("embed_tokens" in name - or "lm_head" in name): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Florence2ProcessingInfo(BaseProcessingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": 1} - - def get_num_image_tokens(self) -> int: - processor_config = self.ctx.get_hf_image_processor_config() - return processor_config["image_seq_length"] - - -class Florence2DummyInputsBuilder( - BaseDummyInputsBuilder[Florence2ProcessingInfo]): - - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - return "" - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - - target_width = target_height = self.info.get_hf_config().projection_dim - - return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - -class Florence2MultiModalProcessor( - EncDecMultiModalProcessor[Florence2ProcessingInfo]): - - def _hf_processor_applies_updates( - self, - prompt_text: str, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> bool: - return False - - def create_encoder_prompt( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: - return prompt - - def create_decoder_prompt( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: - return [self.info.get_hf_config().eos_token_id] - - def _apply_hf_processor_tokens_only( - self, - prompt_tokens: list[int], - ) -> list[int]: - hf_processor = self.info.get_hf_processor() - tokenizer: BartTokenizer = hf_processor.tokenizer - prompt_text = tokenizer.decode(prompt_tokens) - # convert task tokens to prompt - prompt_text = hf_processor._construct_prompts([prompt_text])[0] - prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False) - return prompt_tokens - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - if mm_data: - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs) - else: - hf_processor = self.info.get_hf_processor() - tokenizer = hf_processor.tokenizer - prompt = hf_processor._construct_prompts([prompt])[0] - processed_outputs = tokenizer(prompt, - add_special_tokens=True, - return_tensors="pt") - return processed_outputs - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values=MultiModalFieldConfig.batched("image")) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - hf_config = self.info.get_hf_config() - pad_token_id = hf_config.pad_token_id - num_image_tokens = self.info.get_num_image_tokens() - image_tokens = [pad_token_id] * num_image_tokens - - return [ - PromptInsertion( - modality="image", - target=PromptIndexTargets.start(), - insertion=image_tokens, - ) - ] - - -@MULTIMODAL_REGISTRY.register_processor( - Florence2MultiModalProcessor, - info=Florence2ProcessingInfo, - dummy_inputs=Florence2DummyInputsBuilder) -class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsV0Only): - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - if modality.startswith("image"): - return None - - raise ValueError("Only image modality is supported") - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - processor_config = vllm_config.model_config.hf_image_processor_config - - self.config = config - self.vision_config = config.vision_config - self.processor_config = processor_config - assert config.vision_config.model_type == 'davit', ( - 'only DaViT is supported for now') - self.vision_tower = DaViT.from_config(config=config.vision_config) - self._build_image_projection_layers(config) - self.language_model = Florence2LanguageForConditionalGeneration( - vllm_config=vllm_config.with_hf_config(config.text_config), - prefix=f"{prefix}.language_model", - ) - self.pad_token_id = config.pad_token_id - - def _build_image_projection_layers(self, config: PretrainedConfig): - image_dim_out = config.vision_config.dim_embed[-1] - dim_projection = config.vision_config.projection_dim - self.image_projection = nn.Parameter( - torch.empty(image_dim_out, dim_projection)) - self.image_proj_norm = nn.LayerNorm(dim_projection) - image_pos_embed_config = config.vision_config.image_pos_embed - if image_pos_embed_config['type'] == 'learned_abs_2d': - self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( - embedding_dim=image_dim_out, - num_pos=image_pos_embed_config['max_pos_embeddings']) - else: - raise NotImplementedError("Florence2 only supports learned_abs_2d " - "as image position embedding.") - - self.image_feature_source = config.vision_config.image_feature_source - - # temporal embedding - visual_temporal_embedding_config = ( - self.vision_config.visual_temporal_embedding) - if visual_temporal_embedding_config['type'] == 'COSINE': - self.visual_temporal_embed = PositionalEmbeddingCosine1D( - embed_dim=image_dim_out, - max_seq_len=visual_temporal_embedding_config[ - 'max_temporal_embeddings']) - else: - raise NotImplementedError( - 'Florence2 only supports COSINE as temporal embedding.') - - def _parse_and_validate_image_input(self, **kwargs: object): - pixel_values: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "pixel_values", None) - image_embeds: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "image_embeds", None) - - if pixel_values is None and image_embeds is None: - return None - - if pixel_values is not None and image_embeds is not None: - raise ValueError( - "Both pixel values and image embeds are provided.") - - if pixel_values is not None: - size = self.processor_config["size"] - expected_h, expected_w = size["height"], size["width"] - - return Florence2ImagePixelInputs( - type="pixel_values", - data=flatten_bn(pixel_values, concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w - }, - ) - - if image_embeds is not None: - raise NotImplementedError - - raise AssertionError("This line should be unreachable.") - - def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor: - dtype = next(self.vision_tower.parameters()).dtype - pixel_values = pixel_values.to(dtype) - - batch_size, T = pixel_values.size(0), 1 - x = self.vision_tower.forward_features_unpool(pixel_values) - if self.image_pos_embed is not None: - x = x.view(batch_size * T, -1, x.shape[-1]) - num_tokens = x.shape[-2] - h, w = int(num_tokens**0.5), int(num_tokens**0.5) - assert h * w == num_tokens, ( - 'only support square feature maps for now') - x = x.view(batch_size * T, h, w, x.shape[-1]) - pos_embed = self.image_pos_embed(x) - x = x + pos_embed - x = x.view(batch_size, T * h * w, x.shape[-1]) - - if self.visual_temporal_embed is not None: - visual_temporal_embed = self.visual_temporal_embed( - x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) - x = x.view(batch_size, T, -1, - x.shape[-1]) + visual_temporal_embed.view( - 1, T, 1, x.shape[-1]) - - x_feat_dict = {} - - spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) - x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x - - temporal_avg_pool_x = x.view(batch_size, T, -1, - x.shape[-1]).mean(dim=1) - x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x - - x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] - x_feat_dict['last_frame'] = x - - new_x = [] - for _image_feature_source in self.image_feature_source: - if _image_feature_source not in x_feat_dict: - raise ValueError('invalid image feature source: {}'.format( - _image_feature_source)) - new_x.append(x_feat_dict[_image_feature_source]) - - x = torch.cat(new_x, dim=1) - - x = x @ self.image_projection - x = self.image_proj_norm(x) - - return x - - def _process_image_input( - self, image_input: Florence2ImagePixelInputs) -> torch.Tensor: - assert image_input["type"] == "pixel_values" - pixel_values = image_input["data"] - return self._encode_image(pixel_values) - - def get_language_model(self) -> torch.nn.Module: - return self.language_model - - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - return [] - vision_embeddings = self._process_image_input(image_input) - return vision_embeddings - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.pad_token_id) - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - *, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - r""" - Args: - input_ids - torch.Tensor of *decoder* input token ids. - positions - torch.Tensor of *decoder* position indices. - encoder_input_ids - torch.Tensor of *encoder* input token ids. - encoder_positions - torch.Tensor of *encoder* position indices - Returns: - Output torch.Tensor - """ - vision_embeddings = self.get_multimodal_embeddings(**kwargs) - if encoder_input_ids.numel() > 0 or vision_embeddings is not None: - inputs_embeds = self.get_input_embeddings(encoder_input_ids, - vision_embeddings) - else: - inputs_embeds = None - - hidden_states = self.language_model(input_ids, - positions, - encoder_input_ids, - encoder_positions, - inputs_embeds=inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 410c715d5241b..1263e3049a14a 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from torch import nn from transformers import Gemma3TextConfig -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from ...attention.layers.encoder_only_attention import EncoderOnlyAttention from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, @@ -169,16 +170,24 @@ class Gemma3Attention(nn.Module): rope_scaling=self.rope_scaling, ) - # Initialize the attention. - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn") + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + attn_cls = (EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY else Attention) + + self.attn = attn_cls(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index f3dc7dde46bdf..bee9fbd2c084a 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -512,7 +512,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, architectures=["Gemma3ForCausalLM"], ) logit_scale = getattr(config, "logit_scale", 1.0) - self.language_model.logits_processor.scale *= logit_scale + + if hasattr(self.language_model, "logits_processor"): + # The logits processor can be unset if we're using + # automatic conversion to pooling model. + self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -684,7 +688,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) - if (sliding_window := self.config.sliding_window) is not None: + sliding_window = self.config.text_config.sliding_window + if sliding_window is not None: # Create a local causal mask with sliding window (1024). local_attn_mask = torch.ones_like(global_attn_mask) local_attn_mask = torch.tril(local_attn_mask, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index c25bbcd420c39..8d3079aee0dfb 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -5,6 +5,7 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast import numpy as np import torch +# yapf: disable from torch import nn from transformers import AutoModel, BatchFeature from transformers.models.gemma3n import (Gemma3nAudioConfig, @@ -30,7 +31,6 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems, MultiModalDataParser) -# yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, MultiModalPromptUpdates, @@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict): class Gemma3nAudioInputs(TypedDict): - input_features: torch.Tensor + input_features: Union[torch.Tensor, list[torch.Tensor]] + input_features_padded: torch.Tensor """Shape: `(batch_size * num_audio, seq_length, num_features)`""" input_features_mask: torch.Tensor """Shape: `(batch_size * num_audio, seq_length)`""" @@ -178,7 +179,7 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] ) -> BatchFeature: # HF Transformers audio processor no longer accepts `audios` key. - # We pop `audios` and replace it with `audio` key to surpress + # We pop `audios` and replace it with `audio` key to suppress # the warning. if 'audios' in mm_data: mm_data['audio'] = mm_data.pop('audios') @@ -188,8 +189,13 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] mm_kwargs, tok_kwargs, ) + if 'input_features' in processed_outputs: - # Avoid padding since we need the output of each item to be + # Padding enables audio_tower to run in batched mode + processed_outputs["input_features_padded"] = \ + processed_outputs["input_features"] + + # Unpad features here since we need the output of each item to be # independent of other items for the cache to work correctly unpadded_features = [ f[mask] for f, mask in zip( @@ -206,9 +212,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values=MultiModalFieldConfig.batched("image"), - input_features=MultiModalFieldConfig.batched("audio"), - input_features_mask=MultiModalFieldConfig.batched("audio")) + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + input_features=MultiModalFieldConfig.batched("audio"), + input_features_padded=MultiModalFieldConfig.batched("audio"), + input_features_mask=MultiModalFieldConfig.batched("audio")) def _get_prompt_updates( self, @@ -453,9 +461,6 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, self.multimodal_config = multimodal_config self.vocab_size = config.text_config.vocab_size - self.sliding_window = getattr(config.text_config, - "interleaved_sliding_window", None) - self.vision_tower = AutoModel.from_config(config=config.vision_config) self.audio_tower = AutoModel.from_config(config=config.audio_config) self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, @@ -516,9 +521,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, if input_features_mask is None: return None + input_features_padded = kwargs.pop("input_features_padded", None) + if input_features_padded is None: + return None + return Gemma3nAudioInputs( input_features=input_features, input_features_mask=input_features_mask, + input_features_padded=input_features_padded, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -564,7 +574,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, audio_input: Gemma3nAudioInputs, ) -> list[torch.Tensor]: assert self.audio_tower is not None - input_features = audio_input["input_features"].squeeze(1) + # Run on padded features to enable batching + input_features = audio_input["input_features_padded"].squeeze(1) input_features_mask = audio_input["input_features_mask"].squeeze(1) audio_outputs, audio_mask = self.audio_tower(input_features, ~input_features_mask) @@ -572,10 +583,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, # ruff: noqa # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the - # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will + # text to account for this. However, the audio preprocessing and encoder do not guarantee they will # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad - # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab. + # the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab. # TODO precompute and cache padding audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index f9fd5163d66b4..cbf327ce02b6b 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -36,7 +36,9 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from packaging.version import Version from transformers import BatchFeature +from transformers import __version__ as TRANSFORMERS_VERSION from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.image_processing_glm4v import ( Glm4vImageProcessor, smart_resize) @@ -44,6 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import ( Glm4vVideoProcessor) from transformers.video_utils import VideoMetadata +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import (get_tensor_model_parallel_world_size, parallel_state) @@ -51,14 +54,10 @@ from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.layernorm import RMSNorm -# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - MergedReplicatedLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) -# yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -174,20 +173,22 @@ class Glm4vVisionMLP(nn.Module): use_data_parallel: bool = False, ): super().__init__() - cls_gate_up = (MergedReplicatedLinear - if use_data_parallel else MergedColumnParallelLinear) - self.gate_up_proj = cls_gate_up(input_size=in_features, - output_sizes=[hidden_features] * 2, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - cls_down = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.down_proj = cls_down(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.gate_up_proj = MergedColumnParallelLinear( + input_size=in_features, + output_sizes=[hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, + ) + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor): @@ -234,51 +235,43 @@ class Glm4vVisionAttention(nn.Module): # Per attention head and per partition values. self.tp_size = (1 if use_data_parallel else get_tensor_model_parallel_world_size()) - self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.tp_rank = (0 if use_data_parallel else + parallel_state.get_tensor_model_parallel_rank()) self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - if use_data_parallel: - self.qkv = ReplicatedLinear( - input_size=embed_dim, - output_size=3 * projection_size, - bias=False, - quant_config=quant_config, - # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg - prefix=f"{prefix}.qkv_proj" - if quant_config else f"{prefix}.qkv", - ) - self.proj = ReplicatedLinear( - input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - bias=False, - ) - else: - self.qkv = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.hidden_size_per_attention_head, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, - bias=False, - quant_config=quant_config, - # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg - prefix=f"{prefix}.qkv_proj" - if quant_config else f"{prefix}.qkv", - ) - self.proj = RowParallelLinear( - input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - bias=False, - ) + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=False, + quant_config=quant_config, + # Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg + prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + bias=False, + disable_tp=use_data_parallel, + ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, @@ -290,23 +283,10 @@ class Glm4vVisionAttention(nn.Module): def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape - if self.tp_size > 1: - qkv = all_gather_interleave(qkv, self.qkv.hidden_size, - self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) - # 3 * [s, b, head * head_dim] - if self.tp_size > 1: - splitter = partial( - dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size, - ) - q = splitter(q)[self.tp_rank] - k = splitter(k)[self.tp_rank] - v = splitter(v)[self.tp_rank] - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] new_shape = ( seq_len, @@ -341,7 +321,10 @@ class Glm4vVisionAttention(nn.Module): if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -438,15 +421,16 @@ class Glm4vVisionBlock(nn.Module): max_seqlen: Optional[int] = None, # Only used for Flash Attention seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - x = x + self.attn( + x_attn = self.attn( self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, seqlens=seqlens, ) + x_fused_norm, residual = self.norm2(x, residual=x_attn) + x = residual + self.mlp(x_fused_norm) - x = x + self.mlp(self.norm2(x)) return x @@ -494,41 +478,31 @@ class Glm4vPatchMerger(nn.Module): ) -> None: super().__init__() self.hidden_size = d_model - if use_data_parallel: - self.proj = ReplicatedLinear( - input_size=self.hidden_size, - output_size=self.hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) - else: - self.proj = ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - bias=bias, - gather_output=True, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) + self.proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=bias, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) self.post_projection_norm = nn.LayerNorm(self.hidden_size) - cls_gate_up = (MergedReplicatedLinear - if use_data_parallel else MergedColumnParallelLinear) - self.gate_up_proj = cls_gate_up( + self.gate_up_proj = MergedColumnParallelLinear( input_size=self.hidden_size, output_sizes=[context_dim] * 2, bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, ) - cls_down = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.down_proj = cls_down( + self.down_proj = RowParallelLinear( context_dim, self.hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, ) self.act_fn = SiluAndMul() self.extra_activation_func = nn.GELU() @@ -756,7 +730,11 @@ class Glm4vVisionTransformer(nn.Module): self.post_layernorm = RMSNorm(vision_config.hidden_size, eps=vision_config.rms_norm_eps) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -1025,28 +1003,32 @@ class Glm4vProcessingInfo(BaseProcessingInfo): max_frame_idx = meta_frames - 1 duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1) - if duration <= video_processor.max_duration: - n = int(math.floor(duration * video_processor.fps)) - frame_indices = [ - min( - max_frame_idx, - int(math.ceil(i * video_fps / video_processor.fps)), - ) for i in range(n) - ] + do_sample_frames = metadata["do_sample_frames"] + if not do_sample_frames: + frame_indices = metadata["frames_indices"] else: - num_samples = int(video_processor.max_duration * - video_processor.fps) - if num_samples >= meta_frames: - frame_indices = list(range(meta_frames)) - else: - target_seconds = np.linspace(0, - duration, - num_samples, - endpoint=True) + if duration <= video_processor.max_duration: + n = int(math.floor(duration * video_processor.fps)) frame_indices = [ - min(max_frame_idx, int(math.ceil(t * video_fps))) - for t in target_seconds + min( + max_frame_idx, + int(math.ceil(i * video_fps / video_processor.fps)), + ) for i in range(n) ] + else: + num_samples = int(video_processor.max_duration * + video_processor.fps) + if num_samples >= meta_frames: + frame_indices = list(range(meta_frames)) + else: + target_seconds = np.linspace(0, + duration, + num_samples, + endpoint=True) + frame_indices = [ + min(max_frame_idx, int(math.ceil(t * video_fps))) + for t in target_seconds + ] seen, uniq = set(), [] for idx in frame_indices: @@ -1064,6 +1046,43 @@ class Glm4vProcessingInfo(BaseProcessingInfo): selected_timestamps.append(timestamps_list[idx]) return selected_timestamps + def _construct_video_placeholder( + self, + video_array: np.ndarray, + metadata: dict[str, Any], + grid_thw: torch.Tensor, + ) -> str: + hf_processor = self.get_hf_processor() + tokenizer = self.get_tokenizer() + image_processor = hf_processor.image_processor + + hf_config = self.get_hf_config() + boi_token_id = hf_config.image_start_token_id + eoi_token_id = hf_config.image_end_token_id + bov_token_id = hf_config.video_start_token_id + eov_token_id = hf_config.video_end_token_id + merge_length = image_processor.merge_size**2 + + assert isinstance(grid_thw, torch.Tensor) + timestamps = self._get_video_second_idx(metadata, len(video_array)) + frames_idx_token = [ + tokenizer.encode(str(i), add_special_tokens=False) + for i in timestamps + ] + T, H, W = grid_thw + num_tokens_per_frame = int(H * W) // merge_length + placeholder = [] + placeholder.append(bov_token_id) + for frame_idx in frames_idx_token: + placeholder.append(boi_token_id) + placeholder.extend([hf_processor.video_token_id] * + num_tokens_per_frame) + placeholder.append(eoi_token_id) + placeholder.extend(frame_idx) + placeholder.append(eov_token_id) + + return placeholder + class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): @@ -1126,7 +1145,9 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): "fps": 2.0, "duration": num_frames / 2.0, "total_num_frames": num_frames, + "frames_indices": [i for i in range(num_frames)], "video_backend": "opencv", + "do_sample_frames": False, } video_item = (video.copy(), video_metadata) video_items.append(video_item) @@ -1159,48 +1180,56 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): for item in mm_data.pop("videos", []): video_array, metadata = item - # FIXME(Isotr0py): Activate the below logic after we can disable - # resampling from video loader backend. - # assert metadata["total_num_frames"] == len(video_array), ( - # f"Total frames {metadata['total_num_frames']} does not " - # f"match the length of video array {len(video_array)}.") - - # NOTE: Temporary workaround for resampled videos. - # this can cause a divergence with HF implementation if - # the input video is resampled in advance. - - if metadata["total_num_frames"] != len(video_array): - logger.warning( - "Total frames in metadata " - "(%s) does not match the length of " - "video array %s. This can " - "be because the video is resampled " - "in advance. This may cause " - "a divergence with HF implementation.", - metadata["total_num_frames"], - len(video_array), - ) - metadata["total_num_frames"] = len(video_array) - metadata = VideoMetadata(**metadata) + # don't update mm_kwargs inplace + video_mm_kwargs = dict(**mm_kwargs) + video_mm_kwargs["do_sample_frames"] = metadata.get( + "do_sample_frames", True) video_mm_data = dict() video_mm_data["videos"] = [[video_array]] - video_mm_data["video_metadata"] = [[metadata]] + + # backward compatibility for Transformers 4.55 + unuse_metadata = ["do_sample_frames"] + if not hasattr( + VideoMetadata, + "frames_indices") and "frames_indices" in metadata: + unuse_metadata.append("frames_indices") + + video_mm_data["video_metadata"] = [[ + VideoMetadata( + **{ + k: metadata[k] + for k in metadata if k not in unuse_metadata + }) + ]] video_outputs = super()._call_hf_processor( prompt="<|begin_of_video|><|video|><|end_of_video|>", mm_data=video_mm_data, - mm_kwargs=mm_kwargs, + mm_kwargs=video_mm_kwargs, tok_kwargs=tok_kwargs, ) - input_ids = video_outputs.pop("input_ids") - input_ids[input_ids == processor.image_token_id] = ( - processor.video_token_id) - video_placeholder = processor.tokenizer.batch_decode( - input_ids)[0] + if not video_mm_kwargs["do_sample_frames"] and Version( + TRANSFORMERS_VERSION) < Version("4.56.0"): + # Transformers v4.55 has incorrect timestamps issue for + # skip sampling. We construct the placeholder manually to + # get placeholders with correct timestamps. + placeholder = self.info._construct_video_placeholder( + video_array, + metadata, + video_outputs["video_grid_thw"].squeeze(0), + ) + video_placeholder = processor.tokenizer.decode(placeholder) + else: + input_ids = video_outputs.pop("input_ids") + input_ids[input_ids == processor.image_token_id] = ( + processor.video_token_id) + video_placeholder = processor.tokenizer.batch_decode( + input_ids)[0] prompt = prompt.replace( "<|begin_of_video|><|video|><|end_of_video|>", video_placeholder, + 1, ) video_grid_thw_lst.append(video_outputs["video_grid_thw"]) @@ -1243,14 +1272,6 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor( **hf_processor_mm_kwargs) - tokenizer = self.info.get_tokenizer() - hf_config = self.info.get_hf_config() - - boi_token_id = hf_config.image_start_token_id - eoi_token_id = hf_config.image_end_token_id - - bov_token_id = hf_config.video_start_token_id - eov_token_id = hf_config.video_end_token_id merge_length = image_processor.merge_size**2 @@ -1268,21 +1289,8 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): assert isinstance(grid_thw, torch.Tensor) video, metadata = mm_items["video"][item_idx] - timestamps = self.info._get_video_second_idx(metadata, len(video)) - frames_idx_token = [ - tokenizer.encode(str(i), add_special_tokens=False) - for i in timestamps - ] - num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length - placeholder = [] - placeholder.append(bov_token_id) - for frame_idx in frames_idx_token: - placeholder.append(boi_token_id) - placeholder.extend([hf_processor.video_token_id] * - num_tokens_per_frame) - placeholder.append(eoi_token_id) - placeholder.extend(frame_idx) - placeholder.append(eov_token_id) + placeholder = self.info._construct_video_placeholder( + video, metadata, grid_thw) return PromptUpdateDetails.select_token_id( placeholder, embed_token_id=hf_processor.video_token_id, @@ -1383,7 +1391,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) @@ -1457,6 +1465,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) @@ -1471,13 +1480,15 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist()) merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return image_embeds.split(sizes.tolist()) + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return image_embeds.split(sizes) def _process_video_input( self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) @@ -1494,8 +1505,9 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw=grid_thw.tolist()) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - return video_embeds.split(sizes.tolist()) + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -1524,7 +1536,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -1600,17 +1612,10 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, **NOTE**: If mrope is enabled (default setting for GLM-4V opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. - second_per_grid_ts: Tensor `(num_videos)` of video time interval ( - in seconds) for each grid along the temporal dimension in the - 3D position IDs. `None` if no videos are passed. + intermediate_tensors: Optional intermediate tensors for pipeline + parallelism. + inputs_embeds: Optional pre-computed input embeddings. + **kwargs: Additional keyword arguments. """ if intermediate_tensors is not None: inputs_embeds = None diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 06ed453ec29f9..1acbd18091fb3 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -46,6 +46,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -146,24 +147,6 @@ class Glm4MoE(nn.Module): self.physical_expert_end = (self.physical_expert_start + self.n_local_physical_experts) - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func="sigmoid", - routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) - if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) @@ -172,23 +155,68 @@ class Glm4MoE(nn.Module): intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + reduce_results=False, prefix=f"{prefix}.shared_experts", ) + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + else: + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states.to(dtype=torch.float32)) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + + fused_moe_out = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.shared_experts is not None: + shared_output, final_hidden_states = fused_moe_out + assert shared_output is not None + final_hidden_states = \ + final_hidden_states * self.routed_scaling_factor\ + + shared_output + else: + final_hidden_states = fused_moe_out * self.routed_scaling_factor + if self.tp_size > 1: final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( @@ -605,7 +633,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 4446b5ab181c1..0f6521e44e6be 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -339,7 +339,10 @@ class GPT2ForSequenceClassification(nn.Module): config = vllm_config.model_config.hf_config self.transformer = GPT2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt2")) - self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) + self.score = nn.Linear(config.n_embd, + config.num_labels, + bias=False, + dtype=vllm_config.model_config.head_dtype) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -348,7 +351,7 @@ class GPT2ForSequenceClassification(nn.Module): "encode": Pooler.for_encode(pooler_config), "classify": - Pooler.for_classify(pooler_config, classifier=None), + Pooler.for_classify(pooler_config, classifier=self.score), }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -367,8 +370,7 @@ class GPT2ForSequenceClassification(nn.Module): position_ids=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) - logits = self.score(hidden_states) - return logits + return hidden_states def _add_transformer_prefix( diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index d5c2604145eed..745d0b7759991 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -302,7 +302,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head = ParallelLMHead( self.transformer.vocab_size, self.transformer.embed_dim, - org_num_embeddings=self.config.vocab_size) + org_num_embeddings=self.config.vocab_size, + prefix=maybe_prefix(prefix, "lm_head")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 584c7f5d8a2d1..77df6ae6f30c8 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -306,6 +306,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP): config.n_embd, bias=True, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index e0b4df7728757..990a1d6d883a1 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -655,6 +655,7 @@ class GptOssForCausalLM(nn.Module, SupportsPP): self.lm_head = ParallelLMHead( self.config.vocab_size, self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.config.vocab_size) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index f8ba0229210a9..4f9cc2532bd8c 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -434,6 +434,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 07ad75bcf1665..da16c72000c0e 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -487,6 +487,7 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 306775af68065..b42df3ad86508 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -17,7 +17,7 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargsItems +from vllm.multimodal.inputs import MultiModalKwargsItems, MultiModalUUIDDict from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (MultiModalProcessingInfo, @@ -479,7 +479,7 @@ class H2OVLMultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is @@ -491,7 +491,7 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return super()._cached_apply_hf_processor( @@ -499,7 +499,7 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index fbba849a76f23..4110c8a1fd08d 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -23,7 +23,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only HunYuan model compatible with HuggingFace weights.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from typing import Any, Optional, Union import regex as re @@ -33,8 +34,8 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul @@ -56,9 +57,9 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_layers) + make_layers, maybe_prefix) def _is_moe(config: PretrainedConfig) -> bool: @@ -355,10 +356,16 @@ class HunYuanSparseMoeBlock(nn.Module): quant_config: Optional[QuantizationConfig] = None, layer_id: int = -1, prefix: str = "", + enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -379,8 +386,23 @@ class HunYuanSparseMoeBlock(nn.Module): config.moe_intermediate_size, int) else config.moe_intermediate_size[layer_id]) + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + self.experts = FusedMoE( - num_experts=config.num_experts, + num_experts=self.n_routed_experts, top_k=top_k, hidden_size=config.hidden_size, intermediate_size=intermediate_size, @@ -388,6 +410,8 @@ class HunYuanSparseMoeBlock(nn.Module): renormalize=top_k > 1, quant_config=quant_config, prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, ) self.gate = ReplicatedLinear(config.hidden_size, @@ -446,6 +470,7 @@ class HunYuanDecoderLayer(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", layer_id: int = -1, + enable_eplb: bool = False, ) -> None: super().__init__() assert layer_id >= 0 @@ -509,6 +534,7 @@ class HunYuanDecoderLayer(nn.Module): quant_config=quant_config, layer_id=layer_id, prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, ) else: self.mlp = HunYuanMLP( @@ -562,6 +588,9 @@ class HunYuanModel(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + eplb_config = vllm_config.parallel_config.eplb_config + enable_eplb = vllm_config.parallel_config.enable_eplb + self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config self.quant_config = quant_config @@ -588,6 +617,7 @@ class HunYuanModel(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=prefix, + enable_eplb=enable_eplb, ), prefix=f"{prefix}.layers", ) @@ -674,6 +704,7 @@ class HunYuanModel(nn.Module): ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, ) else: return [] @@ -803,25 +834,43 @@ class HunYuanModel(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + # this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader( param, loaded_weight, - name, + name_mapped, shard_id=shard_id, expert_id=expert_id, + return_success=True, ) - break + if success: + name = name_mapped + break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: @@ -841,7 +890,7 @@ class HunYuanModel(nn.Module): return loaded_params -class HunYuanV1Base(nn.Module, SupportsLoRA): +class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -871,6 +920,7 @@ class HunYuanV1Base(nn.Module, SupportsLoRA): org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight @@ -882,6 +932,64 @@ class HunYuanV1Base(nn.Module, SupportsLoRA): else: self.lm_head = PPMissingLayer() + # Set MoE hyperparameters + self.expert_weights = [] + self.num_expert_groups = 1 + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, HunYuanDecoderLayer) + if isinstance(layer.mlp, HunYuanSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No HunYuanMoE layer found in model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + self.expert_weights.append(layer.get_expert_weights()) + # Register the expert weights. + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, HunYuanSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 0ca2e9e4bb688..76737a4428232 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -31,7 +31,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -139,37 +138,23 @@ class Idefics2VisionAttention(nn.Module): assert self.num_heads % tp_size == 0 self.num_heads_per_partition = self.num_heads // tp_size - if use_data_parallel: - self.q_size = self.num_heads * self.head_dim - self.qkv_proj = ReplicatedLinear( - self.embed_dim, - 3 * self.q_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = ReplicatedLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - else: - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, + ) + # Use unified MultiHeadAttention with Flash Attention support self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -181,6 +166,8 @@ class Idefics2VisionAttention(nn.Module): hidden_states ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim query_states, key_states, value_states = qkv.chunk(3, dim=-1) + + # Use unified MultiHeadAttention implementation out = self.attn(query_states, key_states, value_states) attn_output, _ = self.out_proj(out) return attn_output @@ -198,23 +185,21 @@ class Idefics2VisionMLP(nn.Module): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1( + self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.fc2 = cls_fc2( + self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -386,30 +371,6 @@ class Idefics2VisionTransformer(nn.Module): last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state - def _consolidate_qkv_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> Iterable[tuple[str, torch.Tensor]]: - qkv_idx_mappings = { - ".self_attn.q_proj": 0, - ".self_attn.k_proj": 1, - ".self_attn.v_proj": 2, - } - qkv_weights = {} - for name, loaded_weight in weights: - for weight_name, idx in qkv_idx_mappings.items(): - if weight_name not in name: - continue - new_name = name.replace(weight_name, ".self_attn.qkv_proj") - if new_name not in qkv_weights: - qkv_weights[new_name] = [None] * 3 - qkv_weights[new_name][idx] = loaded_weight - break - else: - yield name, loaded_weight - for key, weight in qkv_weights.items(): - qkv_weight = torch.cat(weight, dim=0) - yield key, qkv_weight - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -422,9 +383,6 @@ class Idefics2VisionTransformer(nn.Module): loaded_params: set[str] = set() layer_count = len(self.encoder.layers) - if self.use_data_parallel: - weights = self._consolidate_qkv_weights(weights) - for name, loaded_weight in weights: # skip pooling header if name.startswith("head."): diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 63307470d959b..9153a0e2c1e5a 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -606,6 +606,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, config.text_config.vocab_size, config.text_config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.text_config.tie_word_embeddings: self.lm_head.weight = self.model.text_model.wte.weight diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d5b71b057831b..8f8e300c84d71 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -823,7 +823,7 @@ class SupportsEagle3(Protocol): Args: layers: Tuple of layer indices that should output auxiliary - hidden states. + hidden states. """ ... diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 58e8163e0b26e..118cce810a1f2 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -25,9 +25,11 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import run_dp_sharded_vision_model NORM2FN = { 'rms_norm': RMSNorm, @@ -137,6 +139,7 @@ class InternParallelAttention(nn.Module): *, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -150,8 +153,10 @@ class InternParallelAttention(nn.Module): f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' f' {self.num_heads}).') - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + self.tp_rank = (0 if use_data_parallel else + get_tensor_model_parallel_rank()) # Additional dummy heads are used to enable TP for common GPU counts. self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim @@ -159,14 +164,23 @@ class InternParallelAttention(nn.Module): self.tp_size) self.scale = self.head_dim**-0.5 - self.qkv = QKVParallelLinear( - self.embed_dim, - self.head_dim, - num_dummy_heads + self.num_heads, - bias=config.qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv", - ) + if use_data_parallel: + self.qkv = ReplicatedLinear( + self.embed_dim, + 3 * self.head_dim * self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + else: + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + num_dummy_heads + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) self.qk_normalization = config.qk_normalization @@ -178,12 +192,20 @@ class InternParallelAttention(nn.Module): eps=config.layer_norm_eps, var_hidden_size=self.embed_dim) - self.proj = RowParallelLinear( - self.dummy_dim, - self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - ) + if use_data_parallel: + self.proj = ReplicatedLinear( + self.dummy_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + else: + self.proj = RowParallelLinear( + self.dummy_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -255,6 +277,10 @@ class InternSdpaAttention(nn.Module): self.proj = nn.Linear(self.dummy_dim, self.embed_dim) + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, + self.scale) + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x) @@ -268,12 +294,9 @@ class InternSdpaAttention(nn.Module): B_, N_, H_, D_ = q.shape q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, -1) + # Use unified MultiHeadAttention with automatic backend selection + x = self.attn(q, k, v) x = self.proj(x) return x @@ -286,21 +309,26 @@ class InternMLP(nn.Module): config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.fc2 = cls_fc2(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -319,6 +347,7 @@ class InternVisionEncoderLayer(nn.Module): *, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -329,11 +358,13 @@ class InternVisionEncoderLayer(nn.Module): self.attn = self._init_attn(config, quant_config, num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.mlp = InternMLP(config, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.norm2 = NORM2FN[self.norm_type](self.embed_dim, @@ -351,16 +382,20 @@ class InternVisionEncoderLayer(nn.Module): *, num_dummy_heads: int, prefix: str = "", + use_data_parallel: bool = False, ): # fallback to sdpa attention if tp unavailable - tp_size = get_tensor_model_parallel_world_size() + # tp_size = get_tensor_model_parallel_world_size() + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) num_heads = config.num_attention_heads if (num_heads + num_dummy_heads) % tp_size == 0: return InternParallelAttention(config, quant_config=quant_config, num_dummy_heads=num_dummy_heads, - prefix=prefix) + prefix=prefix, + use_data_parallel=use_data_parallel) return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) @@ -387,6 +422,7 @@ class InternVisionEncoder(nn.Module): num_hidden_layers_override: Optional[int] = None, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() @@ -401,7 +437,8 @@ class InternVisionEncoder(nn.Module): InternVisionEncoderLayer(config, quant_config, num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.layers.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(num_hidden_layers) ]) @@ -428,10 +465,12 @@ class InternVisionModel(nn.Module): num_hidden_layers_override: Optional[int] = None, num_dummy_heads: int = 0, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config + self.use_data_parallel = use_data_parallel self.embeddings = InternVisionEmbeddings(config) self.encoder = InternVisionEncoder( @@ -440,6 +479,7 @@ class InternVisionModel(nn.Module): num_hidden_layers_override=num_hidden_layers_override, num_dummy_heads=num_dummy_heads, prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel, ) def get_input_embeddings(self): @@ -463,7 +503,11 @@ class InternVisionModel(nn.Module): raise ValueError( f'wrong pixel_values size: {pixel_values.shape}') - encoder_outputs = self.encoder(inputs_embeds=hidden_states) + if self.use_data_parallel: + encoder_outputs = run_dp_sharded_vision_model( + hidden_states, self.encoder) + else: + encoder_outputs = self.encoder(inputs_embeds=hidden_states) return encoder_outputs diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 320e8d9d480c3..ce94328797ed6 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -423,13 +423,15 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): delattr(self, attr) config = vllm_config.model_config.hf_config - self.v_head = RowParallelLinear( - config.hidden_size, - 1, - bias=False, - input_is_parallel=False, - prefix=maybe_prefix(prefix, "v_head"), - ) + self.head_dtype = vllm_config.model_config.head_dtype + + self.v_head = RowParallelLinear(config.hidden_size, + 1, + bias=False, + input_is_parallel=False, + params_dtype=self.head_dtype, + prefix=maybe_prefix(prefix, "v_head"), + return_bias=False) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -446,5 +448,6 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) - logits, _ = self.v_head(hidden_states) + hidden_states = hidden_states.to(self.head_dtype) + logits = self.v_head(hidden_states) return logits diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 26e358f9394c6..b59d1b88cf5ce 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -492,7 +492,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - # transformers InternVLProcessor uses <IMG_CONTEXT> as the seperator + # transformers InternVLProcessor uses <IMG_CONTEXT> as the separator # refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116 if modality.startswith("image"): return '<IMG_CONTEXT>' @@ -738,7 +738,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py index 300ed17ecaabc..eb6b685d03dc5 100644 --- a/vllm/model_executor/models/interns1_vit.py +++ b/vllm/model_executor/models/interns1_vit.py @@ -12,10 +12,10 @@ from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F from transformers import PretrainedConfig from transformers.utils import torch_int +from vllm.attention.layer import MultiHeadAttention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -206,6 +206,10 @@ class InternSdpaAttention(nn.Module): self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, + self.scale) + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape @@ -213,20 +217,13 @@ class InternSdpaAttention(nn.Module): k = self.k_proj(x) v = self.v_proj(x) - q = q.view(B, N, self.num_heads, self.head_dim) - k = k.view(B, N, self.num_heads, self.head_dim) - v = v.view(B, N, self.num_heads, self.head_dim) - if self.qk_normalization: B_, N_, H_, D_ = q.shape q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, -1) + # Use unified MultiHeadAttention with automatic backend selection + x = self.attn(q, k, v) x = self.projection_layer(x) return x diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b09ed7bbe72a3..6a5c565b52e85 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -7,6 +7,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +import os from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Literal, Optional, TypeVar, Union @@ -37,6 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import set_default_torch_num_threads from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -115,13 +117,26 @@ InternVLVideoInputs = Union[InternVLVideoPixelInputs, # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD - return T.Compose([ + transform = T.Compose([ T.Lambda(lambda img: convert_image_mode(img, 'RGB')), T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) + # Image transformation operations (which include tensor computations + # on the CPU) can occupy a substantial number of CPU cores, introducing + # overhead due to CPU contention. This issue becomes particularly + # noticeable when deploying multiple vLLM instances on a single machine. + # Therefore, it is necessary to limit the number of threads allocated to + # image transformation tasks. + num_threads = int(os.environ.get("OMP_NUM_THREADS", "1")) + + def apply(img): + with set_default_torch_num_threads(num_threads): + return transform(img) + + return apply # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B @@ -1020,6 +1035,8 @@ class InternVLMultiModalProcessor( class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -1038,6 +1055,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self._patch_quant_config(config, quant_config) image_size = config.force_image_size or config.vision_config.image_size @@ -1105,7 +1123,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers, prefix=prefix, - ) + use_data_parallel=self.use_data_parallel) else: return InternVisionPatchModel(config.vision_config) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 91a06dd502474..4fee8c32fd581 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -302,7 +302,9 @@ class JAISLMHeadModel(nn.Module, SupportsPP): self.lm_head = self.transformer.wte else: self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size) + self.config.hidden_size, + prefix=maybe_prefix( + prefix, "lm_head")) if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index aebd2cbe2e999..5b8fbc7226866 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -502,6 +502,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) # Used to track and store by the Mamba cache between steps. self.mamba_cache: Optional[MambaCacheManager] = None @@ -613,7 +614,7 @@ class JambaForSequenceClassification(JambaForCausalLM): config.hidden_size, num_labels, bias=score_bias, - dtype=torch.float32, + dtype=vllm_config.model_config.head_dtype, ) pooler_config = vllm_config.model_config.pooler_config diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 8c64f636c6a0f..f8c2a1e507a74 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -5,9 +5,9 @@ from typing import Optional import torch import torch.nn as nn -from transformers import BatchFeature, PretrainedConfig +from transformers import BatchFeature -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -28,13 +28,17 @@ logger = init_logger(__name__) class JinaVLScorer(nn.Module): - def __init__(self, config: PretrainedConfig): + def __init__(self, model_config: "ModelConfig"): super().__init__() + config = model_config.hf_config + head_dtype = model_config.head_dtype self.dense = ColumnParallelLinear(config.hidden_size, config.hidden_size, + params_dtype=head_dtype, bias=True) self.out_proj = RowParallelLinear(config.hidden_size, config.num_labels, + params_dtype=head_dtype, bias=True) def forward(self, x, **kwargs): @@ -88,21 +92,17 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "qwen2_vl")) - config = vllm_config.model_config.hf_config pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - # logit bias for sigmoid normalization - self.LOGIT_BIAS = 2.65 - - self.score = JinaVLScorer(config) + self.score = JinaVLScorer(vllm_config.model_config) self.pooler = DispatchPooler({ "encode": Pooler.for_encode(pooler_config), "classify": - Pooler.for_classify(pooler_config, classifier=None), + Pooler.for_classify(pooler_config, classifier=self.score), "score": - Pooler.for_classify(pooler_config, classifier=None), + Pooler.for_classify(pooler_config, classifier=self.score), }) @classmethod @@ -137,9 +137,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, inputs_embeds=inputs_embeds, **kwargs, ) - - logits = self.score(hidden_states) - self.LOGIT_BIAS - return logits + return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 710b805acb3ea..afe33b4d4ad26 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPooling) from transformers.utils import torch_int +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -374,7 +375,16 @@ class KeyeSiglipAttention(nn.Module): ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.head_dim, dtype=torch.get_default_dtype()) + + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now.") @@ -428,7 +438,10 @@ class KeyeSiglipAttention(nn.Module): ) if self.attn_backend == _Backend.FLASH_ATTN: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -1507,15 +1520,9 @@ class BaseKeyeModule(nn.Module): batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. + otherwise it will be `(seq_len,)`. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ if intermediate_tensors is not None: inputs_embeds = None @@ -1598,12 +1605,12 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) elif is_list_of(mm_input, torch.Tensor): if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2 for p in mm_input): return mm_input - return torch.concat(list(mm_input)) + return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeImageInputs]: diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 605c6d3eaf643..93a3bf5f98f7b 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -58,17 +58,18 @@ def split_thw(grid_thw: torch.Tensor) -> torch.Tensor: return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0) -def get_num_patches(grid_thw: torch.Tensor, num_frames: Union[list[int], - torch.Tensor]): +def get_num_patches(grid_thw: torch.Tensor, + num_frames: Union[list[int], torch.Tensor]) -> list[int]: """ Return num_patches per video. Args: - t: tensor with shape [N, ...] where each item is a list/tensor - cu_seqlens: list indicating the boundaries of groups + grid_thw: Tensor with shape [N, 3] containing temporal, height, width + dimensions + num_frames: List or tensor indicating the number of frames per video Returns: - list of ints representing the sum of products for each group + List of ints representing the number of patches for each video Examples: >>> # Suppose there are 2 videos with a total of 3 grids @@ -491,14 +492,14 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, if mm_input.ndim == expected_dim: return mm_input elif mm_input.ndim == expected_dim + 1: - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, *mm_input.shape[2:]) else: raise ValueError( f"{name} should be {expected_dim}D or " f"batched {expected_dim}D tensor." f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})") else: - return torch.concat(list(mm_input)) + return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]: diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 4f76d4afdb20e..94a5933a61416 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -328,6 +328,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, config.text_config.hidden_size, org_num_embeddings=self.config.text_config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), ) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a22bde194f5de..f8ea2111fed57 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -171,7 +171,22 @@ class LlamaAttention(nn.Module): sliding_window = None if layer_types := getattr(config, "layer_types", None): - is_sliding = layer_types[layer_idx] == "sliding_attention" + # Fix for Eagle3 compatibility: + # for draft models, subtract target layer count + # to get draft-relative layer index starting from 0 + if hasattr(config, 'target_layer_count'): + # This is a draft model, + # adjust layer_idx to be relative to draft layers + effective_layer_idx = layer_idx - config.target_layer_count + else: + # This is a target model, use layer_idx directly + effective_layer_idx = layer_idx + assert effective_layer_idx < len(layer_types), \ + f"effective_layer_idx: {effective_layer_idx} \ + is out of bounds for layer_types: {layer_types}" + + is_sliding = layer_types[ + effective_layer_idx] == "sliding_attention" if is_sliding: sliding_window = config.sliding_window @@ -611,9 +626,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - def permute(w: torch.Tensor, n_heads: int): + def permute(w: torch.Tensor, n_heads: int, attn_out: int): attn_in = self.config.head_dim * n_heads - attn_out = self.config.hidden_size return w.view(n_heads, attn_in // n_heads // 2, 2, attn_out).transpose(1, 2).reshape(attn_in, attn_out) @@ -622,12 +636,24 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): modules = name.split(".") # rotary embeds should be sliced + # If using quantized model in mistral format, + # quantization scales (qscale_weight) also need to be sliced if "wk" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads) + self.config.num_key_value_heads, + self.config.hidden_size) + elif "wk" in modules and modules[ + -1] == "qscale_weight" and loaded_weight.numel() > 1: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads, 1) elif "wq" in modules and modules[-1] == "weight": loaded_weight = permute(loaded_weight, - self.config.num_attention_heads) + self.config.num_attention_heads, + self.config.hidden_size) + elif "wq" in modules and modules[ + -1] == "qscale_weight" and loaded_weight.numel() > 1: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads, 1) num_modules = len(modules) for i in range(num_modules): diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index ba08e6f81f7fe..ddd7e6a5936e3 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -36,6 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) @@ -73,7 +74,18 @@ class Llama4MoE(nn.Module): quant_config=None, prefix=f"{prefix}.router") - self.experts = FusedMoE( + self.shared_expert = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size_moe, + hidden_act="silu", + quant_config=quant_config, + bias=False, + prefix=f"{prefix}.shared_expert", + reduce_results=False, + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -83,22 +95,13 @@ class Llama4MoE(nn.Module): reduce_results=False, renormalize=False, quant_config=quant_config, - prefix=f"{prefix}.experts") - - self.shared_expert = LlamaMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size_moe, - hidden_act="silu", - quant_config=quant_config, - bias=False, - prefix=f"{prefix}.shared_expert", - reduce_results=self.experts.must_reduce_shared_expert_outputs(), + prefix=f"{prefix}.experts", ) def forward(self, hidden_states): router_logits, _ = self.router(hidden_states) - shared_out = self.shared_expert(hidden_states) - routed_out = self.experts( + + shared_out, routed_out = self.experts( hidden_states=hidden_states, router_logits=router_logits, ) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 572930c39a846..7027138dfcb17 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -9,7 +9,7 @@ import torch.nn as nn from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -33,10 +33,14 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__(config, quant_config=quant_config, prefix=prefix) + super().__init__(config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) # override qkv self.self_attn.qkv_proj = QKVParallelLinear( @@ -114,6 +118,8 @@ class LlamaModel(nn.Module): speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size + current_vllm_config = get_current_vllm_config() + self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, @@ -123,6 +129,7 @@ class LlamaModel(nn.Module): self.layers = nn.ModuleList([ LlamaDecoderLayer( config=self.config, + cache_config=current_vllm_config.cache_config, prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), ) ]) @@ -199,6 +206,10 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): speculative_config.draft_model_config.hf_config target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) + + # Store target layer count in draft config for + # proper layer_types indexing in draft models + self.config.target_layer_count = target_layer_num self.model = LlamaModel(vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num) @@ -209,7 +220,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): self.config.hidden_size, org_num_embeddings=self.config.draft_vocab_size, padding_size=(DEFAULT_VOCAB_PADDING_SIZE), - prefix="") + prefix=maybe_prefix(prefix, "lm_head")) self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, scale=logit_scale) self.draft_id_to_target_id = nn.Parameter( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 8a847a6180f3a..9591deea06ce9 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -24,7 +24,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems) + MultiModalInputs, MultiModalKwargsItems, + MultiModalUUIDDict) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -731,7 +732,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each input image. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. Info: [LlavaImageInputs][] @@ -795,7 +798,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -810,7 +813,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index a63c18493df5e..5e82f9799e0fe 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -535,8 +535,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each grid patch for each input image. - image_sizes: The original `(height, width)` for each input image. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. Info: [LlavaNextImageInputs][] diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index e4ac0cd919101..46d54452a52d8 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -216,12 +216,9 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): seq_len: int, mm_counts: Mapping[str, int], ) -> int: - max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) - max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len) max_frames_per_video = min(max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO) @@ -838,7 +835,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index f02499a4f96b5..9d1017dac8aa1 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -223,6 +223,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) # Used to track and store by the Mamba cache between steps. diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 81b9a125380aa..b1a4138cb8f6c 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -278,6 +278,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 709a5a993c6f7..6ba8ad372c95a 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -15,6 +15,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from .utils import maybe_prefix + class ResidualBlock(nn.Module): @@ -71,6 +73,7 @@ class Medusa(nn.Module): config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), ) self.lm_heads = [ self.lm_head for _ in range(self.config.num_heads) diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py new file mode 100644 index 0000000000000..140800dd41c76 --- /dev/null +++ b/vllm/model_executor/models/midashenglm.py @@ -0,0 +1,794 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Horizon team, Xiaomi MiLM Plus. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiDashengLM model compatible with HuggingFace weights.""" +import collections +import collections.abc +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Callable, Optional, TypedDict, Union, cast + +import numpy as np +import torch +import torch.nn as nn +import torchaudio.transforms as audio_transforms +from transformers import BatchFeature + +from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.midashenglm import DashengConfig + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + +_Tuple2 = Union[int, tuple[int, int], Sequence[int]] + + +def _resolve_tuple2(x: _Tuple2) -> tuple[int, int]: + if isinstance(x, collections.abc.Sequence): + assert len(x) == 2, ( + f"Expected a sequence of length 2, got {x} with length {len(x)}") + return cast(tuple[int, int], tuple(x)) + return (x, x) + + +def calculate_mel_frames_dasheng( + audio_length_samples: int, + n_fft: int = 512, + hop_size: int = 160, + dasheng_subsampling: int = 4, + center=True, + model_subsampling: int = 5, +) -> int: + """Calculate the number of Mel-spectrogram frames.""" + if center: + audio_length_samples = audio_length_samples + n_fft + + return (int(1 + ((audio_length_samples - n_fft) / hop_size)) // + dasheng_subsampling // model_subsampling) + + +class AudioPatchEmbed(nn.Module): + + def __init__( + self, + input_size: _Tuple2 = 64, + patch_size: _Tuple2 = 16, + patch_stride: _Tuple2 = 16, + in_chans: int = 1, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = False, + ): + super().__init__() + self.input_size = _resolve_tuple2(input_size) + self.patch_size = _resolve_tuple2(patch_size) + self.patch_stride = _resolve_tuple2(patch_stride) + self.grid_size = ( + self.input_size[0] // self.patch_stride[0], + self.input_size[1] // self.patch_stride[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=self.patch_size, + stride=self.patch_stride, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + if self.flatten: + x = torch.permute(torch.flatten( + x, 2, 3), (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c") + x = self.norm(x) + return x + + +class LayerScale(nn.Module): + + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class DashengMlp(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = ColumnParallelLinear(input_size=in_features, + output_size=hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.act = get_act_fn("gelu") + self.fc2 = RowParallelLinear(input_size=hidden_features, + output_size=out_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.act(x) + x, _ = self.fc2(x) + return x + + +class DashengAttention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + causal: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.embed_dim = dim + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_heads >= tp_size: + # Number of heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_heads % tp_size == 0 + else: + # Number of heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_heads == 0 + self.num_kv_heads = max(1, self.total_num_heads // tp_size) + self.head_dim = self.embed_dim // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scale = self.head_dim**-0.5 + + self.qkv = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + self.attn = MultiHeadAttention( + self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + ) + self.proj = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + self.causal = causal + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): + B, N, C = x.shape + + qkv_out, _ = self.qkv(x) + q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + attn_out = self.attn(q, k, v) + C_local = attn_out.numel() // (B * N) # C_local for parallel + attn_out = attn_out.view(B, N, C_local) + + x, _ = self.proj(attn_out) + + return x + + +class DashengBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + init_values: Optional[float] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=1e-6) + self.attn = DashengAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.ls1 = (LayerScale(dim, init_values=init_values) + if init_values else nn.Identity()) + + self.norm2 = nn.LayerNorm(dim, eps=1e-6) + self.mlp = DashengMlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.ls2 = (LayerScale(dim, init_values=init_values) + if init_values else nn.Identity()) + + # Kwargs usually has a mask parameter that is passed to Attention + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.ls1(self.attn(self.norm1(x), mask)) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +class DashengAudioTransformer(nn.Module): + + def __init__( + self, + config: DashengConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.target_length = config.target_length + self.hop_length = config.hop_length + + self._init_front_end(config) + + self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01) + + self.patch_embed = AudioPatchEmbed( + input_size=(config.n_mels, config.target_length), + embed_dim=config.embed_dim, + in_chans=config.input_channels, + patch_size=config.patch_size, + flatten=False, + patch_stride=config.patch_stride, + ) + + self.time_pos_embed = nn.Parameter( + torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1])) + self.freq_pos_embed = nn.Parameter( + torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1)) + self.blocks = nn.ModuleList( + DashengBlock( + dim=config.embed_dim, + num_heads=config.num_heads, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + init_values=config.init_values, + quant_config=quant_config, + prefix=f"{prefix}.block{i}", + ) for i in range(config.depth)) + self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) + + def _init_front_end(self, config): + with set_default_torch_dtype(torch.float32): + self.front_end = nn.Sequential( + audio_transforms.MelSpectrogram( + f_min=config.f_min, + f_max=config.f_max, + center=config.center, + win_length=config.win_length, + hop_length=config.hop_length, + sample_rate=config.sample_rate, + n_fft=config.n_fft, + n_mels=config.n_mels, + ), + audio_transforms.AmplitudeToDB(top_db=120), + ) + + mel_spectrogram = self.front_end[0] + fb = mel_spectrogram.mel_scale.fb + win = mel_spectrogram.spectrogram.window + mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to( + torch.float32) + mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to( + torch.float32) + + def forward_features( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + t = x.shape[-1] + x = x + self.time_pos_embed[:, :, :, :t] + x = (x + self.freq_pos_embed[:, :, :, :] + ) # Just to support __getitem__ in posembed + x = torch.permute(torch.flatten(x, 2, 3), + (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c") + for block in self.blocks: + x = block(x, mask) + x = self.norm(x) + return x + + def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor: + batch_size = len(lengths) + idx = torch.arange(max_length, device=lengths.device) + idx = idx.repeat(batch_size).view(batch_size, max_length) + mask = (idx < lengths.unsqueeze(-1)).bool() + return mask + + def forward( + self, + x: torch.Tensor, + x_length: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + x = self.front_end(x) + x = x.to(self.time_pos_embed.dtype) + target_length_in_patches = self.target_length // 4 + x = x.unsqueeze(1) + x = torch.permute(x, (0, 2, 1, 3)) + x = self.init_bn(x) + x = torch.permute(x, (0, 2, 1, 3)) + + x = self.patch_embed(x) + t = x.shape[-1] + + input_splits = x.split(target_length_in_patches, dim=-1) + + if x_length is not None: + assert len(x_length) == len(x), ( + "batchsizes of input x and x_length need to be same") + assert x_length.ndim == 1, "Lengths are of size (B,)" + scaled_lengths = (x_length / (self.hop_length * 4)).long() + mask = self._to_mask(max_length=t, lengths=scaled_lengths) + split_masks = mask.logical_not().split(target_length_in_patches, + dim=-1) + else: + mask = None + split_masks = [None] * len(input_splits) + + outputs = [] + + for split_x, split_mask in zip(input_splits, split_masks): + forward_kwargs = {} + forward_kwargs["mask"] = split_mask + split_x = self.forward_features(split_x, **forward_kwargs) + outputs.append(split_x) + x = torch.cat(outputs, dim=1) + return x, mask + + +class AudioProjectorSubsample(nn.Module): + + def __init__( + self, + in_dim: int, + out_dim: int, + downsample_rate=5, + dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.k = downsample_rate + self.net = nn.Sequential( + ColumnParallelLinear( + input_size=in_dim * self.k, + output_size=out_dim, + quant_config=quant_config, + prefix=f"{prefix}.net.0", + return_bias=False, + ), get_act_fn("gelu"), + RowParallelLinear( + input_size=out_dim, + output_size=out_dim, + quant_config=quant_config, + prefix=f"{prefix}.net.2", + return_bias=False, + )) + + def forward(self, x, mask=None): + batch_size, seq_len, dim = x.shape + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + if mask is not None: + mask = mask[:, :-num_frames_to_discard] + if mask is None: + mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device) + x = x.reshape(batch_size, -1, self.k * + dim) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k) + for layer in self.net: + x = layer(x) + mask = mask.reshape( + batch_size, -1, + self.k) # rearrange(mask, "b (s k) -> b s k", k=self.k) + mask = mask.any(dim=-1).long() + return x, mask + + +# === Audio Inputs === # +class MiDashengLMAudioInputs(TypedDict): + input_values: torch.Tensor + """Shape: `(num_audios, num_sampling_points)`""" + audio_length: torch.Tensor + """Shape: `(num_audios, 1)`""" + + +class MiDashengLMProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_feature_extractor(self): + hf_processor = self.get_hf_processor() + feature_extractor = hf_processor.feature_extractor + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None} + + def get_min_audio_len(self): + return 3200 + + def get_max_audio_len(self): + return 160000 + + +class MiDashengLMDummyInputsBuilder( + BaseDummyInputsBuilder[MiDashengLMProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + hf_processor = self.info.get_hf_processor() + audio_token = hf_processor.audio_token + audio_bos_token = hf_processor.audio_bos_token + audio_eos_token = hf_processor.audio_eos_token + + single_audio_text = f"{audio_bos_token}{audio_token}{audio_eos_token}" + return single_audio_text * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + + return { + "audio": + self._get_dummy_audios(length=self.info.get_max_audio_len(), + num_audios=num_audios) + } + + +class MiDashengLMMultiModalProcessor( + BaseMultiModalProcessor[MiDashengLMProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, Any], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + audios = mm_data.pop("audios", []) + + # + Padding + min_audio_len = self.info.get_min_audio_len() + processed_audios = [ + np.pad(audio, (0, min_audio_len - audio.shape[-1]), + mode='constant', + constant_values=0) if isinstance(audio, np.ndarray) + and audio.shape[-1] < min_audio_len else audio for audio in audios + ] + + if processed_audios: + mm_data["audio"] = processed_audios + + if not mm_data.get("audio", []): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + mm_kwargs = dict(**mm_kwargs, ) + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_values=MultiModalFieldConfig.batched("audio"), + audio_length=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + audio_token_id = vocab[audio_token] + + out_mm_data = out_mm_kwargs.get_data() + audio_length = out_mm_data.get("audio_length") + if audio_length is None: + audio_output_lengths = [] + else: + audio_length_np = audio_length.cpu().numpy() if isinstance( + audio_length, torch.Tensor) else audio_length + audio_output_lengths = [ + max(1, calculate_mel_frames_dasheng( + int(length))) # at least one frame + for length in audio_length_np + ] + + def get_replacement_midashenglm(item_idx: int): + num_features = audio_output_lengths[item_idx] + audio_tokens = [audio_token_id] * num_features + + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_midashenglm, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + MiDashengLMMultiModalProcessor, + info=MiDashengLMProcessingInfo, + dummy_inputs=MiDashengLMDummyInputsBuilder, +) +class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("audio"): + return "<|audio_bos|><|AUDIO|><|audio_eos|>" + + raise ValueError("Only audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + # Initialize audio components + self.audio_encoder = DashengAudioTransformer( + config.audio_encoder_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_encoder"), + ) + self.audio_projector = AudioProjectorSubsample( + in_dim=config.audio_encoder_config.embed_dim, + out_dim=config.text_config.hidden_size, + downsample_rate=config.subsample_factor, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_projector"), + ) + + # Initialize language model (decoder) + self.decoder = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "decoder"), + architectures=["Qwen2ForCausalLM"], + ) + + self.quant_config = quant_config + self.make_empty_intermediate_tensors = ( + self.decoder.make_empty_intermediate_tensors) + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + return mm_input.reshape(-1, *mm_input.shape[2:]) + + if name == "input_values": + max_length = max(tensor.shape[1] for tensor in mm_input) + padded_mm_input = [ + torch.nn.functional.pad(tensor, + (0, max_length - tensor.shape[1])) + if tensor.shape[1] < max_length else tensor + for tensor in mm_input + ] + return torch.concat(padded_mm_input) + + return torch.concat(mm_input) + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[MiDashengLMAudioInputs]: + input_values = kwargs.pop("input_values", None) + audio_length = kwargs.pop("audio_length", None) + + if input_values is None: + return None + input_values = self._validate_and_reshape_mm_tensor( + input_values, "input_values") + audio_length = self._validate_and_reshape_mm_tensor( + audio_length, "audio_length") + if not isinstance(input_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_values)}") + + return MiDashengLMAudioInputs( + input_values=input_values, + audio_length=audio_length, + ) + + def _process_audio_input( + self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor: + # Process audio through encoder and projector + input_values = audio_input["input_values"] + audio_length = audio_input["audio_length"] + + encoder_out, encoder_atts = self.audio_encoder(input_values, + audio_length) + audio_embeddings, _ = self.audio_projector(encoder_out, encoder_atts) + audio_embeddings = audio_embeddings.to( + audio_input["input_values"].dtype) + batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape + + audio_length_np = audio_length.cpu().numpy() if isinstance( + audio_length, torch.Tensor) else audio_length + audio_output_lengths = [ + max(1, calculate_mel_frames_dasheng( + int(length))) # at least one frame + for length in audio_length_np + ] + audio_output_lengths = torch.tensor(audio_output_lengths).to( + audio_embeddings.device) + + audio_feature_mask = (torch.arange( + max_audio_tokens, + device=audio_embeddings.device).unsqueeze(0).expand( + batch_size, max_audio_tokens) + < audio_output_lengths.unsqueeze(1)) + + masked_audio_features = audio_embeddings[audio_feature_mask].view( + -1, embed_dim) + + return torch.split(masked_audio_features, + audio_output_lengths.tolist()) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + + if audio_input is None: + return [] + return self._process_audio_input(audio_input) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.decoder.get_input_embeddings(input_ids) + if multimodal_embeddings and len(multimodal_embeddings) > 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.audio_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + return self.decoder.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index 5a2079bf5121a..09194e9f95d0e 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -158,7 +158,8 @@ class MiMoMTP(nn.Module): prefix=maybe_prefix( prefix, "model")) self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size) + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head")) def forward( self, @@ -240,6 +241,15 @@ class MiMoMTP(nn.Module): def map_model_name_to_mtp_param_name(self, name: str) -> str: import regex as re + + # append mtp_start_layer_idx + pattern = r"(model\.mtp_layers\.)(\d+)(\.)" + match = re.match(pattern, name) + if match: + original_num = int(match.group(2)) + new_num = original_num + self.config.num_hidden_layers + name = name.replace(match.group(), f"{match.group(1)}{new_num}.") + # check for early turn name_without_prefix = [ "token_layernorm", "hidden_layernorm", "input_proj", "final_layernorm" @@ -247,10 +257,11 @@ class MiMoMTP(nn.Module): for sub_name in name_without_prefix: if sub_name in name: return name - pattern = r"model.mtp_layers.(\d+)." - group = re.match(pattern, name) - if group is not None: - name = name.replace(group.group(), group.group() + "mtp_block.") + # add mtp_block + pattern = r"(model\.mtp_layers\.\d+\.)" + match = re.match(pattern, name) + if match: + name = name.replace(match.group(), match.group() + "mtp_block.") return name def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 5632f8c8cc4fb..240c23ea2b25d 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, @@ -136,13 +136,18 @@ class MiniCPMMoE(nn.Module): hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize=True, - inplace=True) + + topk_weights, topk_ids, _ = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=True) + + final_hidden_states = fused_experts(hidden_states, + self.ws, + self.w2s, + topk_weights, + topk_ids, + inplace=True) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -547,6 +552,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 06c2eb4e80afb..848a97b8bb2a0 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -338,6 +338,7 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 04176c5589ed6..9b2d84e32151a 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1117,7 +1117,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): def _process_multimodal_inputs(self, modalities: dict): # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index ef1fe86c5b5c0..6ce883be0a83c 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -702,6 +702,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): self.config.hidden_size, org_num_embeddings=self.config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 08948960b275c..d15776a39362d 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -578,10 +578,12 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - pixel_values: The pixels in each input image. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. Info: - [Mistral3ImagePixelInputs][] + [`Mistral3ImagePixelInputs`][vllm.model_executor.models.mistral3.Mistral3ImagePixelInputs] """ if intermediate_tensors is not None: inputs_embeds = None diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 52fcbbfc58be6..8b3474d809532 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -23,7 +23,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from itertools import islice from typing import Optional, Union @@ -33,8 +34,9 @@ from transformers import MixtralConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -50,8 +52,8 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -74,10 +76,32 @@ class MixtralMoE(nn.Module): quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, dp_size: Optional[int] = None, - prefix: str = ""): + prefix: str = "", + enable_eplb: bool = False): super().__init__() self.hidden_size = hidden_size + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + + # Expert Parallelism Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_routed_experts = num_experts + self.n_logical_experts = num_experts + self.n_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts) + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(hidden_size, @@ -97,7 +121,9 @@ class MixtralMoE(nn.Module): quant_config=quant_config, tp_size=tp_size, dp_size=dp_size, - prefix=f"{prefix}.experts") + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -200,6 +226,7 @@ class MixtralDecoderLayer(nn.Module): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -221,7 +248,8 @@ class MixtralDecoderLayer(nn.Module): hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") + prefix=f"{prefix}.block_sparse_moe", + enable_eplb=enable_eplb) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -262,6 +290,7 @@ class MixtralModel(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config self.config = config self.quant_config = quant_config @@ -276,10 +305,18 @@ class MixtralModel(nn.Module): org_num_embeddings=config.vocab_size, ) + self.enable_eplb = parallel_config.enable_eplb + self.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts) + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix + config, + cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=self.enable_eplb, ), prefix=f"{prefix}.layers") @@ -325,7 +362,8 @@ class MixtralModel(nn.Module): ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) + num_experts=self.config.num_local_experts, + num_redundant_experts=self.num_redundant_experts) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -373,26 +411,40 @@ class MixtralModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: continue - name = name.replace(weight_name, param_name) + + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + if is_pp_missing_parameter(name_mapped, self): continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + + if ((name_mapped.endswith(".bias") + or name_mapped.endswith("_bias")) + and name_mapped not in params_dict): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break + + param = params_dict[name_mapped] + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break else: + if is_expert_weight: + continue # Skip loading extra bias for GPTQ models. if ((name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict): @@ -413,7 +465,8 @@ class MixtralModel(nn.Module): return loaded_params -class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, + MixtureOfExperts): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -454,6 +507,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight @@ -462,6 +516,67 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + self.moe_layers: list[FusedMoE] = [] + example_moe = None + + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + assert isinstance(layer, MixtralDecoderLayer) + if hasattr(layer, "block_sparse_moe") and isinstance( + layer.block_sparse_moe, MixtralMoE): + example_moe = layer.block_sparse_moe + self.moe_layers.append(layer.block_sparse_moe.experts) + + self.num_moe_layers = len(self.moe_layers) + + if example_moe is None: + raise RuntimeError("No MixtralMoE layer found in model.layers.") + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + self.num_expert_groups = 1 + self.num_shared_experts = 0 + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if hasattr(layer, "block_sparse_moe") and isinstance( + layer.block_sparse_moe, MixtralMoE): + moe = layer.block_sparse_moe + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py deleted file mode 100644 index 692267b4d7271..0000000000000 --- a/vllm/model_executor/models/mixtral_quant.py +++ /dev/null @@ -1,454 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Mixtral model.""" -from collections.abc import Iterable -from itertools import islice -from typing import Optional, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -from transformers import MixtralConfig - -from vllm.attention import Attention -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -class MixtralMLP(nn.Module): - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - quant_config=quant_config) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralMoE(nn.Module): - - def __init__( - self, - config: MixtralConfig, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") - # Split experts equally between ranks - self.expert_indices = np.array_split(range(self.num_total_experts), - self.tp_size)[self.rank].tolist() - if not self.expert_indices: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indices else None - for idx in range(self.num_total_experts) - ]) - self.gate = ReplicatedLinear(config.hidden_size, - self.num_total_experts, - bias=False, - quant_config=None) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indices: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) - - -class MixtralAttention(nn.Module): - - def __init__( - self, - config: MixtralConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", None) - if self.head_dim is None: - self.head_dim = self.hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=quant_config, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=int(self.rope_theta), - is_neox_style=True, - ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output - - -class MixtralDecoderLayer(nn.Module): - - def __init__( - self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 - rope_theta = getattr(config, "rope_theta", 10000) - self.self_attn = MixtralAttention( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - self.block_sparse_moe = MixtralMoE(config=config, - quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.block_sparse_moe(hidden_states) - return hidden_states, residual - - -class MixtralModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), - prefix=f"{prefix}.layers") - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if name.endswith("scale"): - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class MixtralForCausalLM(nn.Module, SupportsPP): - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = MixtralModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py deleted file mode 100644 index f441287a4d089..0000000000000 --- a/vllm/model_executor/models/mllama.py +++ /dev/null @@ -1,1703 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Mllama model.""" -import math -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union - -import numpy as np -import torch -import torch.nn.functional as F -import transformers.models.mllama.configuration_mllama as config_mllama -from PIL.Image import Image -from torch import nn -from transformers import BatchFeature, MllamaConfig -from transformers.modeling_outputs import (BaseModelOutput, - CausalLMOutputWithPast) -from transformers.models.mllama.image_processing_mllama import ( - get_optimal_tiled_canvas) -from transformers.models.mllama.processing_mllama import ( - MllamaProcessor, get_cross_attention_token_mask) - -import vllm.distributed.parallel_state as ps -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.selector import _Backend -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group, get_tp_group -from vllm.forward_context import get_forward_context -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVCrossParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseProcessingInfo, - EncDecMultiModalProcessor, - PromptReplacement, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from .clip import CLIPMLP -from .interfaces import SupportsMultiModal, SupportsV0Only -from .llama import LlamaDecoderLayer, LlamaMLP -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix - -logger = init_logger(__name__) - - -class MllamaImagePixelInputs(TensorSchema): - """ - Dimensions: - - batch_size: Batch size - - max_num_image: Max number of images - - max_num_chunk: Max number of chunks - - max_num_tiles: Max number of tiles per image - - num_channel: Number of channels - - height: Height - - width: Width - """ - - type: Literal["pixel_values"] = "pixel_values" - - data: Annotated[torch.Tensor, - TensorShape("batch_size", "max_num_image", "max_num_chunk", - "num_channel", "height", "width")] - - aspect_ratio_ids: Annotated[torch.Tensor, - TensorShape("batch_size", "max_num_image")] - - aspect_ratio_mask: Annotated[ - torch.Tensor, - TensorShape("batch_size", "max_num_image", "max_num_tiles")] - - -# TODO: support LlamaImageEmbeddingInputs - - -def calc_token_per_chunk(image_size: int) -> int: - assert image_size % 14 == 0, "chunk size should be multiple of 14" - token_per_chunk = (image_size // 14)**2 + 1 - return token_per_chunk - - -class MllamaProcessingInfo(BaseProcessingInfo): - - def get_hf_config(self) -> MllamaConfig: - return self.ctx.get_hf_config(MllamaConfig) - - def get_hf_processor(self, **kwargs: object) -> MllamaProcessor: - return self.ctx.get_hf_processor(MllamaProcessor, **kwargs) - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} - - def get_token_per_chunk_from_config(self) -> int: - image_size = self.get_hf_config().vision_config.image_size - return calc_token_per_chunk(image_size) - - def get_num_tiles_per_image(self, image_height: int, - image_width: int) -> int: - vision_config = self.get_hf_config().vision_config - max_num_tiles = vision_config.max_num_tiles - image_size = vision_config.image_size - tiled_height, tiled_width = get_optimal_tiled_canvas( - image_height, - image_width, - max_num_tiles, - tile_size=image_size, - ) - num_tiles_height = tiled_height // image_size - num_tiles_width = tiled_width // image_size - return num_tiles_height * num_tiles_width - - def get_image_size_with_most_features(self) -> ImageSize: - vision_config = self.get_hf_config().vision_config - image_size = vision_config.image_size - max_num_tiles = vision_config.max_num_tiles - # Result in the max possible feature size (h:w = 16:1) - return ImageSize(height=max_num_tiles * image_size, width=image_size) - - -class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]): - - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_images = mm_counts.get("image", 0) - - processor = self.info.get_hf_processor() - image_token = processor.image_token - - return image_token * num_images - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - - target_width, target_height = \ - self.info.get_image_size_with_most_features() - - return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - -class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] - ): - - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, - ) -> MultiModalEncDecInputs: - mm_inputs = super().apply(prompt, - mm_data, - hf_processor_mm_kwargs, - tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) - - image_token_id = self.info.get_hf_config().image_token_index - # Check that the number of image tokens in the decoder prompt matches - # the number of images provided in mm_data - num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id) - image_data = mm_data.get("image", []) - num_images = 1 if isinstance(image_data, Image) else len(image_data) - if num_image_tokens != num_images: - raise ValueError( - f"The number of image tokens ({num_image_tokens}) must be" - f" the same as the number of images ({num_images})") - - # Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501 - # P0 & P1 do cross attention with placeholder of <IMG0> - # P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2> - # Example input to encoder and decoder: - # { - # 'encoder': { - # 'type': 'token', - # 'prompt_token_ids': [128256, 128256, ..., 128256], - # 'prompt': '<|image|><|image|>...<|image|>', - # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501 - # }, - # 'decoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 - # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501 - # }, - # } - - if mm_data: - hf_processor = self.info.get_hf_processor() - image_token: str = hf_processor.image_token - - # Since only the last group of consecutive images - # are attended by the decoded tokens, we only need to - # get the number of tokens for those images. - token_per_chunk = self.info.get_token_per_chunk_from_config() - num_decode_images = self._get_num_image_in_last_group( - mm_inputs["prompt_token_ids"]) - num_encode_images = num_images - num_decode_images - - # Set encoder prompt length based on the number of tiles. - # This tells the block manager to allocate correct number - # of slots for encoder tokens. - num_tiles = mm_inputs["mm_kwargs"].get_data()["num_tiles"] - decode_tiles = num_tiles[num_encode_images:num_images].sum().item() - num_tokens = decode_tiles * token_per_chunk - mm_inputs["encoder_prompt_token_ids"] = [image_token_id - ] * num_tokens - mm_inputs["encoder_prompt"] = image_token * num_tokens - - return mm_inputs - - def _get_num_image_in_last_group(self, prompt_token_ids: list[int]) -> int: - num_images = 0 - for token_id in prompt_token_ids[::-1]: - if token_id == self.info.get_hf_config().image_token_index: - num_images += 1 - elif num_images > 0: - break - return num_images - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - tokenizer = self.info.get_tokenizer() - if mm_data: - num_tiles = [ - self.info.get_num_tiles_per_image(img.height, img.width) - for img in mm_data["images"] - ] - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs) - processed_outputs["num_tiles"] = torch.tensor(num_tiles) - for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): - processed_outputs[k] = processed_outputs[k].squeeze(0) - - processed_token_ids = processed_outputs.pop("input_ids") - start_idx, end_idx = 0, processed_token_ids.size(1) - processed_prompt_text = tokenizer.decode(processed_token_ids[0]) - - hf_processor = self.info.get_hf_processor() - bos_token = hf_processor.bos_token - # Remove the bos_token from the start of prompt, - # because we all know there would be image_token. - if processed_prompt_text.startswith(bos_token): - start_idx += 1 - # Remove the bos_token from the end of prompt, - # because text is empty in this case. - if processed_prompt_text.endswith(bos_token): - end_idx -= 1 - processed_outputs[ - "input_ids"] = processed_token_ids[:, start_idx:end_idx] - else: - processed_outputs = tokenizer(prompt, - add_special_tokens=False, - return_tensors="pt") - return processed_outputs - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - aspect_ratio_ids=MultiModalFieldConfig.batched("image"), - aspect_ratio_mask=MultiModalFieldConfig.batched("image"), - num_tiles=MultiModalFieldConfig.batched("image"), - ) - - def create_encoder_prompt( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: - data = mm_data.get("image", []) - num_images = 1 if isinstance(data, Image) else len(data) - image_token_id = self.info.get_hf_config().image_token_index - return [image_token_id] * num_images - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - token_per_chunk = self.info.get_token_per_chunk_from_config() - image_token_id = self.info.get_hf_config().image_token_index - - def get_replacement_mllama(item_idx): - images = mm_items.get_items("image", ImageProcessorItems) - image_size = images.get_image_size(item_idx) - num_tile = self.info.get_num_tiles_per_image( - image_height=image_size.height, - image_width=image_size.width, - ) - num_tokens = num_tile * token_per_chunk - return [image_token_id] * num_tokens - - return [ - PromptReplacement( - modality="image", - target=[image_token_id], - replacement=get_replacement_mllama, - ) - ] - - -def _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask: torch.Tensor, - num_patches: int, - target_length: int, - dtype: torch.dtype, -) -> torch.Tensor: - # Expand aspect ratio mask to target_length - batch_size, max_num_tiles = aspect_ratio_mask.shape - attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, - 1).to(dtype) - attention_mask = attention_mask.repeat(1, 1, target_length, 1) - - # Mask padding patches - pad_patches = target_length - num_patches - attention_mask[:, :, -pad_patches:] = 0 - - # Invert the mask (0 -> 1, 1 -> 0) - attention_mask = 1 - attention_mask - - # Reshape to 2D and create 4D attention mask - # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length) - attention_mask = attention_mask.reshape(batch_size, - max_num_tiles * target_length, 1) - attention_mask = attention_mask @ attention_mask.transpose( - -1, -2) * torch.finfo(dtype).min - attention_mask = attention_mask.unsqueeze(1) - - return attention_mask - - -class ColumnParallelConv2dPatch(torch.nn.Module): - """Conv2D Patching layer with model parallelism. - Column parallel over unfolded input. - Arguments: - in_channels: Input channels. - out_channels: Output channels. - kernel_size: Size of convolution kernel. - stride (default 1): Stride for convolution. - bias (default False): Use bias in Conv2d. - Input: (bsz, in_channels, width, height) - Output: (bsz, num_tokens, out_channels) - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, tuple[int, int]], - stride: Union[int, tuple[int, int]], - bias: bool = False, - ) -> None: - super().__init__() - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) - self._linear = ColumnParallelLinear( - in_channels * kernel_size[0] * kernel_size[1], - out_channels, - bias=bias, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self._unfold(x) - x = x.permute(0, 2, 1) - x, _ = self._linear(x) - return x - - -class MllamaPrecomputedAspectRatioEmbedding(nn.Module): - - def __init__(self, - config: config_mllama.MllamaVisionConfig, - is_gated: bool = True): - super().__init__() - self.max_num_tiles = config.max_num_tiles - self.hidden_size = config.hidden_size - self.max_aspect_ratio_id = config.max_aspect_ratio_id - self.is_gated = is_gated - - self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, - self.max_num_tiles * self.hidden_size) - if is_gated: - self.gate = nn.Parameter(torch.zeros(1)) - - def forward(self, hidden_state: torch.Tensor, - aspect_ratio_ids: torch.Tensor) -> torch.Tensor: - embeddings = self.embedding(aspect_ratio_ids) - embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, - self.hidden_size) - - if self.is_gated: - embeddings = embeddings * self.gate.tanh() - - hidden_state = hidden_state + embeddings - return hidden_state - - -class MllamaPrecomputedPositionEmbedding(nn.Module): - - def __init__(self, config: config_mllama.MllamaVisionConfig): - super().__init__() - self.max_num_tiles = config.max_num_tiles - self.max_aspect_ratio_id = config.max_aspect_ratio_id - self.num_patches = (config.image_size // config.patch_size)**2 + 1 - self.hidden_size = config.hidden_size - self.scale = config.hidden_size**-0.5 - - self.gate = nn.Parameter(torch.zeros(1)) - - # position embedding - position_embedding = torch.randn(self.num_patches, self.hidden_size) - self.embedding = nn.Parameter(self.scale * position_embedding) - - # tile position embedding - self.tile_embedding = nn.Embedding( - self.max_aspect_ratio_id + 1, - self.max_num_tiles * self.num_patches * self.hidden_size) - - def forward(self, hidden_state: torch.Tensor, - aspect_ratio_ids: torch.Tensor) -> torch.Tensor: - # position embeddings - gated_position_embedding = (1 - self.gate.tanh()) * self.embedding - hidden_state = hidden_state + gated_position_embedding.view( - 1, 1, self.num_patches, self.hidden_size) - - # precomputed tile position embeddings - tile_position_embedding = self.tile_embedding(aspect_ratio_ids) - batch_size = hidden_state.shape[0] - tile_position_embedding = tile_position_embedding.reshape( - batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) - gated_tile_position_embedding = self.gate.tanh( - ) * tile_position_embedding - hidden_state = hidden_state + gated_tile_position_embedding - - return hidden_state - - -# TODO: support other attention backends for attention in vision model -class MllamaVisionSdpaAttention(nn.Module): - - def __init__(self, - config: config_mllama.MllamaVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__() - - tensor_parallel_size = get_tp_group().world_size - self.embed_dim = config.hidden_size - self.num_heads = config.attention_heads - self.head_dim = config.hidden_size // config.attention_heads - self.num_local_heads = self.num_heads // tensor_parallel_size - self.q_size = self.num_local_heads * self.head_dim - self.kv_size = self.num_local_heads * self.head_dim - - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_state) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view(q.shape[0], q.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - k = k.view(k.shape[0], k.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - v = v.view(v.shape[0], v.shape[1], self.num_local_heads, - self.head_dim).transpose(1, 2) - - # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - dropout_p=0.0) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(attn_output.shape[0], - attn_output.shape[1], -1) - output, _ = self.o_proj(attn_output) - return output - - -class MllamaVisionEncoderLayer(nn.Module): - - def __init__( - self, - config: config_mllama.MllamaVisionConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "", - is_gated: bool = False, - ) -> None: - super().__init__() - - self.hidden_size = config.hidden_size - self.num_attention_heads = config.attention_heads - self.is_gated = is_gated - self.intermediate_size = config.intermediate_size - - self.self_attn = MllamaVisionSdpaAttention( - config, quant_config=quant_config, prefix=f"{prefix}.self_attn") - self.mlp = CLIPMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - - self.input_layernorm = nn.LayerNorm(self.hidden_size, - eps=config.norm_eps) - self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, - eps=config.norm_eps) - - # there used to be an if else here, no code path - if is_gated: - self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) - self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ): - # Self Attention - residual = hidden_state - hidden_state = self.input_layernorm(hidden_state) - hidden_state = self.self_attn(hidden_state, - attention_mask=attention_mask) - gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() - hidden_state = residual + gate_attn * hidden_state - - # Feed forward - residual = hidden_state - hidden_state = self.post_attention_layernorm(hidden_state) - hidden_state = self.mlp(hidden_state) - gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() - hidden_state = residual + gate_ffn * hidden_state - - return hidden_state - - -class MllamaVisionEncoder(nn.Module): - - def __init__( - self, - config: config_mllama.MllamaVisionConfig, - quant_config: Optional[QuantizationConfig], - num_layers: int = 32, - is_gated: bool = False, - output_hidden_states=None, - prefix: str = "", - ) -> None: - super().__init__() - self.config = config - self.layers = nn.ModuleList([ - MllamaVisionEncoderLayer(config, - quant_config=quant_config, - is_gated=is_gated, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_layers) - ]) - self.output_hidden_states = output_hidden_states or [] - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> Union[BaseModelOutput]: - encoder_states = () - - for i, encoder_layer in enumerate(self.layers): - if i in self.output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - hidden_states = encoder_layer( - hidden_states, - attention_mask, - ) - - if len(self.layers) - 1 in self.output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - - return hidden_states, encoder_states - - -class MllamaVisionModel(nn.Module): - - def __init__( - self, - config: config_mllama.MllamaVisionConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "", - ) -> None: - super().__init__() - - self.image_size = config.image_size - self.patch_size = config.patch_size - self.max_num_tiles = config.max_num_tiles - self.hidden_size = config.hidden_size - self.in_channels = config.num_channels - self.intermediate_layers_indices = config.intermediate_layers_indices - - self.num_patches = (self.image_size // self.patch_size)**2 + 1 - self.scale = config.hidden_size**-0.5 - - self.patch_embedding = ColumnParallelConv2dPatch( - in_channels=config.num_channels, - out_channels=self.hidden_size, - kernel_size=self.patch_size, - stride=self.patch_size, - bias=False, - ) - - self.class_embedding = nn.Parameter(self.scale * - torch.randn(self.hidden_size)) - self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( - config) - - self.pre_tile_positional_embedding = \ - MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) - self.post_tile_positional_embedding = \ - MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) - - # layer norms - self.layernorm_pre = nn.LayerNorm(self.hidden_size) - self.layernorm_post = nn.LayerNorm(self.hidden_size) - - # encoders - self.transformer = MllamaVisionEncoder( - config, - quant_config, - config.num_hidden_layers, - is_gated=False, - output_hidden_states=config.intermediate_layers_indices, - prefix=f"{prefix}.transformer", - ) - self.global_transformer = MllamaVisionEncoder( - config, - quant_config, - config.num_global_layers, - is_gated=True, - prefix=f"{prefix}.global_transformer", - ) - - def apply_class_embedding(self, - hidden_state: torch.Tensor) -> torch.Tensor: - batch_size, _, hidden_size = hidden_state.shape - class_embedding = self.class_embedding.expand(batch_size, 1, - hidden_size) - hidden_state = torch.cat([class_embedding, hidden_state], dim=1) - return hidden_state - - def forward(self, pixel_values: torch.Tensor, - aspect_ratio_ids: torch.Tensor, - aspect_ratio_mask: torch.Tensor) -> torch.Tensor: - batch_size, num_concurrent_media, num_tiles, num_channels, \ - height, width = pixel_values.shape - - pixel_values = pixel_values.reshape( - batch_size * num_concurrent_media * num_tiles, num_channels, - height, width) - aspect_ratio_ids = aspect_ratio_ids.reshape( - batch_size * num_concurrent_media, -1) - - # patch embedding - patch_embeds = self.patch_embedding( - pixel_values.to(self.layernorm_pre.weight.dtype)) - hidden_state = patch_embeds - hidden_state = ps.get_tp_group().all_gather(hidden_state) - - # tile embeddings - _, num_patches, dim = hidden_state.shape - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, -1, dim) - hidden_state = self.pre_tile_positional_embedding( - hidden_state, aspect_ratio_ids) - - # apply cls token - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media * num_tiles, num_patches, dim) - hidden_state = self.apply_class_embedding(hidden_state) - num_patches += 1 - - # apply position embeddings - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, num_patches, dim) - hidden_state = self.gated_positional_embedding(hidden_state, - aspect_ratio_ids) - - # apply encoder - hidden_state = self.layernorm_pre(hidden_state) - - # Compute the number of tokens to pad - num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 - # Compute padding tuple for pad function - padding = ( - 0, 0, 0, num_padding_patches - ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) - # Pad the tensor - hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) - slice_index = -num_padding_patches if num_padding_patches > 0 else None - - attention_mask = aspect_ratio_mask.reshape( - batch_size * num_concurrent_media, -1) - attention_mask = _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask=attention_mask, - num_patches=self.num_patches, - target_length=hidden_state.shape[2], - dtype=self.layernorm_pre.weight.dtype, - ) - - hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, - dim) - output = self.transformer( - hidden_state, - attention_mask=attention_mask, - ) - hidden_state, intermediate_hidden_states = output[0], output[1] - intermediate_hidden_states = torch.stack(intermediate_hidden_states, - dim=-1) - - # apply global encoder - hidden_state = self.layernorm_post(hidden_state) - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim) - hidden_state = self.post_tile_positional_embedding( - hidden_state, aspect_ratio_ids) - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, - num_tiles * (num_patches + num_padding_patches), dim) - hidden_state = self.global_transformer( - hidden_state, attention_mask=attention_mask)[0] - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, - num_tiles, - num_patches + num_padding_patches, - dim) - hidden_state = hidden_state[:, :, :slice_index] - - # adding intermediate layer outputs - hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, - num_tiles, num_patches, dim) - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size * num_concurrent_media, num_tiles, - num_patches + num_padding_patches, -1) - intermediate_hidden_states = intermediate_hidden_states[:, :, : - slice_index] - intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, -1) - hidden_state = torch.cat([hidden_state, intermediate_hidden_states], - dim=-1) - return hidden_state - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - updated_params: set[str] = set() - for name, loaded_weight in weights: - if 'patch_embedding._linear.weight' in name: - loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - updated_params.add(name) - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict.pop(name) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - updated_params.add(name) - return updated_params - - -class MllamaTextRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - """ - MllamaTextRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class MllamaTextCrossAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config: Optional[config_mllama.MllamaTextConfig] = None, - layer_idx: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.pipeline_parallel_rank = get_pp_group().rank_in_group - self.tensor_parallel_size = get_tp_group().world_size - self.num_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - - self.num_local_heads = self.num_heads // self.tensor_parallel_size - self.num_local_key_value_heads = \ - self.num_key_value_heads // self.tensor_parallel_size - self.hidden_size = config.hidden_size - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - - self.layer_idx = layer_idx - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.q_local_size = self.num_local_heads * self.head_dim - self.kv_local_size = self.num_local_key_value_heads * self.head_dim - - self.qkv_proj = QKVCrossParallelLinear( - self.hidden_size, - self.head_dim, - self.num_heads, - self.num_key_value_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, - # use huggingface's instead - self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.scaling = self.head_dim**-0.5 - - self.attn = Attention( - self.num_local_heads, - self.head_dim, - self.scaling, - self.num_local_key_value_heads, - prefix=f"{prefix}.attn", - attn_type=AttentionType.ENCODER_DECODER, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor], - kv_range_for_decode: Optional[list[tuple[int, int]]], - cross_attention_states: Optional[torch.Tensor], - ) -> torch.Tensor: - q, k, v = self.qkv_proj(hidden_states, cross_attention_states) - if cross_attention_states is not None: - k = k.view(-1, self.num_local_key_value_heads, self.head_dim) - v = v.view(-1, self.num_local_key_value_heads, self.head_dim) - k = self.k_norm(k) - - q = q.view(-1, self.num_local_heads, self.head_dim) - q = self.q_norm(q) - - if attention_mask is not None: - output = self._attention_with_mask(q, k, v, attention_mask, - kv_range_for_decode) - else: - output = self.attn( - q.view(-1, self.num_local_heads * self.head_dim), k, v) - out, _ = self.o_proj(output) - return out - - def _attention_with_mask( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attention_mask: torch.Tensor, - kv_range_for_decode: list[tuple[int, int]], - ) -> torch.Tensor: - kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - # Skip writing kv-cache for the initial profiling run. - # TODO (NickLucche) replace with custom attn bias and use standard attn - if len(kv_cache.shape) > 1: - i = torch.ones(1, dtype=torch.float32) - if self.attn.backend in (_Backend.FLASH_ATTN, - _Backend.FLASH_ATTN_VLLM_V1): - cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) - torch.ops._C_cache_ops.reshape_and_cache_flash( - cached_k, - cached_v, - kv_cache[0], - kv_cache[1], - attn_metadata. - cross_slot_mapping, # type: ignore[union-attr] - "auto", - i, - i, - ) - elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH, - _Backend.TORCH_SDPA): - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_local_key_value_heads, self.head_dim) - cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) - PagedAttention.write_to_paged_cache( - cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", i, i) - else: - raise ValueError( - f"Unsupported Attention backend {self.attn.backend} " - "enum found. Expected the Attention backend to be " - "FLASH_ATTN, FLASH_ATTN_VLLM_V1, " - "XFORMERS or TORCH_SDPA.") - - # We have to call torch.sdpa for prefill when using a - # custom cross-attention mask. Because the mask is not a - # standard causal mask, neither a block diagonal mask which - # can be optimized by xformers.BlockDiagonalMask. - # The mask is specially calculated for supporting multi - # images and interleaved images. - q_len = q.shape[0] - kv_len = k.shape[0] - q = q.transpose(0, 1).view(self.num_local_key_value_heads, - self.num_key_value_groups, q_len, - self.head_dim).contiguous() - k = k.transpose(0, - 1)[:, - None, :, :].expand(self.num_local_key_value_heads, - self.num_key_value_groups, - kv_len, - self.head_dim).contiguous() - v = v.transpose(0, - 1)[:, - None, :, :].expand(self.num_local_key_value_heads, - self.num_key_value_groups, - kv_len, - self.head_dim).contiguous() - attention_mask = attention_mask.view(1, 1, q_len, kv_len) - output = F.scaled_dot_product_attention(q, - k, - v, - attn_mask=attention_mask, - is_causal=False) - output = output.permute(2, 0, 1, 3).reshape( - q_len, self.num_local_heads * self.head_dim) - return output - - -class MllamaCrossAttentionDecoderLayer(torch.nn.Module): - """Cross-attention transformer block with tanh-gated attention - and feedforward.""" - - def __init__( - self, - config: config_mllama.MllamaTextConfig, - layer_idx: int, - quant_config: Optional[QuantizationConfig], - prefix: str = "", - ) -> None: - super().__init__() - - self.layer_idx = layer_idx - self.cross_attn = MllamaTextCrossAttention( - config=config, - layer_idx=layer_idx, - quant_config=quant_config, - prefix=f"{prefix}.cross_attn", - ) - - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) - - self.mlp = LlamaMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: torch.Tensor, - cross_attention_mask: torch.Tensor, - kv_range_for_decode: Optional[list[tuple[int, int]]], - full_text_row_masked_out_mask: torch.Tensor, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - hidden_states = self.cross_attn( - hidden_states=hidden_states, - attention_mask=cross_attention_mask, - kv_range_for_decode=kv_range_for_decode, - cross_attention_states=cross_attention_states, - ) - hidden_states = full_text_row_masked_out_mask * hidden_states - hidden_states = residual + self.cross_attn_attn_gate.tanh( - ) * hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = full_text_row_masked_out_mask * hidden_states - hidden_states = residual + self.cross_attn_mlp_gate.tanh( - ) * hidden_states - return hidden_states - - -class MllamaTextModel(nn.Module): - config_class = config_mllama.MllamaTextConfig - base_model_prefix = "model" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config.text_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, - config.hidden_size) - self.cross_attention_layers = config.cross_attention_layers - - layers = [] - for layer_idx in range(config.num_hidden_layers): - if layer_idx in self.cross_attention_layers: - layers.append( - MllamaCrossAttentionDecoderLayer( - config, - layer_idx, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - )) - else: - # TODO: force LlamaDecoderLayer to config.attention_bias=False - layers.append( - LlamaDecoderLayer( - config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - )) - - self.layers = nn.ModuleList(layers) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.LongTensor, - positions: Optional[torch.LongTensor], - cross_attention_states: Optional[torch.LongTensor], - cross_attention_mask: Optional[torch.LongTensor], - kv_range_for_decode: Optional[list[tuple[int, int]]], - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, - torch.Tensor]], - skip_cross_attention: bool, - ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - - for idx, decoder_layer in enumerate(self.layers): - if idx in self.cross_attention_layers: - if not skip_cross_attention: - hidden_states = decoder_layer( - hidden_states=hidden_states, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - kv_range_for_decode=kv_range_for_decode, - full_text_row_masked_out_mask= - full_text_row_masked_out_mask, - ) - else: - hidden_states, residual = decoder_layer( - positions=positions, - hidden_states=hidden_states, - residual=None, - ) - hidden_states = hidden_states + residual - hidden_states = self.norm(hidden_states) - return hidden_states - - -class MllamaForCausalLM(nn.Module): - config_class = config_mllama.MllamaTextConfig - base_model_prefix = "language_model" - _no_split_modules = [ - "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" - ] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config.text_config - quant_config = vllm_config.quant_config - self.quant_config = quant_config - - self.vocab_size = config.vocab_size - self.model = MllamaTextModel(vllm_config=vllm_config, - prefix=f"{prefix}.model") - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - quant_config=quant_config, - prefix=f"{prefix}.lm_head", - ) - - def forward( - self, - input_ids: torch.LongTensor, - positions: Optional[torch.LongTensor], - cross_attention_states: Optional[torch.LongTensor], - cross_attention_mask: Optional[torch.LongTensor], - kv_range_for_decode: Optional[list[tuple[int, int]]], - full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, - torch.Tensor]], - skip_cross_attention: bool, - ) -> torch.Tensor: - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - kv_range_for_decode=kv_range_for_decode, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - skip_cross_attention=skip_cross_attention, - ) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - updated_params: set[str] = set() - for name, loaded_weight in weights: - if 'patch_embedding.weight' in name: - name = name.replace('patch_embedding.weight', - 'patch_embedding._linear.weight') - loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - updated_params.add(scale_name) - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - updated_params.add(name) - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - orig_name = name - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - logger.debug("Missing name %s, orig name %s", name, - orig_name) - continue - - param = params_dict.pop(name) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - updated_params.add(name) - return updated_params - - -@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor, - info=MllamaProcessingInfo, - dummy_inputs=MllamaDummyInputsBuilder) -class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsV0Only): - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] - } - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - # mapping for new names in checkpoint saved after transformers v4.52 - "model.vision_model.": "vision_model.", - "model.multi_modal_projector.": "multi_modal_projector.", - "model.language_model.": "language_model.model.", - "lm_head.": "language_model.lm_head.", - }, - orig_to_new_suffix={ - "patch_embedding.weight": "patch_embedding._linear.weight", - }, - ) - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - if modality.startswith("image"): - return "<|image|>" - - raise ValueError("Only image modality is supported") - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config: MllamaConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.vocab_size = config.text_config.vocab_size - self.hidden_size = config.text_config.hidden_size - self.max_num_tiles = config.vision_config.max_num_tiles - self.vision_output_dim = config.vision_config.vision_output_dim - self.pad_token_id = \ - config.pad_token_id if config.pad_token_id is not None else -1 - self.image_size = config.vision_config.image_size - self.image_token_id = config.image_token_index - - self.vision_model = MllamaVisionModel(config.vision_config, - quant_config, - prefix=maybe_prefix( - prefix, "vision_model")) - self.language_model = MllamaForCausalLM( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - self.multi_modal_projector = ColumnParallelLinear( - config.vision_config.vision_output_dim, - config.text_config.hidden_size, - bias=True, - quant_config=quant_config, - gather_output=True, - prefix=maybe_prefix(prefix, "multi_modal_projector"), - ) - self.logits_processor = LogitsProcessor(config.output_hidden_states, - config.text_config.vocab_size) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.language_model.lm_head, - hidden_states, sampling_metadata) - return logits - - def unpack_data(self, - image_data: Union[list[torch.Tensor], torch.Tensor], - padding_value=0) -> torch.Tensor: - if isinstance(image_data, torch.Tensor): - # torch.Tensor - return image_data - else: - assert isinstance( - image_data[0], - torch.Tensor), "Image data is not properly batched." - # list[torch.Tensor] - bsz = len(image_data) - max_length = max(t.size(0) for t in image_data) - trailing_dims = image_data[0].shape[1:] - for data in image_data: - cur_trailing_dims = data.shape[1:] - assert cur_trailing_dims == trailing_dims - output_tensor = torch.full((bsz, max_length, *trailing_dims), - padding_value, - dtype=image_data[0].dtype, - device=image_data[0].device) - for i, t in enumerate(image_data): - output_tensor[i, :t.size(0)] = t - return output_tensor - - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[MllamaImagePixelInputs]: - # tensor with the same shape will be batched together by - # MultiModalKwargs.batch, so pixel_values here can be: - # - list[torch.Tensor]: - # with shape (num_image, num_tiles, 3, image_res, image_res) - # - torch.Tensor: - # with shape (bs, num_image, num_tiles, 3, image_res, image_res) - pixel_values: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "pixel_values", None) - image_embeds: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "image_embeds", None) - aspect_ratio_ids: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "aspect_ratio_ids", None) - aspect_ratio_mask: Optional[Union[list[list[torch.Tensor]], - list[torch.Tensor], - torch.Tensor]] = kwargs.pop( - "aspect_ratio_mask", None) - - if pixel_values is None and image_embeds is None: - return None - - if pixel_values is not None and image_embeds is not None: - raise ValueError( - "Both pixel values and image embeds are provided.") - - if pixel_values is not None: - assert aspect_ratio_ids is not None - assert aspect_ratio_mask is not None - - return MllamaImagePixelInputs( - type="pixel_values", - data=self.unpack_data(pixel_values), - aspect_ratio_ids=self.unpack_data(aspect_ratio_ids), - aspect_ratio_mask=self.unpack_data(aspect_ratio_mask)) - - if image_embeds is not None: - raise NotImplementedError - - raise AssertionError("This line should be unreachable.") - - def _get_and_validate_encoder_lens( - self, - encoder_seq_lens: list[int], - num_tiles: list[list[int]], - num_tokens_per_tile: int, - ) -> list[int]: - # Get the actual number of encoder tokens for each sample. - # Because attn_metadata.encoder_seq_lens only counts the last - # group of images for each sample, which is used to cheat the - # block manager to allocate blocks for those images only. - # See MllamaMultiModalProcessor for more details. - actual_encoder_seq_lens = [ - sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles - ] - - # remove 0 encoder len entries for text-only requests for these - # assertions - attn_metadata_lens = [x for x in encoder_seq_lens if x > 0] - assert len(actual_encoder_seq_lens) == len(attn_metadata_lens) - for actual_len, last_group_len in zip(actual_encoder_seq_lens, - attn_metadata_lens): - assert actual_len >= last_group_len - - return actual_encoder_seq_lens - - def flat_encoder_result(self, cross_attention_states: torch.Tensor, - attn_metadata: AttentionMetadata, - actual_encoder_seq_lens: list[int]): - - cross_attention_states_flat = torch.zeros( - sum(actual_encoder_seq_lens), - cross_attention_states.shape[-1], - device=cross_attention_states.device, - dtype=cross_attention_states.dtype) - start_pos = 0 - for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens, - cross_attention_states): - end_pos = start_pos + seq_len - cross_attention_states_flat[ - start_pos:end_pos] = vision_token_in_batch[:seq_len] - start_pos = end_pos - cross_attention_states = cross_attention_states_flat - return cross_attention_states - - def get_language_model(self) -> torch.nn.Module: - return self.language_model - - def get_cross_attention_states( - self, - image_inputs: MllamaImagePixelInputs, - attn_metadata: AttentionMetadata, - actual_encoder_seq_lens: list[int], - ) -> tuple[torch.Tensor]: - # NOTE: llama's reference implementation runs vision model on CPU - pixel_values = image_inputs['data'] - aspect_ratio_ids = image_inputs['aspect_ratio_ids'] - aspect_ratio_mask = image_inputs['aspect_ratio_mask'] - cross_attention_states = self.vision_model(pixel_values, - aspect_ratio_ids, - aspect_ratio_mask) - cross_attention_states, _ = self.multi_modal_projector( - cross_attention_states) - - bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) - cross_attention_states = cross_attention_states.view( - bsz, -1, image_token_dim) - - cross_attention_states = self.flat_encoder_result( - cross_attention_states, attn_metadata, actual_encoder_seq_lens) - - return cross_attention_states - - def get_cross_attention_mask( - self, - input_ids: torch.Tensor, - attn_metadata: AttentionMetadata, - num_tiles: list[list[int]], - num_tokens_per_tile: int, - dtype: torch.dtype, - ) -> tuple[torch.Tensor, torch.Tensor]: - token_ids = input_ids.tolist() - start = 0 - batch_token_ids = [] - for seq_len in attn_metadata.seq_lens: - batch_token_ids.append(token_ids[start:start + seq_len]) - start += seq_len - sparse_mask = [ - get_cross_attention_token_mask(t, self.image_token_id) - for t in batch_token_ids - ] - - # Skip generating cross-attention mask if all samples - # are text-only or have only 1 leading image. - if skip_attention_mask(sparse_mask): - return None, None - - dense_mask, tile_range_for_decode = \ - convert_sparse_cross_attention_mask_to_dense( - sparse_mask, num_tiles, attn_metadata.seq_lens) - cross_attention_mask = \ - convert_dense_cross_attention_mask_to_tensor( - dense_mask, num_tokens_per_tile, input_ids.device, dtype) - kv_range_for_decode = [[ - t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile - ] for t in tile_range_for_decode] - - return cross_attention_mask, kv_range_for_decode - - def get_full_text_row_masked_out_mask( - self, - attn_metadata: AttentionMetadata, - device: torch.device, - ) -> torch.Tensor: - full_text_row_masked_out_mask = torch.ones( - (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) - start_pos = 0 - for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens, - attn_metadata.encoder_seq_lens): - if encoder_seq_len == 0: - full_text_row_masked_out_mask[start_pos:start_pos + - seq_len] = False - start_pos += seq_len - full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( - device) - return full_text_row_masked_out_mask - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - **kwargs: object, - ) -> Union[CausalLMOutputWithPast]: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata.num_prefill_tokens > 0 and \ - attn_metadata.num_decode_tokens > 0: - raise ValueError("Chunk prefill not supported") - image_inputs = self._parse_and_validate_image_input(**kwargs) - cross_attention_states = None - cross_attention_mask = None - kv_range_for_decode = None - - # For 1) text-only prefill and decode, 2) image-present decode. - if image_inputs is None: - full_text_row_masked_out_mask = ( - attn_metadata.encoder_seq_lens_tensor - != 0).reshape(-1, 1).to(input_ids.device) - skip_cross_attention = attn_metadata.max_encoder_seq_len == 0 - - # For image-present prefill. - else: - skip_cross_attention = False - - num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")] - num_tokens_per_tile = calc_token_per_chunk(self.image_size) - - actual_encoder_seq_lens = self._get_and_validate_encoder_lens( - attn_metadata.encoder_seq_lens, - num_tiles, - num_tokens_per_tile, - ) - - cross_attention_states = self.get_cross_attention_states( - image_inputs, attn_metadata, actual_encoder_seq_lens) - - full_text_row_masked_out_mask = \ - self.get_full_text_row_masked_out_mask( - attn_metadata, input_ids.device) - - cross_attention_mask, kv_range_for_decode = \ - self.get_cross_attention_mask( - input_ids, attn_metadata, num_tiles, - num_tokens_per_tile, cross_attention_states.dtype) - - outputs = self.language_model( - input_ids=input_ids, - positions=positions, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - kv_range_for_decode=kv_range_for_decode, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - skip_cross_attention=skip_cross_attention, - ) - - return outputs - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="language_model", - connector="multi_modal_projector", - tower_model="vision_model") - - -def skip_attention_mask(sparse_mask: list[list[int]]) -> bool: - for mask in sparse_mask: - # Skip text-only samples. - if len(mask) == 0: - continue - # If the sample contains more than 1 images, - # we can't skip mask. - if len(mask) != 1: - return False - # If the sample contains only 1 image, - # but the image is not the leading one, - # we can't skip mask. - if mask[0][0] != 0 or mask[0][1] != -1: - return False - return True - - -def convert_sparse_cross_attention_mask_to_dense( - sparse_mask: list[list[list[int]]], - num_tiles: list[list[int]], - lengths: list[int], -) -> tuple[np.ndarray, list[tuple[int, int]]]: - total_length = sum(lengths) - total_tiles = sum([sum(tiles) for tiles in num_tiles]) - dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64) - # A list of ranges, range[i] = [start, end] means that the i-th image will - # use tiles[start, end] for cross-attention decoding. - tile_range_for_decode = [] - - seq_start = 0 - tile_start = 0 - - # sparse_mask has an [] entry for each sequence that does not have images, - # but num_tiles does not have these entries... - num_tiles_idx = 0 - for masks, length in zip(sparse_mask, lengths): - if len(masks) == 0: - # Text only - continue - - tiles = num_tiles[num_tiles_idx] - num_tiles_idx += 1 - ts, td = -1, 0 - for mask, tile in zip(masks, tiles): - if len(mask) != 2: - continue - start, end = mask - end = min(end, length) - if end == -1: - end = length - if end == length: - if ts == -1: - ts = tile_start - td += tile - dense_mask[seq_start + start:seq_start + end, - tile_start:tile_start + tile] = 1 - tile_start += tile - assert ts != -1 - assert td != 0 - tile_range_for_decode.append((ts, ts + td)) - seq_start += length - assert num_tiles_idx == len(num_tiles) - - return dense_mask, tile_range_for_decode - - -def convert_dense_cross_attention_mask_to_tensor( - cross_attention_token_mask: np.ndarray, - num_tokens_per_tile: int, - device: torch.device, - dtype: torch.dtype, -) -> torch.Tensor: - mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device) - mask = mask.repeat_interleave(num_tokens_per_tile, dim=1) - - mask = 1.0 - mask - mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min) - - ninf = torch.finfo(dtype).min - full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None]) - mask *= full_text_mask - # (num_prompt_tokens, num_encoder_tokens) - return mask diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index ecbbb5f57bec8..2f0e8a2a5e575 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -106,22 +106,21 @@ class Llama4VisionMLP(nn.Module): use_data_parallel: bool = False, ): super().__init__() - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1( + self.fc1 = ColumnParallelLinear( input_size=input_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) - cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear - self.fc2 = cls_fc2( + self.fc2 = RowParallelLinear( input_size=intermediate_size, output_size=output_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) self.activation_fn = nn.GELU() self.output_activation = output_activation @@ -388,11 +387,10 @@ class Llama4VisionEncoder(nn.Module): ) -> torch.Tensor: r""" Args: - inputs_embeds (`torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to - directly pass an embedded representation. This is useful if you - want more control over how to convert `input_ids` indices into + hidden_states: Input tensor of shape + (batch_size, sequence_length, hidden_size). + Hidden states from the model embeddings, representing + the input tokens. associated vectors than the model's internal embedding lookup matrix. """ @@ -419,20 +417,15 @@ class Llama4UnfoldConvolution(nn.Module): kernel_size = (kernel_size, kernel_size) self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) - params = { - "input_size": - config.num_channels * kernel_size[0] * kernel_size[1], - "output_size": config.hidden_size, - "bias": False, - "quant_config": quant_config, - "prefix": f"{prefix}.linear", - } - if use_data_parallel: - cls = ReplicatedLinear - else: - cls = ColumnParallelLinear - params["gather_output"] = True - self.linear = cls(**params) + self.linear = ColumnParallelLinear( + input_size=config.num_channels * kernel_size[0] * kernel_size[1], + output_size=config.hidden_size, + bias=False, + gather_output=True, + quant_config=quant_config, + prefix=f"{prefix}.linear", + disable_tp=use_data_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.unfold(hidden_states) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 776287589808a..1d5da3139de92 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -306,7 +306,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): self.config = config self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.classifier = nn.Linear(config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype) self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index b2fc7be1af224..2475fe1316097 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -76,20 +76,22 @@ class MolmoImageInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - - nc: Number of crops + - nc: Number of crops (dynamic) - np: Number of patches + - tp: Token sequence positions - pd: Patch dimension """ images: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "np", "pd")] + TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"})] + # Number of crops may vary per batch and image, so pass it as a list. image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]], - TensorShape("bn", "nc", "np")] + TensorShape("bn", "nc", "np", dynamic_dims={"nc"})] - feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "np")] + feat_is_patch: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "tp", dynamic_dims={"nc"})] # A boolean mask indicating which image features correspond to patch tokens. - num_crops: Annotated[torch.Tensor, TensorShape("bn")] @@ -1401,6 +1403,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, config.embedding_size or config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.embedding_size diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 41a2c836b09f3..caa00763fc3d4 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -70,11 +70,15 @@ def multihead_attention( v: torch.Tensor, q_cu_seqlens: Optional[torch.Tensor] = None, k_cu_seqlens: Optional[torch.Tensor] = None, -): +) -> torch.Tensor: """Multi-head attention using flash attention 2. Args: - q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing. q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. The first element should be 0 and the last element should be q.shape[0]. @@ -123,8 +127,14 @@ def sdpa_attention( """SDPA attention. Args: - q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), + q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing. + k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), + or (tot_seqlens, num_heads, head_dim) if packing. + q_cu_seqlens: Optional cumulative sequence lengths of q. + k_cu_seqlens: Optional cumulative sequence lengths of k. """ seq_length = q.shape[0] attention_mask = torch.zeros([1, seq_length, seq_length], @@ -387,7 +397,7 @@ class MLP2(nn.Module): def __init__(self, dims: list[int], activation, - bias=True, + bias: bool = True, prefix: str = "", use_data_parallel: bool = False): super().__init__() diff --git a/vllm/model_executor/models/motif.py b/vllm/model_executor/models/motif.py new file mode 100644 index 0000000000000..153f36dcf1f55 --- /dev/null +++ b/vllm/model_executor/models/motif.py @@ -0,0 +1,345 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/LICENSE +"""Inference-only Motif model compatible with HuggingFace weights.""" +import math +from typing import Any, Optional + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionType +from vllm.attention.selector import _Backend +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .adapters import as_seq_cls_model +from .interfaces import SupportsV0Only +from .utils import extract_layer_index + + +class MotifMLP(nn.Module): + """MLP for the language component of the Motif model, which contains a + MergedColumnParallelLinear merging 2 outputs via PolyNorm activation.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "poly_norm", + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "poly_norm": + raise NotImplementedError(f"Unsupported activation: {hidden_act}. " + "Only poly_norm is supported for now.") + self.act_fn = PolyNorm() + self.intermediate_size = intermediate_size + tp_size = get_tensor_model_parallel_world_size() + if hidden_act == "poly_norm" and tp_size > 1: + raise NotImplementedError( + "Tensor parallelism for poly_norm is not supported yet. " + "Support will be added in the future.") + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn( + x[..., :self.intermediate_size]) * x[..., self.intermediate_size:] + x, _ = self.down_proj(x) + return x + + +class MotifAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = self.hidden_size // self.total_num_heads + self.head_dim = head_dim + # Phi models introduced a partial_rotary_factor parameter in the config + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + assert self.num_heads % 2 == 0, 'num_heads should be even' + assert self.num_kv_heads % 2 == 0, 'num_heads should be even' + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self._init_rotary_emb(config, + rope_scaling=rope_scaling, + quant_config=quant_config) + sliding_window = None + + self.lambda_init = self.lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.subln = RMSNorm(2 * self.head_dim, eps=config.attn_rms_norm_eps) + + params = { + 'differential_flash_attention_config': { + 'lambda_init': self.lambda_init, + 'lambda_q1': self.lambda_q1, + 'lambda_k1': self.lambda_k1, + 'lambda_q2': self.lambda_q2, + 'lambda_k2': self.lambda_k2, + "subln": self.subln, + } + } + + diff_attn_err_msg = ( + 'Set VLLM_ATTENTION_BACKEND="DIFFERENTIAL_FLASH_ATTN" ' + 'to enable Differential Flash Attention.') + try: + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + **params, + ) + except TypeError as e: + raise ValueError(diff_attn_err_msg) from e + assert (self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN + ), diff_attn_err_msg + + def lambda_init_fn(self, depth): + return 0.8 - 0.6 * math.exp(-0.3 * (depth - 1)) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb(self, config: PretrainedConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig]) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, + ) + + +class MotifDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "use_bias", False) + bias_o_proj = attention_bias + if hasattr(config, 'qkv_bias'): + attention_bias = config.qkv_bias + + # By default, Motif uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. parasail-ai/GritLM-7B-vllm) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = MotifAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = MotifMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "use_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +# Motif model uses differential attention +# Only supported in v0 (no chunked prefill support) +class MotifForCausalLM(LlamaForCausalLM, SupportsV0Only): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = MotifDecoderLayer): + + # Prefix caching and chunked prefill is not supported for this model. + assert not vllm_config.cache_config.enable_prefix_caching, \ + "Motif currently does not support prefix caching" + assert not vllm_config.scheduler_config.chunked_prefill_enabled, \ + "Motif currently does not support chunked prefill" + + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=layer_type) + + +MotifForSequenceClassification = as_seq_cls_model(MotifForCausalLM) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py new file mode 100644 index 0000000000000..4f8652c006941 --- /dev/null +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -0,0 +1,1397 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# -------------------------------------------------------- +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/internvl.py +# under Apache-2.0 License +# LICENSE is in root directory. +# -------------------------------------------------------- + +import copy +import warnings +from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal, Optional, TypedDict, TypeVar, Union + +import numpy.typing as npt +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import (BatchEncoding, BatchFeature, PretrainedConfig, + TensorType) + +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import ReLUSquaredActivation +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + MultiModalEmbeddings, + SupportsMultiModal) +from vllm.model_executor.models.internvl import (calculate_internvl_targets, + get_internvl_target_ratios) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM +from vllm.model_executor.models.radio import RadioModel +from vllm.model_executor.models.utils import (flatten_bn, + init_vllm_registered_model, + maybe_prefix, + merge_multimodal_embeddings) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, MultiModalKwargsItems, + NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.radio import RadioConfig +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +# Configure PIL to handle large images without warnings +# This prevents DecompressionBombWarning for legitimate large images +Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely +# Alternative: Set a specific higher limit +# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels + +IMG_START = "<img>" +IMG_END = "</img>" +IMG_CONTEXT = "<image>" + +# Profiling +MAX_FRAMES = 16 + + +class NanoNemotronVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class NanoNemotronVLImageEmbeddinInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +NanoNemotronVLImageInputs = Union[NanoNemotronVLImagePixelInputs, + NanoNemotronVLImageEmbeddinInputs] + + +class NanoNemotronVLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - bvf: Batch size * number of videos * num_frames + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each video frame + - w: Width of each video frame + """ + type: Literal["pixel_values_videos"] + pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] + + +class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - n: Number of videos + - f: Total video feature size + - h: Hidden size (must match the hidden size of language model backbone) + """ + type: Literal["video_embeds"] + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("n", "f", "h")] + + +NanoNemotronVLVideoInputs = Union[NanoNemotronVLVideoPixelInputs, + NanoNemotronVLVideoEmbeddingInputs] + + +def dynamic_preprocess(image, + *, + image_size=512, + max_num_tiles=12, + use_thumbnail=True, + idx=0): + orig_width, orig_height = image.size + + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + blocks, target_width, target_height = calculate_internvl_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False) + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + processed_images = [ + img.convert("RGB") if img.mode != "RGB" else img + for img in processed_images + ] + processed_images = [ + T.Resize((image_size, image_size), + interpolation=T.InterpolationMode.BICUBIC)(img) + for img in processed_images + ] + processed_images = [T.ToTensor()(img) for img in processed_images] + return processed_images + + +def image_to_pixel_values( + image: Image.Image, + *, + input_size: int, + max_num: int, + use_thumbnail: bool, + idx: int, +) -> torch.Tensor: + images = dynamic_preprocess( + image, + image_size=input_size, + max_num_tiles=max_num, + use_thumbnail=use_thumbnail, + idx=idx, + ) + + pixel_values = torch.stack(images) + return pixel_values + + +def video_to_pixel_values( + video: npt.NDArray, + *, + input_size: int, + max_num_tiles: int = 1, + use_thumbnail: bool, +) -> torch.Tensor: + # Convert each frame to a single resized tile tensor consistent + # with image path + frames_tensors: list[torch.Tensor] = [] + for frame in video: + pil_frame = dynamic_preprocess( + Image.fromarray(frame, mode="RGB"), + image_size=input_size, + max_num_tiles=max_num_tiles, + use_thumbnail=use_thumbnail, + idx=0, + ) + # dynamic_preprocess returns tensors already; take the single tile + assert len(pil_frame) >= 1 + frames_tensors.append(pil_frame[0]) + + return torch.stack(frames_tensors) + + +class BaseNanoNemotronVLProcessor(ABC): + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + The code to insert image tokens is based on: + https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252 + """ + + def __init__(self, config: PretrainedConfig, tokenizer: AnyTokenizer, + *args, **kwargs) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + image_size: int = config.force_image_size + patch_size: int = config.patch_size + + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.image_size = image_size + self.use_thumbnail: bool = config.use_thumbnail + self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1) + self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1) + + @property + @abstractmethod + def image_token_id(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + raise NotImplementedError + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + max_num_tiles: int, + ) -> int: + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + num_patches, _, _ = calculate_internvl_targets( + orig_width=image_width, + orig_height=image_height, + target_ratios=target_ratios, + image_size=self.image_size, + use_thumbnail=self.use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + max_num_tiles: int, + ) -> list[torch.Tensor]: + return [ + image_to_pixel_values( + image, + input_size=self.image_size, + max_num=max_num_tiles, + use_thumbnail=self.use_thumbnail, + idx=idx, + ) for idx, image in enumerate(images) + ] + + def _preprocess_image( + self, + text: list[str], + images: list[Image.Image], + max_num_tiles: int, + ) -> tuple[list[str], dict[str, torch.Tensor]]: + if len(images) == 0: + image_inputs = {} + else: + pixel_values_lst = self._images_to_pixel_values_lst( + images, max_num_tiles) + image_inputs: dict[str, NestedTensors] = { + "pixel_values_flat": + torch.cat(pixel_values_lst), + "image_num_patches": + torch.tensor([len(item) for item in pixel_values_lst]), + } + + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + image_repl = self.get_image_repl(feature_size, num_patches) + text = [t.replace('<image>', image_repl.full, 1) for t in text] + return text, image_inputs + + def _make_batch_input(self, + input_item: Optional[Union[Any, list[Any]]] = None): + if input_item is None: + input_item = [] + if not isinstance(input_item, list): + input_item = [input_item] + return input_item + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + max_num_tiles: Optional[int] = None, + ) -> Mapping[str, NestedTensors]: + # Use default if not provided + if max_num_tiles is None: + max_num_tiles = 12 + + text, images = [self._make_batch_input(x) for x in (text, images)] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + max_num_tiles=max_num_tiles, + ) + + text_inputs = self.tokenizer(text, add_special_tokens=False) + + return { + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + } + + +class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): + """ + HF Processor with extended video processing logic. + Code for video processing is adapted from video example: + https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + video_token: Optional[str] = None, + ) -> None: + super().__init__( + config=config, + tokenizer=tokenizer, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + # add extra video token for video processing + self.video_token = video_token + + @property + def supports_video(self) -> bool: + return self.video_token_id is not None + + @property + def video_token_id(self) -> Optional[int]: + if self.video_token is None: + return None + return self.tokenizer.get_vocab().get(self.video_token, None) + + @property + def image_token_id(self) -> int: + return self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT) + + def _videos_to_pixel_values_lst( + self, + videos: list[npt.NDArray], + max_num_tiles: int, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + + return [ + video_to_pixel_values( + video, + input_size=self.image_size, + max_num_tiles=max_num_tiles, + use_thumbnail=self.use_thumbnail, + ) for video in videos + ] + + def _preprocess_video( + self, + text: list[str], + videos: list[npt.NDArray], + max_num_tiles: int, + dynamic_image_size: Optional[bool] = None, + ): + if len(videos) == 0 or not self.supports_video: + video_inputs = {} + else: + pixel_values_lst_video = self._videos_to_pixel_values_lst( + videos, + max_num_tiles=max_num_tiles, + dynamic_image_size=dynamic_image_size, + ) + + video_inputs: dict[str, NestedTensors] = { + "pixel_values_flat_video": + torch.cat(pixel_values_lst_video), + "video_num_patches": + torch.tensor([len(item) for item in pixel_values_lst_video]), + } + + for pixel_values in pixel_values_lst_video: + num_patches = pixel_values.shape[0] + + video_repl = self.get_video_repl(self.num_image_token, + num_patches, self.video_token) + text = [t.replace('<video>', video_repl.full, 1) for t in text] + return text, video_inputs + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + videos: Optional[Union[npt.NDArray, list[npt.NDArray]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + max_num_tiles: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> Mapping[str, NestedTensors]: + # Use default if not provided + if max_num_tiles is None: + max_num_tiles = 12 + + text, images, videos = [ + self._make_batch_input(x) for x in (text, images, videos) + ] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + max_num_tiles=max_num_tiles, + ) + + text, video_inputs = self._preprocess_video( + text=text, + videos=videos, + max_num_tiles=max_num_tiles, + dynamic_image_size=dynamic_image_size, + ) + + text_inputs = self.tokenizer(text, add_special_tokens=False) + + return BatchFeature({ + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + **video_inputs, + }) + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END + + return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) + + def get_video_repl( + self, + feature_size: int, + num_patches: Optional[int] = None, + video_context_token: str = IMG_CONTEXT, + ) -> PromptUpdateDetails[str]: + repl_features = video_context_token * self.num_image_token + repl_features_with_sep = IMG_START + repl_features + IMG_END + # num_patches is equal to num_frames + repl_full = ''.join([ + f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches) + ]) + + return PromptUpdateDetails.select_text(repl_full, video_context_token) + + +class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): + """Basic image-only ProcessingInfo for InternVL-style models.""" + + @abstractmethod + def get_hf_processor( + self, + **kwargs: object, + ) -> BaseNanoNemotronVLProcessor: + raise NotImplementedError + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + max_num_tiles: int, + processor: Optional[BaseNanoNemotronVLProcessor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + max_num_tiles=max_num_tiles, + ) + + def get_image_size_with_most_features(self, + max_num_tiles: int) -> ImageSize: + processor = self.get_hf_processor() + + base_size = processor.image_size + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + largest_feature_size, largest_feature_pinpoint = 0, None + for wr, hr in target_ratios: + width, height = base_size * wr, base_size * hr + + feat_size = self.get_num_image_tokens( + image_width=width, + image_height=height, + max_num_tiles=max_num_tiles, + processor=processor, + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, + height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + def get_max_image_tokens(self) -> int: + processor = self.get_hf_processor() + # Use default max_num_tiles for max tokens calculation + max_num_tiles = 12 + target_width, target_height = self.get_image_size_with_most_features( + max_num_tiles) + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + max_num_tiles=max_num_tiles, + processor=processor, + ) + + +_I = TypeVar("_I", bound=BaseNanoNemotronVLProcessingInfo) + + +class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): + """ ProcessingInfo extended for video processing""" + + @property + def supports_video(self): + return self.get_hf_processor().supports_video + + def get_supported_mm_limits(self): + video_limit = {"video": None} if self.supports_video else {} + return {**super().get_supported_mm_limits(), **video_limit} + + def get_video_token(self) -> Optional[str]: + return IMG_CONTEXT + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + processor = self.get_hf_processor() # we get the CustomProcessor here + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = (seq_len - + max_image_tokens) // processor.num_image_token + max_frames_per_video = max_total_frames // max(max_videos, 1) + + max_frames_per_video = min(max_frames_per_video, MAX_FRAMES) + return max(max_frames_per_video, 1) + + def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor: + return self.ctx.init_processor( + NanoNemotronVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + video_token=self.get_video_token(), + **kwargs, + ) + + +class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): + """Basic image-only MultiModalProcessor for InternVL-style models.""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_token_id = hf_processor.image_token_id + + # Since there may be extra tokens in the feature placeholders, + # we need to pass the image token ID to the model to select the + # tokens to merge from the vision encoder outputs + processed_outputs["image_token_id"] = torch.tensor(image_token_id) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) + num_images = len(image_num_patches) + + return dict( + pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_patches), + image_num_patches=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + image_token_id=MultiModalFieldConfig.shared("image", num_images), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "image_num_patches" in out_mm_kwargs: + image_num_patches = out_mm_kwargs["image_num_patches"] + assert isinstance(image_num_patches, torch.Tensor) + image_num_patches = image_num_patches.tolist() + elif "image_embeds" in out_mm_kwargs: + # to compute num_patches (similar to Qwen2-VL) + image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + else: + image_num_patches = [] + + def get_replacement_custom(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + feature_size = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + # Extract max_num_tiles from kwargs, default to 12 + max_num_tiles = hf_processor_mm_kwargs.get("max_num_tiles", 12) + feature_size = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + max_num_tiles=max_num_tiles, + processor=hf_processor, + ) + + num_patches = None + local_image_num_patches = image_num_patches + if isinstance(local_image_num_patches, torch.Tensor): + local_image_num_patches = local_image_num_patches.tolist() + if isinstance( + local_image_num_patches, + (list, tuple)) and item_idx < len(local_image_num_patches): + num_patches = int(local_image_num_patches[item_idx]) + + return hf_processor.get_image_repl(feature_size, num_patches) + + return [ + PromptReplacement( + modality="image", + target="<image>", + replacement=get_replacement_custom, + ) + ] + + +class NanoNemotronVLMultiModalProcessor( + NanoNemotronBaseVLMultiModalProcessor[NanoNemotronVLProcessingInfo]): + """MultiModalProcessor extended for video support""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> Mapping[str, NestedTensors]: + processed_outputs = super()._call_hf_processor(prompt, mm_data, + mm_kwargs, tok_kwargs) + + hf_processor = self.info.get_hf_processor(**mm_kwargs) + if self.info.supports_video and ( + video_token_id := hf_processor.video_token_id) is not None: + processed_outputs["video_token_id"] = torch.tensor(video_token_id) + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_fields = super()._get_mm_fields_config(hf_inputs, + hf_processor_mm_kwargs) + if self.info.supports_video: + video_num_patches = hf_inputs.get("video_num_patches", + torch.empty(0)) + num_videos = len(video_num_patches) + video_fields = dict( + pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches), + video_num_patches=MultiModalFieldConfig.batched("video"), + video_token_id=MultiModalFieldConfig.shared( + "video", num_videos)) + else: + video_fields = {} + + return image_fields | video_fields + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + prompt_repl = super()._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + out_mm_data = out_mm_kwargs.get_data() + if "video_num_patches" in out_mm_data: + video_num_patches = out_mm_data["video_num_patches"] + assert isinstance(video_num_patches, torch.Tensor) + video_num_patches = video_num_patches.tolist() + else: + video_num_patches = [] + + def get_video_replacement_internvl(item_idx: int): + feature_size = hf_processor.num_image_token + num_patches = video_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + return hf_processor.get_video_repl( + feature_size, + num_patches, + video_context_token=hf_processor.video_token) + + if self.info.supports_video: + prompt_repl = [ + *prompt_repl, + PromptReplacement( + modality="video", + target="<video>", + replacement=get_video_replacement_internvl, + ) + ] + + return prompt_repl + + +class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): + """Basic image-only DummyInputsBuilder for InternVL-style models.""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + return "<image>" * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + # Use default max_num_tiles for dummy data generation + max_num_tiles = 12 + target_width, target_height = ( + self.info.get_image_size_with_most_features(max_num_tiles)) + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class NanoNemotronVLDummyInputsBuilder( + NanoNemotronVLDummyInputsBuilder[NanoNemotronVLProcessingInfo]): + """DummyInputsBuilder extended for video support""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_videos = mm_counts.get("video", 0) + + return super().get_dummy_text(mm_counts) + "<video>" * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + dummy_image = super().get_dummy_mm_data(seq_len=seq_len, + mm_counts=mm_counts) + if self.info.supports_video: + config = self.info.get_hf_config() + image_size: int = config.force_image_size + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + num_videos = mm_counts.get("video", 0) + dummy_video = { + "video": + self._get_dummy_videos(width=image_size, + height=image_size, + num_frames=target_num_frames, + num_videos=num_videos) + } + else: + dummy_video = {} + return {**dummy_image, **dummy_video} + + +@MULTIMODAL_REGISTRY.register_processor( + NanoNemotronVLMultiModalProcessor, + info=NanoNemotronVLProcessingInfo, + dummy_inputs=NanoNemotronVLDummyInputsBuilder, +) +class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid, + SupportsMultiModal): + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<image>" + if modality.startswith("video"): + return "<video>" + return None + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + image_size = config.force_image_size + patch_size = config.patch_size + self.patch_size = patch_size + self.template = config.template + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + self.image_tag_type = config.image_tag_type + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.vision_model = self.get_vit_model_from_radio_config(config).to( + self.language_model.config.torch_dtype) + + # Construct the vision projection. + vit_hidden_size = config.vit_hidden_size + vision_projection_hidden_size = config.projector_hidden_size + llm_hidden_size = config.text_config.hidden_size + + self.mlp1 = nn.Sequential( + RMSNorm(hidden_size=vit_hidden_size * + int(1 / self.downsample_ratio)**2, + eps=1e-5), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio)**2, + vision_projection_hidden_size, + bias=False, + ), + ReLUSquaredActivation(), + nn.Linear(vision_projection_hidden_size, + llm_hidden_size, + bias=False), + ) + self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype) + + self.img_context_token_id = None + self.video_context_token_id = None + self.config = config + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view( + n, + w, + int(h * scale_factor), + int(c / scale_factor), + ) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> + # N, H * scale, W * scale, C // (scale ** 2) + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": + warnings.warn( + "In ps_version 'v1', the height and width have not " + "been swapped back, which results in a transposed image.", + stacklevel=2, + ) + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + vit_embeds = self.vision_model(pixel_values) + vit_embeds = vit_embeds.to(dtype=torch.bfloat16) + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[NanoNemotronVLImageInputs]: + pixel_values_flat = kwargs.pop("pixel_values_flat", None) + image_num_patches = kwargs.pop("image_num_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values_flat is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return NanoNemotronVLImageEmbeddinInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + image_token_id = kwargs["image_token_id"] + assert isinstance(image_token_id, torch.Tensor) + self.img_context_token_id = image_token_id.flatten().unique().item() + + if pixel_values_flat is not None: + if not isinstance(pixel_values_flat, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat)}") + + if not isinstance(image_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(image_num_patches)}") + + pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) + image_num_patches = flatten_bn(image_num_patches, concat=True) + + return NanoNemotronVLImagePixelInputs( + type="pixel_values", + pixel_values_flat=pixel_values_flat, + num_patches=image_num_patches, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: NanoNemotronVLImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + + image_embeds = self.extract_feature(image_input["pixel_values_flat"]) + num_patches = image_input["num_patches"] + + # Only one image in the current batch + if len(num_patches) == 1: + return (image_embeds.view(-1, + self.config.text_config.hidden_size), ) + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in num_patches + ] + return image_embeds.split(image_feature_sizes) + + def _parse_and_validate_video_input( + self, + **kwargs: object) -> Optional[NanoNemotronVLVideoPixelInputs]: + pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None) + video_num_patches = kwargs.pop("video_num_patches", None) + video_embeds = kwargs.pop("video_embeds", None) + + if pixel_values_flat_video is None and video_embeds is None: + return None + + if video_embeds is not None: + return NanoNemotronVLVideoEmbeddingInputs( + type="video_embeds", + data=flatten_bn(video_embeds), + ) + + video_token_id = kwargs["video_token_id"] + assert isinstance(video_token_id, torch.Tensor) + self.video_context_token_id = video_token_id.flatten().unique().item() + + if pixel_values_flat_video is not None: + if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat_video)}") + + if not isinstance(video_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(video_num_patches)}") + + pixel_values_flat_video = flatten_bn(pixel_values_flat_video, + concat=True) + video_num_patches = flatten_bn(video_num_patches, concat=True) + expected_h = expected_w = self.config.force_image_size + resolve_bindings = {"h": expected_h, "w": expected_w} + + return NanoNemotronVLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_flat=pixel_values_flat_video, + num_patches=video_num_patches, + resolve_bindings=resolve_bindings, + ) + + raise AssertionError("This line should be unreachable.") + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values_flat", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_flat_video", + ) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + # Validate the multimodal input keyword arguments + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if modalities is None: + return [] + + # # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_image_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if (multimodal_embeddings is not None + and len(multimodal_embeddings) != 0): + context_token_ids = [ + token_id for token_id in (self.img_context_token_id, + self.video_context_token_id) + if token_id is not None + ] + assert len(context_token_ids) >= 1 + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + context_token_ids, + ) + + return inputs_embeds + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return hidden_states + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="mlp1", + tower_model="vision_model", + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + adapter_dict = dict(self.mlp1.named_parameters()) + + def is_llm(name: str) -> bool: + return name.startswith("language_model") + + def is_adapter_weights(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("mlp1") + + def is_vision_weights(name: str) -> bool: + return name.startswith("vision_model.radio_model.") + + # Separate weights by component + llm_weights = [] + vision_weights = [] + + for name, w in weights: + if is_llm(name): + # Strip 'language_model.' prefix for LLM weights + llm_weights.append((".".join(name.split(".")[1:]), w)) + elif is_adapter_weights((name, w)): + # Load vision-language adapter weights directly + trimmed_name = ".".join(name.split(".")[1:]) + param = adapter_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_vision_weights(name): + # Convert: vision_model.radio_model.* → radio_model.* + hf_key = name[len( + "vision_model."):] # Remove "vision_model." prefix + vision_weights.append((hf_key, w)) + + self.language_model.load_weights(llm_weights) + self.vision_model.load_weights(vision_weights) + + def print_architecture(self, + detailed: bool = True, + save_to_file: str = None): + """ + Print model architecture with parameter names, shapes, and sizes. + + Args: + detailed: If True, show detailed parameter breakdown + save_to_file: If provided, save output to this file path + """ + import sys + from io import StringIO + + # Capture output if saving to file + original_stdout = sys.stdout + if save_to_file: + sys.stdout = StringIO() + + try: + print("=" * 100) + print("NemotronH_Nano_VL Model Architecture") + print("=" * 100) + + total_params = 0 + param_groups = { + "language_model": [], + "vision_model": [], + "mlp1": [], + "other": [], + } + + for name, param in self.named_parameters(): + param_size = param.numel() + total_params += param_size + + # Group parameters by main component + if name.startswith("language_model"): + param_groups["language_model"].append( + (name, param.shape, param_size, param.dtype)) + elif name.startswith("vision_model"): + param_groups["vision_model"].append( + (name, param.shape, param_size, param.dtype)) + elif name.startswith("mlp1"): + param_groups["mlp1"].append( + (name, param.shape, param_size, param.dtype)) + else: + param_groups["other"].append( + (name, param.shape, param_size, param.dtype)) + + if detailed: + print(f"{name:<70} | Shape: {str(param.shape):<25} | " + f"Size: {param_size:>12,} | Dtype: {param.dtype}") + + print("=" * 100) + print("Summary by Component:") + print("-" * 60) + + for component, params in param_groups.items(): + if params: # Only show components that have parameters + component_total = sum(size for _, _, size, _ in params) + percentage = ((component_total / total_params) * + 100 if total_params > 0 else 0) + print(f"{component:<20} | Parameters: {len(params):>4} | " + f"Total Size: {component_total:>15,} | " + f"{percentage:>6.2f}%") + + print("-" * 60) + print(f"{'Total Parameters':<20} | {total_params:>15,}") + + # Estimate memory usage (assuming bfloat16 = 2 bytes per parameter) + memory_mb = total_params * 2 / (1024**2) + memory_gb = memory_mb / 1024 + print(f"{'Est. Memory (MB)':<20} | {memory_mb:>15.2f}") + print(f"{'Est. Memory (GB)':<20} | {memory_gb:>15.2f}") + print("=" * 100) + + # Save to file if requested + if save_to_file: + output = sys.stdout.getvalue() + sys.stdout = original_stdout + with open(save_to_file, "w") as f: + f.write(output) + print(f"Architecture saved to: {save_to_file}") + print(output) # Also print to console + + finally: + if save_to_file and sys.stdout != original_stdout: + sys.stdout = original_stdout + + def get_model_info(self): + """ + Get basic model information as a dictionary. + """ + total_params = sum(p.numel() for p in self.parameters()) + + component_info = {} + for name, param in self.named_parameters(): + component = name.split(".")[0] + if component not in component_info: + component_info[component] = {"params": 0, "size": 0} + component_info[component]["params"] += 1 + component_info[component]["size"] += param.numel() + + return { + "model_name": "NemotronH_Nano_VL", + "total_parameters": total_params, + "memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16 + "components": component_info, + "config": { + "image_size": getattr(self.config, "force_image_size", None), + "patch_size": getattr(self.config, "patch_size", None), + "num_image_token": self.num_image_token, + "downsample_ratio": self.downsample_ratio, + }, + } + + def get_vit_model_from_radio_config(self, hf_config): + hf_config_vision = hf_config.vision_config + model_name = hf_config_vision.args.get("model") + if model_name is None: + raise ValueError(f'Unsupported vit model type: {model_name}') + + preferred_resolution = getattr(hf_config_vision, + "preferred_resolution", None) + image_size = preferred_resolution[0] if preferred_resolution else 224 + patch_size = getattr(hf_config_vision, "patch_size", 16) + + radio_config = RadioConfig( + model_name=model_name, + image_size=image_size, + patch_size=patch_size, + norm_mean=hf_config.norm_mean, + norm_std=hf_config.norm_std, + reg_tokens=(hf_config_vision.args.get("register_multiple") + if hasattr(hf_config_vision, "args") + and isinstance(hf_config_vision.args, dict) else None), + ) + + return RadioModel(config=radio_config) + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return (self.language_model.mamba_cache. + get_seqlen_agnostic_capture_inputs(batch_size)) + + @classmethod + def get_mamba_state_shape_from_config(cls, vllm_config: "VllmConfig"): + text_config = vllm_config.model_config.hf_config.text_config + temp_vllm_config = copy.deepcopy(vllm_config) + temp_vllm_config.model_config.hf_config = text_config + return NemotronHForCausalLM.get_mamba_state_shape_from_config( + temp_vllm_config) + + @classmethod + def get_mamba_state_dtype_from_config(cls, vllm_config: "VllmConfig"): + text_config = vllm_config.model_config.hf_config.text_config + temp_vllm_config = copy.deepcopy(vllm_config) + temp_vllm_config.model_config.hf_config = text_config + return NemotronHForCausalLM.get_mamba_state_dtype_from_config( + temp_vllm_config) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 10adc62d3de38..21f785e4b91af 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -466,6 +466,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 8a563288cb4d6..1e1f0524bd063 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -44,15 +44,16 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.models.utils import ( - AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig @@ -426,38 +427,36 @@ class NemotronHModel(nn.Module): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - attb_params_mapping = { - "q_proj": "q", - "k_proj": "k", - "v_proj": "v", - } + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "embeddings" in name: - name = name.replace("embeddings", "embed_tokens") + if "scale" in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue - if "A_log" in name: - name = name.replace("A_log", "A") - loaded_weight = loaded_weight.to(torch.float32) + # load stacked params + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue - if "D" in name: - loaded_weight = loaded_weight.to(torch.float32) - - if "dt_bias" in name: - loaded_weight = loaded_weight.to(torch.float32) - - # load attn params - if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]): - weight_name = next(proj - for proj in ["q_proj", "k_proj", "v_proj"] - if proj in name) - name = name.replace(weight_name, "qkv_proj") param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, - attb_params_mapping[weight_name]) + weight_loader(param, loaded_weight, shard_id) + break + # load other params else: param = params_dict[name] @@ -471,6 +470,14 @@ class NemotronHModel(nn.Module): class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"backbone": "model"}, + orig_to_new_substr={ + "A_log": "A", + "embeddings": "embed_tokens" + }, + ) + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -558,6 +565,7 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) # Used to track and store by the Mamba cache between steps. self.mamba_cache: Optional[MambaCacheManager] = None @@ -622,10 +630,5 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # update name in weights before passing to loader - updated_weights = [] - for name, loaded_weight in weights: - name = name.replace("backbone", "model") - updated_weights.append((name, loaded_weight)) loader = AutoWeightsLoader(self) - return loader.load_weights(updated_weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index a9c7d8044e10c..acda2027401d9 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -560,7 +560,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image). + # tensor corresponding to a multimodal data item (image). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 71575989565a8..7be3c16528b52 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -364,6 +364,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): config.hidden_size, org_num_embeddings=config.vocab_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index bccd1b87043a5..3e4c580a11211 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -52,10 +52,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( - AutoWeightsLoader, is_pp_missing_parameter, + AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Olmo3Config class Olmo2Attention(nn.Module): @@ -68,7 +69,7 @@ class Olmo2Attention(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) hidden_size = self.config.hidden_size self.tp_size = get_tensor_model_parallel_world_size() @@ -111,14 +112,14 @@ class Olmo2Attention(nn.Module): self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) - # Rotary embeddings. - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, # type: ignore - ) self.scaling = self.head_dim**-0.5 + + layer_idx = extract_layer_index(prefix) + sliding_window = None + if ((layer_types := getattr(self.config, "layer_types", None)) + is not None and layer_types[layer_idx] == "sliding_attention"): + sliding_window = self.config.sliding_window + self.attn = Attention( self.num_heads, self.head_dim, @@ -126,7 +127,20 @@ class Olmo2Attention(nn.Module): num_kv_heads=self.num_kv_heads, cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, - prefix=prefix, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) + + # Rotary embeddings. Rope scaling is only applied on full attention + # layers. + self.rope_scaling = (self.config.rope_scaling + if sliding_window is None else None) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, # type: ignore + rope_scaling=self.rope_scaling, ) # Attention output projection. @@ -176,7 +190,7 @@ class Olmo2MLP(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) hidden_size = config.hidden_size intermediate_size = config.intermediate_size @@ -221,7 +235,7 @@ class Olmo2DecoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) # Attention block. self.self_attn = Olmo2Attention(vllm_config=vllm_config, prefix=f"{prefix}.self_attn") @@ -261,7 +275,7 @@ class Olmo2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - assert isinstance(self.config, Olmo2Config) + assert isinstance(self.config, (Olmo2Config, Olmo3Config)) self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -376,7 +390,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - assert isinstance(config, Olmo2Config) + assert isinstance(config, (Olmo2Config, Olmo3Config)) self.config = config self.model = Olmo2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 9b8525bfadece..892e967e4a21f 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -450,7 +450,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index b92e586f0bf21..365aab205b211 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -375,7 +375,9 @@ class OPTForCausalLM(nn.Module, SupportsPP): self.lm_head = self.model.decoder.embed_tokens else: self.lm_head = ParallelLMHead(config.vocab_size, - config.word_embed_proj_dim) + config.word_embed_proj_dim, + prefix=maybe_prefix( + prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index add751ebf09cc..944a9151d75d3 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -314,7 +314,8 @@ class OrionForCausalLM(nn.Module, SupportsPP): prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index b74a09ee92c33..d6eec77ebcee5 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -12,7 +12,8 @@ from vllm.logger import init_logger from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems) + MultiModalInputs, MultiModalKwargsItems, + MultiModalUUIDDict) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -203,13 +204,13 @@ class PaliGemmaMultiModalProcessor( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 6bdd38d068800..3e854e4d561ff 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -307,7 +307,8 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - bias=False) + bias=False, + prefix=maybe_prefix(prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 789b24eb0f6be..6f39afbecf35b 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -322,7 +322,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=True, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index 6d973a964de04..25df9e9261d91 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -374,8 +374,8 @@ class Phi4MMAudioMeanVarianceNormLayer(nn.Module): Typically used as a very first layer in a model. Args: - input_size: int - layer input size. + config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig) + object containing model parameters. """ def __init__(self, config: Phi4MultimodalAudioConfig): @@ -1350,7 +1350,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index fcdfcb7bc1603..c4548ee168bd7 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -630,6 +630,7 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): # compatibility if not lora_config else lora_config.lora_vocab_padding_size), quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.embedding_bias = None # Used to track and store by the Mamba cache between steps. diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 352ae4064cc61..b3fc55dab6eca 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -989,6 +989,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) @@ -1154,7 +1155,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): return None # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index b5e4d727bf210..a1c452053ddd2 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -7,7 +7,7 @@ #!/usr/bin/env python3 import abc import math -from typing import Literal, Optional +from typing import Any, Literal, Optional, Union import numpy as np import torch @@ -100,7 +100,7 @@ class ConformerEncoderLayer(nn.Module): activation function for glu used in the multihead attention, default "swish". activation_checkpointing: str, optional - a dictionarry of {"module","interval","offload"}, where + a dictionary of {"module","interval","offload"}, where "module": str accept ["transformer", "attention"] to select which module should do activation checkpointing. @@ -131,31 +131,31 @@ class ConformerEncoderLayer(nn.Module): def __init__( self, - d_model=512, - ext_pw_out_channel=0, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - n_head=4, - d_ffn=2048, - ext_pw_kernel_size=1, - kernel_size=3, - dropout_rate=0.1, - causal=False, - batch_norm=False, - activation="relu", - chunk_se=0, - chunk_size=18, - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_inner_dim=-1, - attention_glu_type="swish", - activation_checkpointing="", - export=False, - use_pt_scaled_dot_product_attention=False, + d_model: int = 512, + ext_pw_out_channel: int = 0, + depthwise_seperable_out_channel: int = 256, + depthwise_multiplier: int = 1, + n_head: int = 4, + d_ffn: int = 2048, + ext_pw_kernel_size: int = 1, + kernel_size: int = 3, + dropout_rate: float = 0.1, + causal: bool = False, + batch_norm: bool = False, + activation: str = "relu", + chunk_se: int = 0, + chunk_size: int = 18, + conv_activation: str = "relu", + conv_glu_type: str = "sigmoid", + bias_in_glu: bool = True, + linear_glu_in_convm: bool = False, + attention_inner_dim: int = -1, + attention_glu_type: str = "swish", + activation_checkpointing: str = "", + export: bool = False, + use_pt_scaled_dot_product_attention: bool = False, attn_group_sizes: int = 1, - ): + ) -> None: super().__init__() self.feed_forward_in = FeedForward( @@ -209,24 +209,21 @@ class ConformerEncoderLayer(nn.Module): def forward( self, - x, - pos_k, - pos_v, - mask, + x: torch.Tensor, + pos_k: torch.Tensor, + pos_v: torch.Tensor, + mask: torch.Tensor, relative_attention_bias: Optional[Tensor] = None, - ): + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ConformerEncoder forward. Args: - x: torch.Tensor - input feature of shape (batch, max_time_in, size) - pos_k: torch.Tensor - positional key embedding. - mask: torch.Tensor - mask for x (batch, max_time_in) - relative_attention_bias: Optional[torch.Tensor] - bias added to attention logits w.r.t. relative positions - (1, n_head, time1, time2) + x: input feature of shape (batch, max_time_in, size) + pos_k: positional key embedding. + pos_v: positional value embedding. + mask: mask for x (batch, max_time_in) + relative_attention_bias: bias added to attention logits w.r.t. + relative positions (1, n_head, time1, time2) """ x = x + 0.5 * self.feed_forward_in(x) norm_x = self.layer_norm_att(x) @@ -323,25 +320,25 @@ class TransformerEncoderBase(abc.ABC, nn.Module): def __init__( self, - input_size, - chunk_size, - left_chunk, - attention_dim=256, - attention_heads=4, - input_layer="nemo_conv", - cnn_out=-1, - cnn_layer_norm=False, - time_reduction=4, - dropout_rate=0.0, - padding_idx=-1, - relative_attention_bias_args=None, - positional_dropout_rate=0.0, - nemo_conv_settings=None, + input_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]], + attention_dim: int = 256, + attention_heads: int = 4, + input_layer: str = "nemo_conv", + cnn_out: int = -1, + cnn_layer_norm: bool = False, + time_reduction: int = 4, + dropout_rate: float = 0.0, + padding_idx: int = -1, + relative_attention_bias_args: Optional[dict[str, Any]] = None, + positional_dropout_rate: float = 0.0, + nemo_conv_settings: Optional[dict[str, Any]] = None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", - attention_group_size=1, - encoder_embedding_config=None, - ): + attention_group_size: int = 1, + encoder_embedding_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__() self.input_size = input_size self.input_layer = input_layer @@ -399,7 +396,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module): self.encoder_embedding = MeanVarianceNormLayer( self.encoder_embedding_config["input_size"]) - def compute_lens_change(self, feature_lens): + def compute_lens_change( + self, + feature_lens: Union[int, + torch.Tensor]) -> Union[int, torch.Tensor]: """feature_lens: int return updated feature lens. @@ -433,10 +433,14 @@ class TransformerEncoderBase(abc.ABC, nn.Module): return ceil_func(feature_lens / self.time_reduction) @abc.abstractmethod - def forward(self): + def forward(self) -> Any: """Abstract forward method implementation.""" - def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + def _chunk_size_selection( + self, + chunk_size: Optional[Union[int, list[int]]] = None, + left_chunk: Optional[Union[int, + list[int]]] = None) -> tuple[int, int]: """If chunk size is a list, we will randomly select a chunk size.""" if chunk_size is None: @@ -463,7 +467,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): return chunk_size_train_eff, left_chunk_train_eff - def _get_embed_class(self, embed): + def _get_embed_class(self, embed: nn.Module) -> nn.Module: # pylint: disable=protected-access is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) @@ -474,13 +478,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module): embed_class = embed.module return embed_class - def _forward_embeddings_core(self, input_tensor, masks): + def _forward_embeddings_core( + self, input_tensor: torch.Tensor, + masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: embed_class = self._get_embed_class(self.embed) assert isinstance(embed_class, NemoConvSubsampling) input_tensor, masks = self.embed(input_tensor, masks) return input_tensor, masks - def _position_embedding(self, input_tensor): + def _position_embedding( + self, input_tensor: torch.Tensor + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: pos_k = None pos_v = None if self.relative_attention_bias_layer is None: @@ -488,7 +496,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module): input_tensor) # default to add abs sinusoid embedding return pos_k, pos_v - def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + def _streaming_mask(self, seq_len: int, batch_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]]) -> torch.Tensor: chunk_size_train_eff, left_chunk_train_eff = \ self._chunk_size_selection(chunk_size, left_chunk) @@ -502,11 +512,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module): [batch_size, -1, -1])) return enc_streaming_mask - def forward_embeddings(self, - xs_pad, - masks, - chunk_size_nc=None, - left_chunk_nc=None): + def forward_embeddings( + self, + xs_pad: torch.Tensor, + masks: torch.Tensor, + chunk_size_nc: Optional[Union[int, list[int]]] = None, + left_chunk_nc: Optional[Union[int, list[int]]] = None + ) -> Union[tuple[torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, torch.Tensor], + tuple[torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor]]: """Forwarding the inputs through the top embedding layers Args: @@ -569,7 +585,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): return input_tensor, pos_k, pos_v, hs_mask, masks return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc - def get_offset(self): + def get_offset(self) -> int: """Returns offset used when retaining inputs for decoding. This is essentially, how many additional frames have to be added to @@ -605,8 +621,6 @@ class ConformerEncoder(TransformerEncoderBase): Some examples for the 2 cases: left_chunk = 6 left_chunk = [12, 9, 6, 3] - left_chunk: int - number of chunks used for masking in streaming mode. num_lang: int This parameter is used to store the number of languages in the lang_dict, only used for multiseed/multilingual models. @@ -751,46 +765,46 @@ class ConformerEncoder(TransformerEncoderBase): def __init__( # pylint: disable-all self, - input_size, - chunk_size, - left_chunk, - num_lang=None, - attention_dim=256, - attention_heads=4, - linear_units=2048, - num_blocks=6, - dropout_rate=0.1, - input_layer="nemo_conv", - causal=True, - batch_norm=False, - cnn_out=-1, - cnn_layer_norm=False, - ext_pw_out_channel=0, - ext_pw_kernel_size=1, - depthwise_seperable_out_channel=256, - depthwise_multiplier=1, - chunk_se=0, - kernel_size=3, - activation="relu", - conv_activation="relu", - conv_glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - attention_glu_type="swish", - export=False, - extra_layer_output_idx=-1, - extra_multi_layer_output_idxs=[], # noqa - activation_checkpointing="", - relative_attention_bias_args=None, - time_reduction=4, - use_pt_scaled_dot_product_attention=False, - nemo_conv_settings=None, + input_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]], + num_lang: Optional[int] = None, + attention_dim: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + input_layer: str = "nemo_conv", + causal: bool = True, + batch_norm: bool = False, + cnn_out: int = -1, + cnn_layer_norm: bool = False, + ext_pw_out_channel: int = 0, + ext_pw_kernel_size: int = 1, + depthwise_seperable_out_channel: int = 256, + depthwise_multiplier: int = 1, + chunk_se: int = 0, + kernel_size: int = 3, + activation: str = "relu", + conv_activation: str = "relu", + conv_glu_type: str = "sigmoid", + bias_in_glu: bool = True, + linear_glu_in_convm: bool = False, + attention_glu_type: str = "swish", + export: bool = False, + extra_layer_output_idx: int = -1, + extra_multi_layer_output_idxs: list[int] = [], # noqa + activation_checkpointing: str = "", + relative_attention_bias_args: Optional[dict[str, Any]] = None, + time_reduction: int = 4, + use_pt_scaled_dot_product_attention: bool = False, + nemo_conv_settings: Optional[dict[str, Any]] = None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", - replication_pad_for_subsample_embedding=False, - attention_group_size=1, - encoder_embedding_config=None, - ): + replication_pad_for_subsample_embedding: bool = False, + attention_group_size: int = 1, + encoder_embedding_config: Optional[dict[str, Any]] = None, + ) -> None: super().__init__( input_size, chunk_size, @@ -852,11 +866,13 @@ class ConformerEncoder(TransformerEncoderBase): # the device and the needed dtype: self.register_buffer("dev_type", torch.zeros(()), persistent=False) - def init_relative_attention_bias(self, input_tensor): + def init_relative_attention_bias( + self, input_tensor: torch.Tensor) -> Optional[torch.Tensor]: if self.relative_attention_bias_layer: return self.relative_attention_bias_layer(input_tensor) - def calculate_hs_mask(self, xs_pad, device, mask): + def calculate_hs_mask(self, xs_pad: torch.Tensor, device: torch.device, + mask: Optional[torch.Tensor]) -> torch.Tensor: max_audio_length = xs_pad.shape[1] batch_size = xs_pad.shape[0] enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, @@ -877,7 +893,8 @@ class ConformerEncoder(TransformerEncoderBase): return pad_mask @torch.jit.ignore - def forward(self, xs_pad, masks): + def forward(self, xs_pad: torch.Tensor, + masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Conformer Forward function Args: @@ -997,7 +1014,12 @@ class WindowQformer(nn.Module): if normalize_before else None) self.window_size = window_size - def forward(self, audio_embed, mask, embed_len=None): + def forward( + self, + audio_embed: torch.Tensor, + mask: Optional[torch.Tensor], + embed_len: Optional[int] = None + ) -> tuple[torch.Tensor, Optional[int]]: """forward decoder""" # audio_embed: N x T x D => N x D x T @@ -1042,7 +1064,7 @@ class WindowQformer(nn.Module): class AudioEmbedding(nn.Module): """Image embedding.""" - def __init__(self, config: PretrainedConfig, **kwargs) -> None: + def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: super().__init__() self.config = config # n_embed or hidden_size for text LM @@ -1148,19 +1170,18 @@ class AudioEmbedding(nn.Module): self.input_embeds = None self.audio_embed_sizes = None - def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + def set_audio_embeds(self, input_embeds: torch.Tensor) -> None: self.input_embeds = input_embeds - def set_audio_embed_sizes(self, - audio_embed_sizes: torch.LongTensor) -> None: + def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None: self.audio_embed_sizes = audio_embed_sizes def get_audio_features( self, - input_embeds: torch.FloatTensor, - audio_attention_mask: torch.Tensor = None, + input_embeds: torch.Tensor, + audio_attention_mask: Optional[torch.Tensor] = None, audio_projection_mode: str = "speech", - ) -> torch.FloatTensor: + ) -> torch.Tensor: """ arguments: input_embeds: audio features (B, T, D) B: num audios in a sequence @@ -1214,10 +1235,10 @@ class AudioEmbedding(nn.Module): def forward( self, - audio_features: torch.FloatTensor, - audio_attention_mask: torch.Tensor = None, + audio_features: torch.Tensor, + audio_attention_mask: Optional[torch.Tensor] = None, audio_projection_mode: str = "speech", - ) -> torch.FloatTensor: + ) -> torch.Tensor: """ arguments: audio_features: audio features (T, D) diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index 59535503822d4..6fbfca619a42f 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -16,13 +16,13 @@ from torch import Tensor, nn class BlockBase(nn.Module): """Block abstract module""" - def __init__(self, input_size, output_size): + def __init__(self, input_size: int, output_size: int) -> None: super().__init__() self.input_size = input_size self.output_size = output_size -def get_activation(name="relu"): +def get_activation(name: str = "relu") -> torch.nn.Module: """Select an activation function by name Args: @@ -43,15 +43,18 @@ def get_activation(name="relu"): return nn.Identity() -def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): +def adaptive_enc_mask(x_len: int, + chunk_start_idx: list[int], + left_window: int = 0, + right_window: int = 0) -> torch.Tensor: """ The function is very important for Transformer Transducer Streaming mode Args: - xs_len (int): sequence length - chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. + x_len: sequence length + chunk_start_idx: first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] - left_window (int): how many left chunks can be seen - right_window (int): how many right chunks can be seen. It is used for + left_window: how many left chunks can be seen + right_window: how many right chunks can be seen. It is used for chunk overlap model. Returns: mask (torch.Tensor): a mask tensor for streaming model @@ -172,13 +175,13 @@ class GLUPointWiseConv(nn.Module): def __init__( self, - input_dim, - output_dim, - kernel_size, - glu_type="sigmoid", - bias_in_glu=True, - causal=False, - ): + input_dim: int, + output_dim: int, + kernel_size: int, + glu_type: str = "sigmoid", + bias_in_glu: bool = True, + causal: bool = False, + ) -> None: super().__init__() self.glu_type = glu_type @@ -216,11 +219,10 @@ class GLUPointWiseConv(nn.Module): self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """ Args: - x: torch.Tensor - input tensor + x: input tensor """ # to be consistent with GLULinear, we assume the input always has the # #channel (#dim) in the last dimension of the tensor, so need to @@ -272,12 +274,12 @@ class DepthWiseSeperableConv1d(nn.Module): def __init__( self, - input_dim, - depthwise_seperable_out_channel, - kernel_size, - depthwise_multiplier, - padding=0, - ): + input_dim: int, + depthwise_seperable_out_channel: int, + kernel_size: int, + depthwise_multiplier: int, + padding: int = 0, + ) -> None: super().__init__() self.dw_conv = nn.Conv1d( @@ -301,12 +303,11 @@ class DepthWiseSeperableConv1d(nn.Module): self.pw_conv = nn.Identity() self.depthwise_seperable_out_channel = depthwise_seperable_out_channel - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """ Args: - x: torch.Tensor - input tensor + x: input tensor """ x = self.dw_conv(x) if self.depthwise_seperable_out_channel != 0: @@ -375,23 +376,23 @@ class ConvModule(nn.Module): def __init__( self, - input_dim, - ext_pw_out_channel, - depthwise_seperable_out_channel, - ext_pw_kernel_size, - kernel_size, - depthwise_multiplier, - dropout_rate, - causal=False, - batch_norm=False, - chunk_se=0, - chunk_size=18, - activation="relu", - glu_type="sigmoid", - bias_in_glu=True, - linear_glu_in_convm=False, - export=False, - ): + input_dim: int, + ext_pw_out_channel: int, + depthwise_seperable_out_channel: int, + ext_pw_kernel_size: int, + kernel_size: int, + depthwise_multiplier: int, + dropout_rate: float, + causal: bool = False, + batch_norm: bool = False, + chunk_se: int = 0, + chunk_size: int = 18, + activation: str = "relu", + glu_type: str = "sigmoid", + bias_in_glu: bool = True, + linear_glu_in_convm: bool = False, + export: bool = False, + ) -> None: super().__init__() self.layer_norm = nn.LayerNorm(input_dim) self.input_dim = input_dim @@ -437,7 +438,7 @@ class ConvModule(nn.Module): self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) - def _add_ext_pw_layer(self): + def _add_ext_pw_layer(self) -> None: """ This function is an extension of __init__ function and dedicated to the convolution module creation @@ -497,12 +498,11 @@ class ConvModule(nn.Module): self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """ConvModule Forward. Args: - x: torch.Tensor - input tensor. + x: input tensor. """ x = self.layer_norm(x) @@ -567,21 +567,20 @@ class GLULinear(nn.Module): def __init__( self, - input_dim, - output_dim, - glu_type="sigmoid", - bias_in_glu=True, - ): + input_dim: int, + output_dim: int, + glu_type: str = "sigmoid", + bias_in_glu: bool = True, + ) -> None: super().__init__() self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) self.glu_act = GLU(-1, glu_type) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """GLULinear forward Args: - x: torch.Tensor - inpute tensor. + x: input tensor. """ x = self.linear(x) return self.glu_act(x) @@ -609,12 +608,12 @@ class FeedForward(nn.Module): def __init__( self, - d_model, - d_inner, - dropout_rate, - activation="sigmoid", - bias_in_glu=True, - ): + d_model: int, + d_inner: int, + dropout_rate: float, + activation: str = "sigmoid", + bias_in_glu: bool = True, + ) -> None: super().__init__() self.d_model = d_model self.d_inner = d_inner @@ -628,12 +627,11 @@ class FeedForward(nn.Module): nn.Dropout(dropout_rate), ) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """FeedForward forward function. Args: - x: torch.Tensor - input tensor. + x: input tensor. """ out = self.net(self.layer_norm(x)) @@ -642,14 +640,14 @@ class FeedForward(nn.Module): #### positional encoding starts here def _pre_hook( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, -): + state_dict: dict, + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], +) -> None: """Perform pre-hook in load_state_dict for backward compatibility. Note: @@ -708,10 +706,10 @@ class T5RelativeAttentionLogitBias(nn.Module): """ def __init__(self, - num_heads, - num_buckets=-1, - max_distance=1000, - symmetric=False): + num_heads: int, + num_buckets: int = -1, + max_distance: int = 1000, + symmetric: bool = False) -> None: super().__init__() self.num_heads = num_heads self.num_buckets = num_buckets @@ -727,7 +725,7 @@ class T5RelativeAttentionLogitBias(nn.Module): self.num_buckets *= 2 self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: # instantiate bias compatible with shape of x maxpos = x.size(1) context_position = torch.arange(maxpos, @@ -760,7 +758,7 @@ class T5RelativeAttentionLogitBias(nn.Module): return t5_rel_att_bias - def _bucket_relative_position(self, relative_position): + def _bucket_relative_position(self, relative_position: Tensor) -> Tensor: # this is a placeholder (isn't tested, likely buggy) using HuggingFace # implem as a reference this also needs to be extended to support # asymmetric +/- ve positions @@ -810,7 +808,10 @@ class AbsolutePositionalEncoding(nn.Module): """ - def __init__(self, d_model, dropout_rate, max_len=5000): + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super().__init__() self.d_model = d_model @@ -820,11 +821,11 @@ class AbsolutePositionalEncoding(nn.Module): self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self._register_load_state_dict_pre_hook(_pre_hook) - def extend_pe(self, x): + def extend_pe(self, x: torch.Tensor) -> None: """Reset the positional encodings. Args: - x: torch.Tensor + x: input tensor """ if self.pe is not None and self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: @@ -840,15 +841,14 @@ class AbsolutePositionalEncoding(nn.Module): pe = pe.unsqueeze(0) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: """Add positional encoding. Args: - x: torch.Tensor - Input tensor. shape is (batch, time, ...) + x: Input tensor. shape is (batch, time, ...) Returns: - torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + Encoded tensor. Its shape is (batch, time, ...) """ self.extend_pe(x) @@ -868,7 +868,7 @@ class MeanVarianceNormLayer(nn.Module): layer input size. """ - def __init__(self, input_size): + def __init__(self, input_size: int) -> None: super().__init__() self.input_size = input_size self.global_mean = nn.Parameter(torch.zeros(input_size)) @@ -878,8 +878,7 @@ class MeanVarianceNormLayer(nn.Module): """MeanVarianceNormLayer Forward Args: - input_: torch.Tensor - input tensor. + input_: input tensor. """ return (input_ - self.global_mean) * self.global_invstd @@ -949,7 +948,10 @@ class CausalConv1D(nn.Conv1d): dtype=dtype, ) - def update_cache(self, x, cache=None): + def update_cache( + self, + x: Tensor, + cache: Optional[Tensor] = None) -> tuple[Tensor, Optional[Tensor]]: if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) next_cache = cache @@ -963,7 +965,11 @@ class CausalConv1D(nn.Conv1d): next_cache = next_cache[:, :, -cache.size(-1):] return new_x, next_cache - def forward(self, x, cache=None): + def forward( + self, + x: Tensor, + cache: Optional[Tensor] = None + ) -> Union[Tensor, tuple[Tensor, Optional[Tensor]]]: x, cache = self.update_cache(x, cache=cache) x = super().forward(x) if cache is None: @@ -1017,8 +1023,8 @@ class CausalConv2D(nn.Conv2d): def forward( self, - x, - ): + x: Tensor, + ) -> Tensor: x = F.pad( x, pad=(self._left_padding, self._right_padding, 0, 0), @@ -1062,16 +1068,16 @@ class NemoConvSubsampling(torch.nn.Module): """ def __init__( - self, - feat_in, - feat_out, - subsampling_factor=4, - subsampling="dw_striding", - conv_channels=256, - subsampling_conv_chunking_factor=1, - activation=nn.ReLU(), # noqa: B008 - is_causal=False, - ): + self, + feat_in: int, + feat_out: int, + subsampling_factor: int = 4, + subsampling: str = "dw_striding", + conv_channels: int = 256, + subsampling_conv_chunking_factor: int = 1, + activation: torch.nn.Module = nn.ReLU(), # noqa: B008 + is_causal: bool = False, + ) -> None: super().__init__() self._subsampling = subsampling self._conv_channels = conv_channels @@ -1328,28 +1334,25 @@ class NemoConvSubsampling(torch.nn.Module): self.conv = torch.nn.Sequential(*layers) - def get_sampling_frames(self): + def get_sampling_frames(self) -> list[int]: return [1, self.subsampling_factor] - def get_streaming_cache_size(self): + def get_streaming_cache_size(self) -> list[int]: return [0, self.subsampling_factor + 1] - def forward(self, x, mask): + def forward(self, x: Tensor, + mask: Optional[Tensor]) -> tuple[Tensor, Optional[Tensor]]: """ Forward method for NeMo subsampling. Args: - x[Batch, Time, Filters]: torch.Tensor - input tensor - x_mask: torch.Tensor - input mask + x: input tensor + mask: input mask Returns: - x: torch.Tensor - Resulting tensor from subsampling (B, T // + x: Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) - pad_mask: torch.Tensor - tensor of padded hidden state sequences (B, 1, T // + pad_mask: tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) """ x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) @@ -1403,7 +1406,7 @@ class NemoConvSubsampling(torch.nn.Module): padding_length.size(0), -1) < padding_length.unsqueeze(1) return x, pad_mask.unsqueeze(1) - def reset_parameters(self): + def reset_parameters(self) -> None: # initialize weights if self._subsampling == "dw_striding": with torch.no_grad(): @@ -1433,7 +1436,7 @@ class NemoConvSubsampling(torch.nn.Module): torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) - def conv_split_by_batch(self, x): + def conv_split_by_batch(self, x: Tensor) -> tuple[Tensor, bool]: """Tries to split input by batch, run conv and concat results""" b, _, _, _ = x.size() if b == 1: # can't split if batch size is 1 @@ -1460,7 +1463,7 @@ class NemoConvSubsampling(torch.nn.Module): True, ) - def conv_split_by_channel(self, x): + def conv_split_by_channel(self, x: Tensor) -> Tensor: """For dw convs, tries to split input by time, run conv and concat results""" x = self.conv[0](x) # full conv2D @@ -1500,7 +1503,8 @@ class NemoConvSubsampling(torch.nn.Module): x = self.conv[i * 3 + 4](x) # activation return x - def channel_chunked_conv(self, conv, chunk_size, x): + def channel_chunked_conv(self, conv: torch.nn.Module, chunk_size: int, + x: Tensor) -> Tensor: """Performs channel chunked convolution""" ind = 0 @@ -1541,7 +1545,7 @@ class NemoConvSubsampling(torch.nn.Module): return torch.cat(out_chunks, 1) def change_subsampling_conv_chunking_factor( - self, subsampling_conv_chunking_factor: int): + self, subsampling_conv_chunking_factor: int) -> None: if (subsampling_conv_chunking_factor != -1 and subsampling_conv_chunking_factor != 1 and subsampling_conv_chunking_factor % 2 != 0): @@ -1552,12 +1556,12 @@ class NemoConvSubsampling(torch.nn.Module): self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor -def calc_length(lengths, - all_paddings, - kernel_size, - stride, - ceil_mode, - repeat_num=1): +def calc_length(lengths: Tensor, + all_paddings: int, + kernel_size: int, + stride: int, + ceil_mode: bool, + repeat_num: int = 1) -> Tensor: """Calculates the output length of a Tensor passed through a convolution or max pooling layer""" add_pad: float = all_paddings - kernel_size @@ -1573,11 +1577,11 @@ def calc_length(lengths, class AttModule(nn.Module): """Attention abstraction module""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.export_mode = False - def set_export(self, mode=True): + def set_export(self, mode: bool = True) -> None: """set the export mode""" self.export_mode = mode @@ -1591,14 +1595,10 @@ class AttModule(nn.Module): """AttModule forward Args: - x: torch.Tensor - input tensor. - memory: torch.Tensor, optional - memory tensor. - pos_emb: torch.Tensor, optional - positional encoder embedding. - att_mask: torch.Tensor, optional - attention mask tensor. + x: input tensor. + memory: memory tensor. + pos_emb: positional encoder embedding. + att_mask: attention mask tensor. """ return x, memory, pos_emb, att_mask @@ -1606,15 +1606,15 @@ class AttModule(nn.Module): class AttBlock(BlockBase, AttModule): """Attention Block module to support both Attention and Block module.""" - def memory_dims(self, max_len=False): + def memory_dims(self, max_len: bool = False) -> tuple[int, int]: """memory dimensions""" return (1, self.input_size) def masked_softmax( - scores, + scores: Tensor, mask: Optional[Tensor], -): +) -> Tensor: if mask is not None: mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) scores = scores.masked_fill(mask, -torch.inf) @@ -1636,10 +1636,6 @@ class MultiHeadedAttention(nn.Module): input size features. dropout_rate: float dropout rate. - use_LN: bool - apply layer norm or not - dropout_at_output: bool - whether to apply dropout at output attention_inner_dim: int, optional the attention dimension used in the class, it can be different from the input dimension n_feat. @@ -1666,16 +1662,16 @@ class MultiHeadedAttention(nn.Module): def __init__( self, - n_head, - n_feat, - dropout_rate, - attention_inner_dim=-1, - glu_type="swish", - bias_in_glu=True, - use_pt_scaled_dot_product_attention=False, - n_value=-1, + n_head: int, + n_feat: int, + dropout_rate: float, + attention_inner_dim: int = -1, + glu_type: str = "swish", + bias_in_glu: bool = True, + use_pt_scaled_dot_product_attention: bool = False, + n_value: int = -1, group_size: int = 1, - ): + ) -> None: super().__init__() if n_value == -1: n_value = n_feat @@ -1718,28 +1714,22 @@ class MultiHeadedAttention(nn.Module): query: Tensor, key: Tensor, value: Tensor, - pos_k: Tensor, - pos_v: Tensor, + pos_k: Optional[Tensor], + pos_v: Optional[Tensor], mask: Optional[Tensor], relative_attention_bias: Optional[Tensor] = None, - ): + ) -> Tensor: """Compute 'Scaled Dot Product Attention'. Args: - query: torch.Tensor - query tensor (batch, time1, size) - key: torch.Tensor - key tensor (batch, time2, size) - value: torch.Tensor - value tensor (batch, time1, size) - pos_k: torch.Tensor - key tensor used for relative positional embedding. - pos_v: torch.Tensor - value tensor used for relative positional embedding. - mask: torch.Tensor - mask tensor (batch, time1, time2) - relative_attention_bias: torch.Tensor - bias added to attention logits w.r.t. relative positions + query: query tensor (batch, time1, size) + key: key tensor (batch, time2, size) + value: value tensor (batch, time1, size) + pos_k: key tensor used for relative positional embedding. + pos_v: value tensor used for relative positional embedding. + mask: mask tensor (batch, time1, time2) + relative_attention_bias: bias added to attention logits w.r.t. + relative positions (1, n_head, time1, time2) """ n_batch = query.size(0) @@ -1832,20 +1822,20 @@ class MultiSequential(torch.nn.Sequential): """Multi-input multi-output torch.nn.Sequential""" @torch.jit.ignore - def forward(self, *args): + def forward(self, *args) -> tuple: """Forward method implementation.""" for m in self: args = m(*args) return args -def get_offset(input_layer: str, time_reduction: int): +def get_offset(input_layer: str, time_reduction: int) -> int: """Get an offset. We will use the offset for determining #frames of a subsampled feature. Args: - input_layer (str): Type of an input layer - time_reduction (int): time reduction factor for downsampling a feature + input_layer: Type of an input layer + time_reduction: time reduction factor for downsampling a feature Returns: int: offset """ @@ -1858,13 +1848,14 @@ def get_offset(input_layer: str, time_reduction: int): return 0 -def unfold_tensor(xs_pad, max_seq_len): +def unfold_tensor(xs_pad: Tensor, max_seq_len: int) -> Tensor: """ For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. Args: - xs_pad: N, T, D + xs_pad: input tensor with shape (N, T, D) + max_seq_len: maximum sequence length """ _, _, D = xs_pad.shape xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 15ae081a9f5fc..01d16f1f2c387 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -645,6 +645,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): if not lora_config else lora_config.lora_vocab_padding_size), quant_config=None, bias=True, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index e7f5799a80067..142d3251bc67a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) + MultiModalUUIDDict, NestedTensors) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -316,14 +316,14 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) # NOTE: The tokens are already inserted by the chat template diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 7f70e44b10a6d..ef96d272adfb5 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -3,19 +3,24 @@ """Inference-only PLaMo2 model.""" from collections.abc import Iterable from itertools import islice -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -23,8 +28,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import ( - Mamba2Metadata, prepare_mamba2_metadata) + Mamba2Metadata, prepare_mamba2_metadata, update_metadata) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -39,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsPP, SupportsV0Only) + SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.models.utils import ( @@ -47,8 +55,10 @@ from vllm.model_executor.models.utils import ( make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType +from vllm.utils import LayerBlockType, direct_register_custom_op +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Only used for type hinting. @@ -73,20 +83,6 @@ class Plamo2Config(PretrainedConfig): # type: ignore vocab_size: int -class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore - - def _init_weights(self, module: torch.nn.Module) -> None: - std = 0.02 - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def is_mamba(config: Plamo2Config, i: int) -> bool: assert config.mamba_step > 1 @@ -99,7 +95,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: # Adapted from: # vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 # transformers.models.mamba.modeling_mamba.MambaMixer -class Plamo2MambaMixer(nn.Module): +@CustomOp.register(name="plamo2_mamba_mixer") +class Plamo2MambaMixer(MambaBase, CustomOp): def __init__(self, vllm_config: VllmConfig, @@ -108,6 +105,8 @@ class Plamo2MambaMixer(nn.Module): **kwargs) -> None: super().__init__() self.config = vllm_config.model_config.hf_config + self.cache_config = vllm_config.cache_config + self.model_config = vllm_config.model_config self.quant_config = vllm_config.quant_config self.hidden_size = self.config.hidden_size self.ssm_state_size = self.config.mamba_d_state @@ -115,8 +114,6 @@ class Plamo2MambaMixer(nn.Module): self.intermediate_size = (self.config.mamba_num_heads * self.config.hidden_size_per_head) self.tp_size = get_tensor_model_parallel_world_size() - self.intermediate_size_per_tp_worker = \ - self.intermediate_size // self.tp_size self.head_dim = self.config.hidden_size_per_head self.num_heads = self.config.mamba_num_heads self.time_step_rank = max(64, self.hidden_size // 16) @@ -197,6 +194,22 @@ class Plamo2MambaMixer(nn.Module): self.C_norm = RMSNorm(self.ssm_state_size, eps=self.config.rms_norm_eps) + self.chunk_size = self.config.mamba_chunk_size + + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + assert self.chunk_size != -1, "chunk_size must be set for v1" + + self.prefix = prefix + def _project_ssm_parameters(self, hidden_states): ssm_parameters = self.bcdt_proj(hidden_states) B, C, time_step = torch.split( @@ -212,25 +225,73 @@ class Plamo2MambaMixer(nn.Module): dt = self.dt_proj(time_step) return B, C, dt - def forward( + def forward_native( self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: MambaCacheParams, mamba2_metadata: Mamba2Metadata, **kwargs, - ) -> torch.Tensor: + ): + pass + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + if not envs.VLLM_USE_V1: + CustomOp.forward(self, hidden_states, output, mamba_cache_params, + mamba2_metadata) + else: + torch.ops.vllm.plamo2_mamba_mixer( + hidden_states, + output, + self.prefix, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + + forward_context = get_forward_context() # mamba2_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + mamba2_metadata = attn_metadata + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + else: + conv_state = mamba_cache_params.conv_state + ssm_state = mamba_cache_params.ssm_state + state_indices_tensor = mamba_cache_params.state_indices_tensor - num_prefills = attn_metadata.num_prefills # request count - num_decodes = attn_metadata.num_decode_tokens # token count (=request) - num_prefill_tokens = attn_metadata.num_prefill_tokens # token count - has_prefill = num_prefills > 0 - has_decode = num_decodes > 0 + # Common members between V1 metadata and V0 metadata + if mamba2_metadata is not None: + has_initial_states_p = mamba2_metadata.has_initial_states_p + prep_initial_states = mamba2_metadata.prep_initial_states + chunk_size = mamba2_metadata.chunk_size + seq_idx_p = mamba2_metadata.seq_idx_p + chunk_indices_p = mamba2_metadata.chunk_indices_p + chunk_offsets_p = mamba2_metadata.chunk_offsets_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -240,23 +301,59 @@ class Plamo2MambaMixer(nn.Module): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + hidden_states = (hidden_states.transpose(0, 1).clone().transpose( + 0, 1)).contiguous() + output[:] = self.out_proj(hidden_states) + return + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + num_actual_tokens = num_prefill_tokens + num_decodes + + # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - hidden_states_p, hidden_states_d = torch.split( - hidden_states, - [num_prefill_tokens, num_decodes], - dim=0, - ) - gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decodes], - dim=0) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - mamba_cache_params.state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] - if has_prefill else None) + if envs.VLLM_USE_V1: + hidden_states_d, hidden_states_p = torch.split( + hidden_states[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0, + ) + gate_d, gate_p = torch.split(gate[:num_actual_tokens], + [num_decodes, num_prefill_tokens], + dim=0) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + else: + hidden_states_p, hidden_states_d = torch.split( + hidden_states, + [num_prefill_tokens, num_decodes], + dim=0, + ) + gate_p, gate_d = torch.split(gate, + [num_prefill_tokens, num_decodes], + dim=0) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + + 1] + if has_prefill else None) # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -268,25 +365,38 @@ class Plamo2MambaMixer(nn.Module): dtype=hidden_states.dtype, device=hidden_states.device, ) - preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( - preallocated_ssm_out, - [num_prefill_tokens, num_decodes], - dim=0, - ) + if envs.VLLM_USE_V1: + preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split( + preallocated_ssm_out, + [num_decodes, num_prefill_tokens], + dim=0, + ) + else: + preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( + preallocated_ssm_out, + [num_prefill_tokens, num_decodes], + dim=0, + ) # Process prefill requests if has_prefill: # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" + # pointed to by "state_indices_tensor" + x = hidden_states_p.transpose( + 0, 1) # this is the form that causal-conv see + if mamba2_metadata.cu_seqlen is None: + mamba2_metadata = update_metadata(x, query_start_loc_p, + mamba2_metadata) hidden_states_p = causal_conv1d_fn( - hidden_states_p.transpose(0, 1), + x, conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=mamba2_metadata.has_initial_states, + conv_states=conv_state, + has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, + metadata=mamba2_metadata, query_start_loc=query_start_loc_p) hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p[:num_prefill_tokens] @@ -299,12 +409,12 @@ class Plamo2MambaMixer(nn.Module): # 3. State Space Model sequence transformation initial_states = None - if (mamba2_metadata.has_initial_states is not None - and mamba2_metadata.prep_initial_states): + if has_initial_states_p is not None and prep_initial_states: # making a copy of the states initial_states = torch.where( - mamba2_metadata.has_initial_states[:, None, None, None], - mamba_cache_params.ssm_state[state_indices_tensor_p], 0) + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) + varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size, @@ -313,15 +423,15 @@ class Plamo2MambaMixer(nn.Module): self.A, B.view(1, num_prefill_tokens, 1, -1), C.view(1, num_prefill_tokens, 1, -1), - chunk_size=mamba2_metadata.chunk_size, + chunk_size=chunk_size, D=self.D, z=gate_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim), dt_bias=self.dt_bias, - seq_idx=mamba2_metadata.seq_idx, - chunk_indices=mamba2_metadata.chunk_indices, - chunk_offsets=mamba2_metadata.chunk_offsets, - cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], + seq_idx=seq_idx_p, + chunk_indices=chunk_indices_p, + chunk_offsets=chunk_offsets_p, + cu_seqlens=query_start_loc_p, initial_states=initial_states, return_varlen_states=True, return_final_states=False, @@ -329,18 +439,19 @@ class Plamo2MambaMixer(nn.Module): dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), + state_dtype=ssm_state.dtype, ) # update ssm states # - varlen state is a (batch, nheads, headdim, dstate) tensor - mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state + ssm_state[state_indices_tensor_p] = varlen_state # Process decode requests if has_decode: # 2. Convolution sequence transformation hidden_states_d = causal_conv1d_update( hidden_states_d, - mamba_cache_params.conv_state, + conv_state, conv_weights, self.conv1d.bias, self.activation, @@ -363,8 +474,10 @@ class Plamo2MambaMixer(nn.Module): # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected # using state_indices_tensor_d + + # NOTE: final output is an in-place update of out tensor selective_state_update( - mamba_cache_params.ssm_state, + ssm_state, hidden_states_d, dt, A, @@ -378,11 +491,68 @@ class Plamo2MambaMixer(nn.Module): out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) - assert self.num_heads % self.tp_size == 0 # 4. Final linear projection - out = self.out_proj(preallocated_ssm_out) - return out + output[:num_actual_tokens] = self.out_proj(preallocated_ssm_out) + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba2_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.mamba2_state_shape( + intermediate_size=self.intermediate_size, + tp_world_size=get_tensor_model_parallel_world_size(), + n_groups=0, + num_heads=self.num_heads, + head_dim=self.head_dim, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel_size, + ) + + @property + def mamba_type(self) -> str: + return "mamba2" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import ( + Mamba2AttentionBackend) + return Mamba2AttentionBackend + + +def plamo2_mamba_mixer( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None, + mamba2_metadata=None) + + +def plamo2_mamba_mixer_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="plamo2_mamba_mixer", + op_func=plamo2_mamba_mixer, + mutates_args=["output"], + fake_impl=plamo2_mamba_mixer_fake, + dispatch_key=current_platform.dispatch_key, +) class DenseMLP(nn.Module): @@ -418,7 +588,6 @@ class DenseMLP(nn.Module): return self.down_proj(h) -@support_torch_compile class Plamo2AttentionMixer(nn.Module): def __init__(self, @@ -575,12 +744,24 @@ class Plamo2DecoderLayer(nn.Module): hidden_states, residual = self.pre_mixer_norm( hidden_states, residual) + if self.is_mamba: + # Plamo2MambaMixer writes output to this tensor + output = torch.empty_like(hidden_states) + mixer_kwargs = { + "output": output, + "mamba_cache_params": mamba_cache_params, + "mamba2_metadata": mamba2_metadata, + } + else: + mixer_kwargs = { + "positions": positions, + } hidden_states = self.mixer( - positions=positions, hidden_states=hidden_states, - mamba_cache_params=mamba_cache_params, - mamba2_metadata=mamba2_metadata, + **mixer_kwargs, ) + if self.is_mamba: + hidden_states = output hidden_states = self.post_mixer_norm(hidden_states) # Fully Connected hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) @@ -591,7 +772,7 @@ class Plamo2DecoderLayer(nn.Module): class Plamo2Decoder(torch.nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)} @@ -617,7 +798,7 @@ class Plamo2Decoder(torch.nn.Module): mamba_cache_index = 0 for layer in islice(self.layers, self.start_layer, self.end_layer): layer_mamba_cache_params = None - if layer.is_mamba: + if layer.is_mamba and mamba_cache_params is not None: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( mamba_cache_index) mamba_cache_index += 1 @@ -632,10 +813,11 @@ class Plamo2Decoder(torch.nn.Module): return hidden_states, residual -class Plamo2Model(Plamo2PreTrainedModel): +@support_torch_compile +class Plamo2Model(torch.nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config.model_config.hf_config) + super().__init__() config = vllm_config.model_config.hf_config @@ -653,9 +835,9 @@ class Plamo2Model(Plamo2PreTrainedModel): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") + self.layers = Plamo2Decoder(vllm_config=vllm_config, + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_init() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -679,11 +861,16 @@ class Plamo2Model(Plamo2PreTrainedModel): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + attn_metadata: AttentionMetadata = get_forward_context( + ).attn_metadata + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None hidden_states, residual = self.layers( positions=positions, @@ -701,8 +888,7 @@ class Plamo2Model(Plamo2PreTrainedModel): return hidden_states -class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, - IsHybrid, SupportsV0Only): +class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -712,12 +898,10 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() config = vllm_config.model_config.hf_config scheduler_config = vllm_config.scheduler_config - assert not vllm_config.cache_config.enable_prefix_caching, \ - "PLaMo2 currently does not support prefix caching" - super().__init__(config) self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -751,8 +935,6 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - # Initialize weights and apply final processing - self.post_init() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -763,19 +945,27 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = ( + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba)) - self.mamba_cache = MambaCacheManager( - self.vllm_config, - num_mamba_layers, - *self._get_mamba_cache_shape(), - self.lm_head.weight.dtype, - self.lm_head.weight.dtype, - ) + mamba_state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) + self.mamba_cache = MambaCacheManager(self.vllm_config, + num_mamba_layers, + *mamba_state_shape, + *mamba_state_dtype) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + else: + # NOTE: mamba_cache_params is not needed for v1 + mamba_cache_params = None hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) @@ -788,21 +978,48 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = (self.config.mamba_num_heads * - self.config.hidden_size_per_head) - conv_state_shape = ( - hidden_size // world_size, - self.config.mamba_d_conv - 1, + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, ) - temporal_state_shape = ( - divide(self.config.mamba_num_heads, world_size), - self.config.hidden_size_per_head, - self.config.mamba_d_state, + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size =\ + hf_config.mamba_num_heads * hf_config.hidden_size_per_head + + return MambaStateShapeCalculator.mamba2_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=0, + num_heads=hf_config.mamba_num_heads, + head_dim=hf_config.hidden_size_per_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, ) - return conv_state_shape, temporal_state_shape def compute_logits( self, diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py deleted file mode 100644 index 2edc357d2df1b..0000000000000 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ /dev/null @@ -1,313 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2025 The vLLM team. -# Copyright 2025 IBM. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only IBM/NASA Prithvi Geospatial model.""" - -from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, Union - -import torch -import torch.nn as nn -from transformers import BatchFeature - -from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import AutoWeightsLoader -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, - PlaceholderRange) -from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, - MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptUpdate) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors - -from .interfaces import (IsAttentionFree, MultiModalEmbeddings, - SupportsMultiModal) -from .interfaces_base import default_pooling_type - - -def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): - # This model receives in input a multi-dimensional tensor representing - # a single image patch and therefore it is not to be split - # into multiple elements, but rather to be considered a single one. - # Hence, the decision of using a MultiModalSharedField. - # The expected shape is (num_channels, width, height). - - # This model however allows the user to also submit multiple image - # patches as a batch, adding a further dimension to the above shape. - # At this stage we only support submitting one patch per request and - # batching is achieved via vLLM batching. - # TODO (christian-pinto): enable support for multi patch requests - # in tandem with vLLM batching. - return dict( - pixel_values=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - location_coords=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - ) - - -class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser): - - def _parse_image_data( - self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: - if isinstance(data, dict): - return DictEmbeddingItems( - data, - modality="image", - required_fields={"pixel_values", "location_coords"}, - fields_factory=_prithvi_field_config, - ) - - return super()._parse_image_data(data) - - -class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} - - -class PrithviGeoSpatialMAEInputBuilder( - BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): - - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - return "" - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> MultiModalDataDict: - # This model input is fixed and is in the form of a torch Tensor. - # The size of pixel_values might change in the cases where we resize - # the input but never exceeds the dimensions below. - image_data = { - "pixel_values": torch.full((6, 512, 512), 1.0, - dtype=torch.float16), - "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), - } - - return {"image": image_data} - - -class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): - - def _get_data_parser(self) -> MultiModalDataParser: - return PrithviGeoSpatialMAEMultiModalDataParser() - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return _prithvi_field_config(hf_inputs) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - return [] - - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, - ) -> MultiModalInputs: - if "image" in mm_data: - image_data = mm_data["image"] - else: - image_data = mm_data - mm_data = {"image": mm_data} - - mm_items = self._to_mm_items(mm_data) - tokenization_kwargs = tokenization_kwargs or {} - mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else - self._hash_mm_items(mm_items, hf_processor_mm_kwargs, - tokenization_kwargs)) - mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} - - mm_processed_data = BatchFeature(image_data) - - mm_kwargs = MultiModalKwargsItems.from_hf_inputs( - mm_processed_data, - self._get_mm_fields_config(mm_processed_data, - hf_processor_mm_kwargs), - ) - - return MultiModalInputs( - type="multimodal", - prompt=prompt, - prompt_token_ids=[1], - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, - mm_placeholders=mm_placeholders, - ) - - -@default_pooling_type("All") -@MULTIMODAL_REGISTRY.register_processor( - PrithviGeoSpatialMAEMultiModalProcessor, - info=PrithviGeoSpatialMAEProcessingInfo, - dummy_inputs=PrithviGeoSpatialMAEInputBuilder, -) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): - """Prithvi Masked Autoencoder""" - - supports_multimodal_raw_input_only = True - is_pooling_model = True - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: - if modality.startswith("image"): - return None - - raise ValueError("Only image modality is supported") - - def _instantiate_model(self, config: dict) -> Optional[nn.Module]: - # We might be able/need to support different tasks with this same model - if config["task_args"]["task"] == "SemanticSegmentationTask": - from terratorch.cli_tools import SemanticSegmentationTask - - task = SemanticSegmentationTask( - config["model_args"], - config["task_args"]["model_factory"], - loss=config["task_args"]["loss"], - lr=config["task_args"]["lr"], - ignore_index=config["task_args"]["ignore_index"], - optimizer=config["task_args"]["optimizer"], - optimizer_hparams=config["optimizer_params"], - scheduler=config["task_args"]["scheduler"], - scheduler_hparams=config["scheduler_params"], - plot_on_val=config["task_args"]["plot_on_val"], - freeze_decoder=config["task_args"]["freeze_decoder"], - freeze_backbone=config["task_args"]["freeze_backbone"], - ) - - return task.model - else: - return None - - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - # the actual model is dynamically instantiated using terratorch - # allowing us to perform changes to the model architecture - # at startup time (e.g., change the model decoder class.) - self.model = self._instantiate_model( - vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]) - if self.model is None: - raise ValueError( - "Unsupported task. " - "Only SemanticSegmentationTask is supported for now " - "by PrithviGeospatialMAE.") - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, ) - - def _parse_and_validate_multimodal_data( - self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - pixel_values = kwargs.pop("pixel_values", None) - if not isinstance(pixel_values, torch.Tensor): - raise ValueError(f"Incorrect type of pixel_values. " - f"Got type: {type(pixel_values)}") - - location_coords = kwargs.pop("location_coords", None) - if not isinstance(location_coords, torch.Tensor): - raise ValueError(f"Incorrect type of location_coords. " - f"Got type: {type(location_coords)}") - location_coords = torch.unbind(location_coords, dim=0)[0] - if location_coords.shape == torch.Size([0]): - location_coords = None - - return pixel_values, location_coords - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - # We do not really use any input tokens and therefore no embeddings - # to be calculated. However, due to the mandatory token ids in - # the input prompt we pass one token and the size of the dummy - # embedding tensors must reflect that. - return torch.empty((input_ids.shape[0], 0)) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ): - pixel_values, location_coords = ( - self._parse_and_validate_multimodal_data(**kwargs)) - model_output = self.model(pixel_values, - location_coords=location_coords) - - return model_output.output - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - params_list = [] - model_buffers = dict(self.named_buffers()) - loaded_buffers = [] - for key, value in weights: - if key == "state_dict": - weights_to_parse = value - for name, weight in weights_to_parse.items(): - if "pos_embed" in name: - continue - - if "_timm_module." in name: - name = name.replace("_timm_module.", "") - - # this model requires a couple of buffers to be loaded - # that are not loadable with the AutoWeightsLoader - if name in model_buffers: - if "_timm_module." in name: - name = name.replace("_timm_module.", "") - buffer = model_buffers[name] - weight_loader = getattr(buffer, "weight_loader", - default_weight_loader) - weight_loader(buffer, weight) - loaded_buffers.append(name) - else: - params_list.append((name, weight)) - break - - # Load the remaining model parameters - loader = AutoWeightsLoader(self) - autoloaded_weights = loader.load_weights(params_list) - - return autoloaded_weights.union(set(loaded_buffers)) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e32dc51f00c09..7470948499005 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -271,7 +271,8 @@ class QWenBaseModel(nn.Module): prefix, "transformer")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) if self.config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 54dc0bebd9c5e..e13e87b93429d 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -285,7 +285,7 @@ class Qwen2Model(nn.Module): decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config.get_text_config() cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 5c64c81547e65..a7e71309b6074 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -25,7 +25,7 @@ from collections.abc import Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -41,14 +41,14 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo, - _get_feat_extract_output_lengths) + Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -65,8 +65,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens +from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -79,6 +81,26 @@ except (ImportError, ModuleNotFoundError): logger = init_logger(__name__) +class Qwen2_5OmniAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - na: Number of audios + - nmb: Number of mel bins + - msl: Maximum sequence length + - tsl: Total sequence length + """ + type: Literal["audio_features"] + input_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("nmb", "tsl"), + ] + + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("na", "msl"), + ] + + def create_qwen2_5_omni_thinker_field_factory( spatial_merge_size: int ) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, @@ -529,12 +551,14 @@ class Qwen2_5OmniConditionalGenerationMixin: raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): + if dim == 0: + return mm_input.reshape(-1, *mm_input.shape[2:]) return torch.concat(list(mm_input), dim=dim) else: return torch.concat(mm_input, dim=dim) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]: + self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]: input_audio_features = kwargs.pop('input_audio_features', None) audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) feature_attention_mask = kwargs.pop('feature_attention_mask', None) @@ -548,7 +572,8 @@ class Qwen2_5OmniConditionalGenerationMixin: if not isinstance(input_audio_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio input features. " f"Got type: {type(input_audio_features)}") - return Qwen2AudioFeatureInputs( + return Qwen2_5OmniAudioFeatureInputs( + type="audio_features", input_features=input_audio_features, audio_feature_lengths=audio_feature_lengths, feature_attention_mask=feature_attention_mask) @@ -631,7 +656,7 @@ class Qwen2_5OmniConditionalGenerationMixin: def _process_audio_input( self, - audio_input: Qwen2AudioFeatureInputs, + audio_input: Qwen2_5OmniAudioFeatureInputs, audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: @@ -658,8 +683,8 @@ class Qwen2_5OmniConditionalGenerationMixin: feature_lens=audio_feature_lengths, aftercnn_lens=audio_feat_lengths, ) - audio_features = audio_outputs.last_hidden_state - return audio_features.split(audio_output_lengths.tolist()) + return audio_outputs.last_hidden_state.split( + audio_output_lengths.tolist()) def _process_image_input( self, @@ -705,7 +730,7 @@ class Qwen2_5OmniConditionalGenerationMixin: dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, ) class Qwen2_5OmniThinkerForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, Qwen2_5OmniConditionalGenerationMixin): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -713,6 +738,22 @@ class Qwen2_5OmniThinkerForConditionalGeneration( "thinker.model.": "language_model.model.", "thinker.": "", }) + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "attn.qkv": [ + "attn.q", + "attn.k", + "attn.v", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -807,7 +848,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( return [] # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). + # tensor corresponding to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary @@ -834,7 +875,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( if multimodal_embeddings is not None \ and len(multimodal_embeddings) != 0: - # TODO (ywang96): support overlapping modalitiy embeddings so that + # TODO (ywang96): support overlapping modality embeddings so that # `use_audio_in_video` will work on V1. inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, [ @@ -935,3 +976,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration( mapper=self.hf_to_vllm_mapper) return loaded_weights + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="merger.", + tower_model=["visual.", "audio_tower."]) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index c8f7fc16b4e83..dbf486374bcf3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -27,7 +27,7 @@ """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping from functools import lru_cache, partial -from typing import Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -48,9 +49,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - MergedReplicatedLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) # yapf: enable from vllm.model_executor.layers.quantization import QuantizationConfig @@ -65,6 +64,8 @@ from vllm.multimodal.utils import run_dp_sharded_mrope_vision_model from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope +from vllm.utils import is_pin_memory_available +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -81,84 +82,125 @@ logger = init_logger(__name__) # === Vision Inputs === # -class Qwen2_5_VLImagePixelInputs(TypedDict): +class Qwen2_5_VLImagePixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + + Historical context: + - pixel_values shape: (num_patches, num_channels * patch_size * + patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + formatnum_channels * patch_size * patch_size + """ type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` + + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", "cps"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + + Historical context: + - image_embeds shape: (num_image_features, hidden_size) + - num_image_features varies based on the number and resolution of the + images. + - hidden_size must match the hidden size of language model backbone. + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ - - -class Qwen2_5_VLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. - """ + image_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs] -class Qwen2_5_VLVideoPixelInputs(TypedDict): +class Qwen2_5_VLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - np: Number of patches + - nv: Number of videos + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size + + Historical context: + - pixel_values_videos shape: (num_patches, num_channels * + temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format + - second_per_grid_ts: The video time interval (in seconds) for each + grid along the temporal dimension in the 3D position IDs. Returned + when `videos` is not `None`. + """ type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` + + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", "ctps"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] + + second_per_grid_ts: Annotated[ + Optional[torch.Tensor], + TensorShape("nv"), + ] + + +class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): """ - - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos + + Historical context: + - video_embeds shape: (num_video_features, hidden_size) + - num_video_features varies based on the number and resolution of the + videos. + - hidden_size must match the hidden size of language model backbone. + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ - - second_per_grid_ts: torch.Tensor - """ - The video time interval (in seconds) for each grid along the temporal - dimension in the 3D position IDs. Returned when `videos` is not `None`. - """ - - -class Qwen2_5_VLVideoEmbeddingInputs(TypedDict): type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. - """ + video_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, @@ -178,22 +220,20 @@ class Qwen2_5_VisionMLP(nn.Module): prefix: str = "", use_data_parallel: bool = False): super().__init__() - cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else - MergedColumnParallelLinear) - self.gate_up_proj = cls_gate_up_proj( + self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") + prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel) - cls_down_proj = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.down_proj = cls_down_proj(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.down_proj = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel) self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -243,33 +283,33 @@ class Qwen2_5_VisionAttention(nn.Module): self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) - if use_data_parallel: - self.qkv = ReplicatedLinear(embed_dim, - self.hidden_size_per_attention_head * - 3 * num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv") + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel) - else: - self.qkv = QKVParallelLinear( - hidden_size=embed_dim, - head_size=self.hidden_size_per_attention_head, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - - cls_proj = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.proj = cls_proj(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -323,14 +363,19 @@ class Qwen2_5_VisionAttention(nn.Module): q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -345,8 +390,8 @@ class Qwen2_5_VisionAttention(nn.Module): causal=False) context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -365,6 +410,8 @@ class Qwen2_5_VisionAttention(nn.Module): output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -375,8 +422,8 @@ class Qwen2_5_VisionAttention(nn.Module): context_layer = xops.memory_efficient_attention_forward( q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -480,32 +527,32 @@ class Qwen2_5_VisionPatchMerger(nn.Module): norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.mlp = nn.ModuleList([ - cls_fc1(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), + self.mlp = nn.Sequential( + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + return_bias=False, + disable_tp=use_data_parallel, + ), nn.GELU(), - cls_fc2(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), - ]) + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + return_bias=False, + disable_tp=use_data_parallel, + ), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) x = x.view(-1, self.hidden_size) - - mlp_fc1, mlp_act, mlp_fc2 = self.mlp - x_parallel, _ = mlp_fc1(x) - x_parallel = mlp_act(x_parallel) - out, _ = mlp_fc2(x_parallel) + out = self.mlp(x) return out @@ -599,7 +646,12 @@ class Qwen2_5_VisionTransformer(nn.Module): prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -687,6 +739,15 @@ class Qwen2_5_VisionTransformer(nn.Module): seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens + @staticmethod + def invert_permutation(perm: torch.Tensor) -> torch.Tensor: + # building the inverse permutation in O(n) time + inv = torch.empty_like(perm, pin_memory=is_pin_memory_available()) + inv[perm] = torch.arange(perm.numel(), + device=perm.device, + dtype=perm.dtype) + return inv + def forward( self, x: torch.Tensor, @@ -730,6 +791,8 @@ class Qwen2_5_VisionTransformer(nn.Module): rotary_pos_emb = torch.cat(rotary_pos_emb) window_index = torch.cat(window_index) + # compute reverse indices + reverse_indices = self.invert_permutation(window_index) cu_window_seqlens = torch.cat(cu_window_seqlens) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) cu_seqlens = torch.cat(cu_seqlens) @@ -750,6 +813,8 @@ class Qwen2_5_VisionTransformer(nn.Module): non_blocking=True) window_index = window_index.to(device=hidden_states.device, non_blocking=True) + reverse_indices = reverse_indices.to(device=hidden_states.device, + non_blocking=True) hidden_states = hidden_states.reshape( seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) @@ -783,7 +848,6 @@ class Qwen2_5_VisionTransformer(nn.Module): # adapter hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] return hidden_states @@ -929,7 +993,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) @@ -948,10 +1012,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return Qwen2_5_VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) @@ -962,9 +1022,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -985,7 +1042,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_values_videos, "video pixel values") video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - + if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2: + second_per_grid_ts = second_per_grid_ts.squeeze(-1) return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -999,9 +1057,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, @@ -1162,21 +1217,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Flattened (concatenated) position ids corresponding to a - batch. - **NOTE**: If mrope is enabled (default setting for Qwen2.5-VL - opensource models), the shape will be `(3, seq_len)`, + batch. **NOTE**: If mrope is enabled (default setting for + Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. - second_per_grid_ts: Tensor `(num_videos)` of video time interval ( - in seconds) for each grid along the temporal dimension in the - 3D position IDs. `None` if no videos are passed. """ if intermediate_tensors is not None: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 86b4a9a018c76..c797b71b5d2e1 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch import torch.nn as nn @@ -47,6 +47,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, init_vllm_registered_model, @@ -54,21 +55,38 @@ from .utils import (AutoWeightsLoader, init_vllm_registered_model, # # === Audio Inputs === # -class Qwen2AudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - input_features: torch.Tensor - """Shape: `(num_audios, num_mel_bins, 3000)`""" - - feature_attention_mask: torch.Tensor - """Shape: `(num_audios, 3000)`""" - - -class Qwen2AudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - audio_embeds: list[torch.Tensor] - """Shape: `(num_audio_features, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class Qwen2AudioFeatureInputs(TensorSchema): """ + Dimensions: + - na: Number of audios + - nmb: Number of mel bins + """ + type: Literal["audio_features"] + input_features: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("na", "nmb", 3000), + ] + + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("na", 3000), + ] + + +class Qwen2AudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size + - naf: Number of audio features + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + list[torch.Tensor], + TensorShape("bn", "naf", "hs"), + ] Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] @@ -324,7 +342,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, *mm_input.shape[2:]) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 5551ad8c32329..6c6276a930453 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -72,17 +72,20 @@ class Qwen2MoeMLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, quant_config=quant_config, - reduce_results=reduce_results) + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -123,7 +126,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module): self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, bias=False, - quant_config=None) + quant_config=None, + prefix=f"{prefix}.gate") if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( hidden_size=config.hidden_size, @@ -132,6 +136,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): quant_config=quant_config, reduce_results=self.experts.must_reduce_shared_expert_outputs( ), + prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None @@ -203,21 +208,19 @@ class Qwen2MoeAttention(nn.Module): self.max_position_embeddings = max_position_embeddings self.dual_chunk_attention_config = dual_chunk_attention_config - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=True, - quant_config=quant_config, - ) + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") self.rotary_emb = get_rope( self.head_dim, @@ -296,12 +299,11 @@ class Qwen2MoeDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.mlp") else: - self.mlp = Qwen2MoeMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) + self.mlp = Qwen2MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -519,7 +521,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 421b43563bade..2bd9d2b52628a 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -53,15 +53,18 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): self.quant_config = quant_config self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.head_dtype = vllm_config.model_config.head_dtype self.score = nn.Sequential( ColumnParallelLinear(config.hidden_size, config.hidden_size, quant_config=quant_config, + params_dtype=self.head_dtype, return_bias=False), nn.ReLU(), RowParallelLinear(config.hidden_size, config.num_labels, + params_dtype=self.head_dtype, quant_config=quant_config, return_bias=False), ) @@ -80,6 +83,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + hidden_states = hidden_states.to(self.head_dtype) logits = self.score(hidden_states) return logits diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ae7a8d8d7a5b9..b6576b783b64a 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -26,7 +26,7 @@ """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import torch import torch.nn as nn @@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( Qwen2VLVideoProcessor) +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -70,6 +71,7 @@ from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -81,83 +83,124 @@ from .vision import get_vit_attn_backend logger = init_logger(__name__) # For profile run -_MAX_FRAMES_PER_VIDEO = 16 +_MAX_FRAMES_PER_VIDEO = 600 # === Vision Inputs === # -class Qwen2VLImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` +class Qwen2VLImagePixelInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - -class Qwen2VLImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * patch_size * patch_size - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. + Historical context: + - pixel_values shape: (num_patches, num_channels * patch_size * + patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ + type: Literal["pixel_values"] - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", "cps"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +class Qwen2VLImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + + Historical context: + - image_embeds shape: (num_image_features, hidden_size) + - num_image_features varies based on the number and resolution of the + images. + - hidden_size must match the hidden size of language model backbone. + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format + """ + type: Literal["image_embeds"] + + image_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageEmbeddingInputs] -class Qwen2VLVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` +class Qwen2VLVideoPixelInputs(TensorSchema): """ - - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - - This should be in `(grid_t, grid_h, grid_w)` format. - """ - - -class Qwen2VLVideoEmbeddingInputs(TypedDict): - type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). + Dimensions: + - np: The total number of patches over each video over each prompt in + the batch + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size + - nv: Number of videos - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. + Historical context: + - pixel_values_videos shape: (num_patches, num_channels * + temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ + type: Literal["pixel_values_videos"] - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", "ctps"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] + + +class Qwen2VLVideoEmbeddingInputs(TensorSchema): """ + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos + + Historical context: + - video_embeds shape: (num_video_features, hidden_size) + - num_video_features varies based on the number and resolution of the + videos. + - hidden_size must match the hidden size of language model backbone. + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format + """ + type: Literal["video_embeds"] + + video_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, @@ -272,7 +315,16 @@ class Qwen2VisionAttention(nn.Module): prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -332,7 +384,10 @@ class Qwen2VisionAttention(nn.Module): if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -586,7 +641,12 @@ class Qwen2VisionTransformer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.merger", ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -915,12 +975,9 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): seq_len: int, mm_counts: Mapping[str, int], ) -> int: - max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) - max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len) max_frames_per_video = min(max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO) @@ -1110,7 +1167,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError(f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) @@ -1129,10 +1186,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - return Qwen2VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw) @@ -1143,9 +1196,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw") - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return Qwen2VLImageEmbeddingInputs(type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw) @@ -1177,9 +1227,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw") - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return Qwen2VLVideoEmbeddingInputs(type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw) @@ -1189,6 +1236,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"] @@ -1198,15 +1246,17 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() - return image_embeds.split(sizes.tolist()) + return image_embeds.split(sizes) def _process_video_input( self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"] @@ -1216,9 +1266,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() - return video_embeds.split(sizes.tolist()) + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} @@ -1321,15 +1372,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. + otherwise it will be `(seq_len,)`. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ if intermediate_tensors is not None: diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index a7e0a00350e6b..029309c49efd4 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -170,8 +170,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape + assert hidden_states.dim( + ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" + is_input_1d = hidden_states.dim() == 1 hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) @@ -180,7 +181,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - return final_hidden_states.view(orig_shape) + # return to 1d if input is 1d + return final_hidden_states.squeeze(0) if is_input_1d else \ + final_hidden_states class Qwen3MoeAttention(nn.Module): @@ -375,7 +378,7 @@ class Qwen3MoeModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config.get_text_config() cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config @@ -602,7 +605,8 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) @@ -698,4 +702,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() + return self.model.get_expert_mapping() \ No newline at end of file diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py new file mode 100644 index 0000000000000..0c974ee44eee2 --- /dev/null +++ b/vllm/model_executor/models/qwen3_next.py @@ -0,0 +1,1302 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next model.""" +from collections.abc import Iterable +from itertools import islice +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from transformers.activations import ACT2FN + +from vllm import envs +from vllm.attention import Attention, AttentionBackend, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, + VllmConfig, get_current_vllm_config) +from vllm.distributed import (divide, get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fla.ops import ( + RMSNormGated, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) +from vllm.model_executor.layers.fused_moe import FusedMoE +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.layers.layernorm import ( + GemmaRMSNorm as Qwen3NextRMSNorm) +# yapf: enable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + mamba_v2_sharded_weight_loader) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig +from vllm.triton_utils import tl, triton +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + +from .interfaces import (HasInnerState, IsHybrid, MixtureOfExperts, + SupportsLoRA, SupportsPP) +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +class Qwen3NextSparseMoeBlock(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts = config.num_experts + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE(num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_expert = None + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, + 1, + bias=False) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + if self.shared_expert_gate is not None: + shared_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_output + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + return GDNAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + self.model_config.dtype, self.cache_config.mamba_cache_dtype) + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.gated_delta_net_state_shape( + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + use_v1=True) + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = extract_layer_index(prefix) + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.layer_norm_epsilon = config.rms_norm_eps + self.prefix = prefix + + self.config = config + self.model_config = model_config + self.cache_config = cache_config + self.quant_config = quant_config + self.speculative_config = speculative_config + self.num_spec = (self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + prefix=f"{prefix}.conv1d", + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # projection of the input hidden states + self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + self.projection_size_ba = self.num_v_heads * 2 + self.in_proj_qkvz = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_qkvz, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkvz", + ) + # ba_proj doesn't support blockwise fp8 quantization. + self.in_proj_ba = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.projection_size_ba, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_ba", + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, { + "weight_loader": + mamba_v2_sharded_weight_loader([ + query_key_settings, + query_key_settings, + value_settings, + ], self.tp_size, self.tp_rank) + }) + + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter( + torch.ones(self.num_v_heads // self.tp_size), ) + self.A_log = nn.Parameter( + torch.empty( + divide(self.num_v_heads, self.tp_size), + dtype=torch.float32, + )) + + set_weight_attrs(self.A_log, + {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + + self.out_proj = RowParallelLinear(self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj") + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def fix_query_key_value_ordering( + self, + mixed_qkvz, + mixed_ba, + ): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + (self.head_k_dim + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) * self.num_v_heads // + self.num_k_heads), + ) + new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.tp_size, + 2 * self.num_v_heads // self.num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [ + self.num_v_heads // self.num_k_heads, + self.num_v_heads // self.num_k_heads + ] + + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], + # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + (query, key, value, z) = torch.split(mixed_qkvz, + split_arg_list_qkvz, + dim=2) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, self.head_v_dim) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) + a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) + + return query, key, value, z, b, a + + def rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim // self.tp_size, + self.key_dim // self.tp_size, + self.value_dim // self.tp_size, + ], + dim=-1, + ) + query, key = map( + lambda x: rearrange(x, 'l (h d) -> 1 l h d', d=self.head_k_dim), + (query, key)) + value = rearrange(value, 'l (h d) -> 1 l h d', d=self.head_v_dim) + return query, key, value + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + cache_params: Optional[MambaCacheParams] = None, + ): + return torch.ops.vllm.gdn_attention( + hidden_states, + output, + self.prefix, + ) + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + + if attn_metadata is None: + # V1 profile run + return + + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + conv_metadata = attn_metadata + assert isinstance(attn_metadata, GDNAttentionMetadata) + has_initial_state = attn_metadata.has_initial_state + spec_query_start_loc = attn_metadata.spec_query_start_loc + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + spec_sequence_masks = attn_metadata.spec_sequence_masks + spec_token_masks = attn_metadata.spec_token_masks + spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + num_accepted_tokens = attn_metadata.num_accepted_tokens + if spec_token_masks is not None: + spec_token_masks = spec_token_masks[:num_actual_tokens] + + # 1. Set up dimensions for reshapes later + projected_states_qkvz, _ = self.in_proj_qkvz( + hidden_states[:num_actual_tokens]) + projected_states_ba, _ = self.in_proj_ba( + hidden_states[:num_actual_tokens]) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba) + query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), + (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + mixed_qkv_spec = mixed_qkv + mixed_qkv_non_spec = None + else: + mixed_qkv_spec = mixed_qkv[spec_token_masks] + mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + else: + mixed_qkv_spec = None + mixed_qkv_non_spec = mixed_qkv + + # 2.1: process the mutli-query part + if spec_sequence_masks is not None: + mixed_qkv_spec = causal_conv1d_update( + mixed_qkv_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=spec_state_indices_tensor[:, 0] + [:attn_metadata.num_spec_decodes], + num_accepted_tokens=num_accepted_tokens, + query_start_loc=spec_query_start_loc, + max_query_len=spec_state_indices_tensor.size(-1), + validate_data=False, + ) + + # 2.2: process the remaining part + if attn_metadata.num_prefills > 0: + mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) + if conv_metadata.cu_seqlen is None: + conv_metadata = update_metadata(mixed_qkv_non_spec_T, + non_spec_query_start_loc, + conv_metadata) + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "mamba_cache_params.state_indices_tensor" + mixed_qkv_non_spec = causal_conv1d_fn( + mixed_qkv_non_spec_T, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=non_spec_state_indices_tensor, + query_start_loc=non_spec_query_start_loc, + metadata=conv_metadata, + ).transpose(0, 1) + elif attn_metadata.num_decodes > 0: + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv_non_spec, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:attn_metadata + .num_decodes], + validate_data=True, + ) + else: + mixed_qkv_non_spec = None + + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( + mixed_qkv_spec) + query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( + mixed_qkv_non_spec) + + beta = b.sigmoid() + # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + g = fused_gdn_gating(self.A_log, a, self.dt_bias) + g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta)) + + if spec_sequence_masks is not None: + if (attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes == 0): + g_spec = g + beta_spec = beta + g_non_spec = None + beta_non_spec = None + else: + g_spec = g[:, spec_token_masks] + beta_spec = beta[:, spec_token_masks] + g_non_spec = g[:, ~spec_token_masks] + beta_non_spec = beta[:, ~spec_token_masks] + else: + g_spec = None + beta_spec = None + g_non_spec = g + beta_non_spec = beta + + # 3. Recurrent attention + + # 3.1: process the mutlti-query part + if spec_sequence_masks is not None: + core_attn_out_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[:attn_metadata. + num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_spec, last_recurrent_state = None, None + + # 3.2: process the remaining part + if attn_metadata.num_prefills > 0: + initial_state = ssm_state[ + non_spec_state_indices_tensor].contiguous() + initial_state[~has_initial_state, ...] = 0 + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Init cache + ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( + ssm_state.dtype) + elif attn_metadata.num_decodes > 0: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[:attn_metadata. + num_decodes + 1], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + )) + else: + core_attn_out_non_spec, last_recurrent_state = None, None + + # Merge core attention output + if (spec_sequence_masks is not None + and core_attn_out_non_spec is not None): + core_attn_out = torch.empty( + (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), + dtype=core_attn_out_non_spec.dtype, + device=core_attn_out_non_spec.device, + ) + core_attn_out[:, spec_token_masks] = core_attn_out_spec + core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + elif spec_sequence_masks is not None: + core_attn_out = core_attn_out_spec + else: + core_attn_out = core_attn_out_non_spec + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') + + output[:num_actual_tokens], _ = self.out_proj(core_attn_out) + + +class Qwen3NextAttention(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None) + self.attn_output_gate = getattr(config, "attn_output_gate", True) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads * (1 + self.attn_output_gate), + self.total_num_kv_heads, + bias=getattr(config, "qkv_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + rope_scaling=config.rope_scaling, + partial_rotary_factor=config.partial_rotary_factor, + dual_chunk_attention_config=self.dual_chunk_attention_config, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": + self.dual_chunk_attention_config, + } if self.dual_chunk_attention_config else {}, + ) + + self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + output: torch.Tensor, + hidden_states: torch.Tensor, + ): + qkv, _ = self.qkv_proj(hidden_states) + + if self.attn_output_gate: + q_gate, k, v = qkv.split( + [self.q_size * 2, self.kv_size, self.kv_size], dim=-1) + orig_shape = q_gate.shape[:-1] + q_gate = q_gate.view(*orig_shape, self.num_heads, -1) + q, gate = torch.chunk(q_gate, 2, dim=-1) + q = q.reshape(*orig_shape, -1) + gate = gate.reshape(*orig_shape, -1) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( + -1, self.num_heads * self.head_dim) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( + -1, self.num_kv_heads * self.head_dim) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + + if self.attn_output_gate: + gate = torch.sigmoid(gate) + attn_output = attn_output * gate + + output[:], _ = self.o_proj(attn_output) + + +class Qwen3NextDecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + layer_type: str, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.config = config + + self.layer_type = layer_type + self.layer_idx = extract_layer_index(prefix) + + if self.layer_type == "linear_attention": + self.linear_attn = Qwen3NextGatedDeltaNet( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=f'{prefix}.linear_attn') + elif self.layer_type == "full_attention": + self.self_attn = Qwen3NextAttention( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f'{prefix}.self_attn', + ) + else: + raise ValueError(f"Invalid layer_type {self.layer_type}") + + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (self.layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (self.layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3NextSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = Qwen3NextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + self.layer_scale = getattr(config, "layer_scale", False) + if self.layer_scale: + self.attn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + self.config.hidden_size, + dtype=config.torch_dtype, + ), ) + self.ffn_layer_scale = torch.nn.Parameter( + torch.zeros( + 1, + 1, + self.config.hidden_size, + dtype=config.torch_dtype, + ), ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + positions: torch.Tensor = None, + **kwargs: object, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + self_attention_output = torch.empty_like(hidden_states) + if self.layer_type == "linear_attention": + self.linear_attn( + hidden_states=hidden_states, + output=self_attention_output, + ) + elif self.layer_type == "full_attention": + self.self_attn( + hidden_states=hidden_states, + output=self_attention_output, + positions=positions, + ) + else: + raise ValueError("Invalid layer_type") + hidden_states = self_attention_output + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + hidden_states = hidden_states * ( + self.attn_layer_scale.to(hidden_states.dtype) + 1) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if self.layer_scale: + if len(hidden_states.shape) == 2: + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1) + else: + assert len(hidden_states.shape) == len( + self.ffn_layer_scale.shape + ), f'shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}' # noqa: E501 + hidden_states = hidden_states * ( + self.ffn_layer_scale.to(hidden_states.dtype) + 1) + + return hidden_states, residual + + +@support_torch_compile +class Qwen3NextModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: Qwen3NextConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + lora_config = vllm_config.lora_config + speculative_config = vllm_config.speculative_config + enable_eplb = parallel_config.enable_eplb + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + return Qwen3NextDecoderLayer( + config, + layer_type=config.layer_types[extract_layer_index(prefix)], + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + speculative_config=speculative_config, + prefix=prefix, + enable_eplb=enable_eplb, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if name.startswith("mtp."): + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # name = apply_attn_prefix(name, params_dict) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + MixtureOfExperts, IsHybrid): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3Next currently does not support prefix caching" + assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1" + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = Qwen3NextModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head")) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # Set MoE hyperparameters + self.expert_weights = [] + + self.moe_layers: list[FusedMoE] = [] + example_layer = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Qwen3NextDecoderLayer) + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + example_layer = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_layer is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + return hidden_states + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.gated_delta_net_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + tp_size = parallel_config.tensor_parallel_size + num_spec = (vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0) + return MambaStateShapeCalculator.gated_delta_net_state_shape( + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + use_v1=True) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["mtp."], + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +def gdn_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, output=output) + + +def gdn_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="gdn_attention", + op_func=gdn_attention, + mutates_args=["output"], + fake_impl=gdn_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, + (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty_like(a, dtype=torch.float32) + fused_gdn_gating_kernel[grid](g, + A_log, + a, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1) + return g diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py new file mode 100644 index 0000000000000..c755eeb9b4eaa --- /dev/null +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Qwen3Next MTP model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer, + Qwen3NextRMSNorm) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Qwen3NextConfig + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, maybe_prefix) + +logger = init_logger(__name__) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +@support_torch_compile +class Qwen3NextMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + config: Qwen3NextConfig = model_config.hf_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.fc = ColumnParallelLinear(self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False, + quant_config=quant_config, + prefix=f'{prefix}.fc') + + self.layers = torch.nn.ModuleList( + Qwen3NextDecoderLayer( + config, + layer_type="full_attention", + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f'{prefix}.layers.{idx}', + ) for idx in range(self.num_mtp_layers)) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + assert hidden_states.shape[-1] == inputs_embeds.shape[-1] + inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) + hidden_states = self.pre_fc_norm_hidden(hidden_states) + hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1) + hidden_states = self.fc(hidden_states) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + current_step_idx = (spec_step_idx % self.num_mtp_layers) + hidden_states, residual = self.layers[current_step_idx]( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile +class Qwen3NextMTP(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + cache_config = vllm_config.cache_config + assert not cache_config.enable_prefix_caching, \ + "Qwen3NextMTP currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "mtp")) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead(self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head")) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + hidden_states = self.model(input_ids, positions, hidden_states, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + shared_weight_names = ["embed_tokens", "lm_head"] + + def remap_weight_names(weights): + for name, weight in weights: + if name.startswith("mtp."): + name = name.replace("mtp.", "model.") + elif not any(key in name for key in shared_weight_names): + continue + yield name, weight + + loader = AutoWeightsLoader(self) + return loader.load_weights(remap_weight_names(weights)) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py new file mode 100644 index 0000000000000..2c36dfbce7f67 --- /dev/null +++ b/vllm/model_executor/models/qwen3_vl.py @@ -0,0 +1,1534 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import BatchFeature +from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast +from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize +from transformers.models.qwen3_vl import (Qwen3VLProcessor, + Qwen3VLVideoProcessor) +from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLConfig, Qwen3VLVisionConfig) +from transformers.video_utils import VideoMetadata + +from vllm.attention.layer import check_upstream_fa_availability +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItem, + MultiModalKwargsItems, VideoItem) +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import _Backend +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import uses_mrope +from vllm.utils import is_list_of + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .qwen2_5_vl import (Qwen2_5_VisionAttention, + Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs) +from .qwen2_vl import Qwen2VLProcessingInfo +from .qwen3 import Qwen3ForCausalLM, Qwen3Model +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + maybe_prefix, merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + + +class Qwen3_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Qwen3_VisionMLP(nn.Module): + + def __init__(self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False): + super().__init__() + self.linear_fc1 = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel) + self.linear_fc2 = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel) + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return mlp_output + + +class Qwen3_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) + self.mlp = Qwen3_VisionMLP(dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen3_VisionPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + + self.use_postshuffle_norm = use_postshuffle_norm + if self.use_postshuffle_norm: + context_dim = self.hidden_size + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = norm_layer( + self.hidden_size if use_postshuffle_norm else context_dim) + self.linear_fc1 = ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel) + self.act_fn = nn.GELU() + self.linear_fc2 = RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.hidden_size)) + else: + x = self.norm(x).view(-1, self.hidden_size) + + x_parallel, _ = self.linear_fc1(x) + x_parallel = self.act_fn(x_parallel) + out, _ = self.linear_fc2(x_parallel) + return out + + +class Qwen3_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen3VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: + super().__init__() + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.num_position_embeddings = vision_config.num_position_embeddings + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + self.use_data_parallel = use_data_parallel + + # NOTE: This is used for creating empty tensor for all_gather for + # DP ViT. Here out_hidden_size is enlarged due to deepstack + self.out_hidden_size = (vision_config.out_hidden_size * + (1 + len(self.deepstack_visual_indexes))) + + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + self.pos_embed = nn.Embedding(self.num_position_embeddings, + self.hidden_size) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel) + for layer_idx in range(vision_config.depth) + ]) + + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, + ) + + self.deepstack_merger_list = nn.ModuleList([ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + use_data_parallel=use_data_parallel) + for layer_idx in range(len(self.deepstack_visual_indexes)) + ]) + + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + # Support both Tensor and list inputs for DP path + if isinstance(grid_thw, list): + grid_list = grid_thw + max_grid_size = max(max(h, w) for _, h, w in grid_list) + else: + grid_list = grid_thw.tolist() + max_grid_size = int(grid_thw[:, 1:].max().item()) + for t, h, w in grid_list: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + num_grid_per_side = int(self.num_position_embeddings**0.5) + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in grid_thw: + h_idxs = torch.linspace(0, + num_grid_per_side - 1, + h, + dtype=torch.float32) + w_idxs = torch.linspace(0, + num_grid_per_side - 1, + w, + dtype=torch.float32) + + h_idxs_floor = h_idxs.to(torch.long) + w_idxs_floor = w_idxs.to(torch.long) + h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1, + max=num_grid_per_side - 1) + w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1, + max=num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T + + w_idxs_floor[None]).flatten().tolist() * t) + idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T + + w_idxs_ceil[None]).flatten().tolist() * t) + idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T + + w_idxs_floor[None]).flatten().tolist() * t) + idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T + + w_idxs_ceil[None]).flatten().tolist() * t) + + weight_list[0].extend( + ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t) + weight_list[1].extend( + ((1 - dh)[None].T * dw[None]).flatten().tolist() * t) + weight_list[2].extend( + (dh[None].T * (1 - dw)[None]).flatten().tolist() * t) + weight_list[3].extend( + (dh[None].T * dw[None]).flatten().tolist() * t) + + device = self.pos_embed.weight.device + dtype = self.pos_embed.weight.dtype + + p0 = self.pos_embed( + torch.tensor( + idx_list[0], dtype=torch.long, device=device)) * torch.tensor( + weight_list[0], dtype=dtype, device=device)[:, None] + p1 = self.pos_embed( + torch.tensor( + idx_list[1], dtype=torch.long, device=device)) * torch.tensor( + weight_list[1], dtype=dtype, device=device)[:, None] + p2 = self.pos_embed( + torch.tensor( + idx_list[2], dtype=torch.long, device=device)) * torch.tensor( + weight_list[2], dtype=dtype, device=device)[:, None] + p3 = self.pos_embed( + torch.tensor( + idx_list[3], dtype=torch.long, device=device)) * torch.tensor( + weight_list[3], dtype=dtype, device=device)[:, None] + + patch_pos_embeds = p0 + p1 + p2 + p3 + patch_pos_embeds = patch_pos_embeds.split( + [t * h * w for t, h, w in grid_thw]) + patch_pos_embeds_permute = [] + m_size = self.spatial_merge_size + for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): + pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size, + m_size, -1).permute(0, 1, 3, 2, 4, + 5).flatten(0, 4) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def compute_attn_mask_seqlen( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + if isinstance(grid_thw, list): + grid_thw_tensor = torch.tensor(grid_thw, + device=hidden_states.device, + dtype=torch.int32) + else: + grid_thw_tensor = grid_thw + + cu_seqlens = torch.repeat_interleave( + grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], + grid_thw_tensor[:, 0]).cumsum( + dim=0, + dtype=grid_thw_tensor.dtype + if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + hidden_states = hidden_states.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index( + layer_num) + deepstack_feature = self.deepstack_merger_list[ + deepstack_merger_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat( + [hidden_states] + deepstack_feature_lists, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3VLConfig) + + def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor: + return self.ctx.get_hf_processor( + Qwen3VLProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + + def get_tokenizer(self): + return self.ctx.tokenizer + + def get_image_processor(self, + **kwargs: object) -> Qwen2VLImageProcessorFast: + return self.get_hf_processor(**kwargs).image_processor + + def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor: + return self.get_hf_processor(**kwargs).video_processor + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 2, + do_resize: bool = True, + image_processor: Optional[Qwen2VLImageProcessorFast], + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + temporal_patch_size = vision_config.temporal_patch_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.size["shortest_edge"], + max_pixels=image_processor.size["longest_edge"], + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + padded_num_frames = num_frames + num_frames % temporal_patch_size + + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (merge_size**2) + + return preprocessed_size, num_vision_tokens + + def _calculate_timestamps(self, indices: list[int] | torch.Tensor, + video_fps: float, merge_size: int): + if not isinstance(indices, list): + indices = indices.tolist() + if len(indices) % merge_size != 0: + # don't update metadata's frames_indices directly + indices = indices + [indices[-1] + ] * (merge_size - len(indices) % merge_size) + timestamps = [idx / video_fps for idx in indices] + timestamps = [(timestamps[i] + timestamps[i + merge_size - 1]) / 2 + for i in range(0, len(timestamps), merge_size)] + return timestamps + + def _get_video_second_idx( + self, + metadata: dict[str, Any], + out_item: MultiModalKwargsItem, + do_sample_frames: Optional[bool] = None, + sampled_fps: Optional[float] = None) -> list[int]: + video_processor = self.get_video_processor() + merge_size = video_processor.merge_size + indices = metadata["frames_indices"] + + # metadata["fps"] refers to the true fps of the input video. + video_fps = metadata["fps"] + if do_sample_frames is None: + do_sample_frames = metadata.get("do_sample_frames", False) + + # If video frames are sampled in HF processor (instead of vLLM + # video loader), we need to re-calculate the indices from original + # metadata. + if do_sample_frames: + # here video_fps is the fps of the sampled video, and + # metadata["fps"] refers to the fps of the original video. + video_fps = sampled_fps if sampled_fps else video_processor.fps + total_num_frames = metadata["total_num_frames"] + num_frames = int(total_num_frames / metadata["fps"] * video_fps) + num_frames = min( + min(max(num_frames, video_processor.min_frames), + video_processor.max_frames), total_num_frames) + indices = np.linspace(0, total_num_frames - 1, + num_frames).round().astype(int).tolist() + timestamps = self._calculate_timestamps(indices, video_fps, merge_size) + return timestamps + + +class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + image_token = "<|vision_start|><|image_pad|><|vision_end|>" + video_token = "<|vision_start|><|video_pad|><|vision_end|>" + + return image_token * num_images + video_token * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = ( + self.info.get_image_size_with_most_features()) + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts) + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ), + } + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + ) -> list[VideoItem]: + num_frames = max(num_frames, 2) + video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) + video_items = [] + for i in range(num_videos): + video_metadata = { + "fps": 2.0, + "duration": num_frames / 2.0, + "total_num_frames": num_frames, + "frames_indices": [i for i in range(num_frames)], + "video_backend": "opencv", + "do_sample_frames": False, + } + video_item = (video.copy(), video_metadata) + video_items.append(video_item) + return video_items + + +class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo] + ): + + def _get_data_parser(self) -> MultiModalDataParser: + return MultiModalDataParser(video_needs_metadata=True) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + processor = self.info.get_hf_processor(**mm_kwargs) + + # Separate video processing from image processing. Because the videos + # are processed into serval image patches + if ("videos" in mm_data and isinstance(mm_data["videos"], list) + and len(mm_data["videos"]) > 0): + video_grid_thw_lst = [] + pixel_values_videos_lst = [] + + for item_idx, item in enumerate(mm_data.pop("videos", [])): + video_array, metadata = item + + # NOTE: @JJJYmmm new attr metadata.frames_indices indicates + # the sampled frames indices of pre-sampled videos, which is + # used to calculate the timestamps. Make sure that + # do_sample_frames in mm_kwargs is false for presampled videos. + + # NOTE: a copy of is created to update do_sample_frames, + # otherwise mm_hash for the object will be incorrect. + video_mm_kwargs = dict(**mm_kwargs) + if "do_sample_frames" not in video_mm_kwargs: + # qwen_vl_utils already has "do_sample_frames" in + # mm_kwargs, don't overwrite it. + video_mm_kwargs["do_sample_frames"] = metadata.get( + "do_sample_frames", False) + + metadata = VideoMetadata(**{ + k: metadata[k] + for k in metadata if k != "do_sample_frames" + }) + + video_mm_data = dict() + video_mm_data["videos"] = [[video_array]] + video_mm_data["video_metadata"] = [[metadata]] + + video_outputs = super()._call_hf_processor( + prompt="<|vision_start|><|video_pad|><|vision_end|>", + mm_data=video_mm_data, + mm_kwargs=video_mm_kwargs, + tok_kwargs=tok_kwargs, + ) + input_ids = video_outputs.pop("input_ids") + video_placeholder = processor.tokenizer.batch_decode( + input_ids)[0] + prompt = prompt.replace( + "<|vision_start|><|video_pad|><|vision_end|>", + video_placeholder, + 1, + ) + + video_grid_thw_lst.append(video_outputs["video_grid_thw"]) + pixel_values_videos_lst.append( + video_outputs["pixel_values_videos"]) + video_outputs = dict( + pixel_values_videos=torch.cat(pixel_values_videos_lst), + video_grid_thw=torch.cat(video_grid_thw_lst), + ) + else: + video_outputs = dict() + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + combined_outputs = dict( + processed_outputs, + **video_outputs, + ) + return BatchFeature(combined_outputs) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + hf_config = self.info.get_hf_config() + + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + vision_end_token_id = hf_config.vision_end_token_id + + merge_length = image_processor.merge_size**2 + + def get_image_replacement_qwen3vl(item_idx: int): + out_item = out_mm_kwargs["image"][item_idx] + grid_thw = out_item["image_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [hf_processor.image_token_id] * num_tokens + + def get_video_replacement_qwen3vl(item_idx: int): + out_item = out_mm_kwargs["video"][item_idx] + grid_thw = out_item["video_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + video, metadata = mm_items["video"][item_idx] + do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames") + sampled_fps = hf_processor_mm_kwargs.get("fps") + if is_list_of(sampled_fps, float): + sampled_fps = sampled_fps[item_idx] + timestamps = self.info._get_video_second_idx( + metadata, out_item, do_sample_frames, sampled_fps) + + assert len(timestamps) == grid_thw[0], ( + f"The timestamps length({len(timestamps)}) should be equal " + f"video length ({grid_thw[0]}).") + + frames_idx_token = [ + tokenizer.encode(f"<{curr_time:.1f} seconds>", + add_special_tokens=False) + for curr_time in timestamps + ] + num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + placeholder = [] + for frame_idx in frames_idx_token: + placeholder.extend(frame_idx) + placeholder.extend([vision_start_token_id] + + [video_token_id] * num_tokens_per_frame + + [vision_end_token_id]) + return PromptUpdateDetails.select_token_id(placeholder, + video_token_id) + + return [ + PromptReplacement( + modality="image", + target=hf_processor.image_token, + replacement=get_image_replacement_qwen3vl, + ), + + # NOTE: We match string on purpose since searching sequence of + # token ids takes more time. + PromptReplacement( + modality="video", + target="<|vision_start|><|video_pad|><|vision_end|>", + replacement=get_video_replacement_qwen3vl, + ), + ] + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + # the same shape as input_embeds + "deepstack_input_embeds": 0 + }) +class Qwen3LLMModel(Qwen3Model): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + if not get_pp_group().is_first_rank: + assert self.start_layer >= len( + vllm_config.model_config.hf_config.vision_config. + deepstack_visual_indexes), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + # args for deepstack + deepstack_input_embeds: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer_idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + layer_idx = layer_idx + self.start_layer + + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and \ + layer_idx in range(0, len(deepstack_input_embeds)): + hidden_states = hidden_states + deepstack_input_embeds[ + f"deepstack_input_embeds_{layer_idx}"] + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3LLMForCausalLM(Qwen3ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3ForCausalLM, self).__init__() + config = vllm_config.model_config.hf_config.text_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen3LLMModel(vllm_config=vllm_config, prefix=prefix) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix="lm_head") + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + supports_encoder_tp_data = True + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|video_pad|><|vision_end|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): + super().__init__() + config: Qwen3VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) + + self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, + "language_model")) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + self.use_deepstack = hasattr(config.vision_config, + 'deepstack_visual_indexes') + self.deepstack_num_level = len( + config.vision_config.deepstack_visual_indexes + ) if self.use_deepstack else 0 + # register buffer for deepstack + self.deepstack_input_embeds = [ + torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size) + for _ in range(self.deepstack_num_level) + ] if self.use_deepstack else None + + def _get_deepstack_input_embeds(self, + num_tokens: int) -> IntermediateTensors: + # get deepstack_input_embeds from buffer, and clear the buffer + return IntermediateTensors({ + f"deepstack_input_embeds_{idx}": + self.deepstack_input_embeds[idx][:num_tokens] + for idx in range(self.deepstack_num_level) + }) + + def _set_deepstack_input_embeds( + self, deepstack_input_embeds: torch.Tensor) -> None: + # set deepstack_input_embeds to buffer + num_tokens = deepstack_input_embeds.size(1) + if num_tokens > self.deepstack_input_embeds[0].size(0): + self.deepstack_input_embeds = [ + torch.zeros(num_tokens, + self.config.text_config.hidden_size, + device=self.deepstack_input_embeds[0].device, + dtype=self.deepstack_input_embeds[0].dtype) + for _ in range(self.deepstack_num_level) + ] + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].copy_( + deepstack_input_embeds[idx]) + + def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: + # clear deepstack_input_embeds in buffer + if num_tokens > 0: + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].zero_() + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + if self.use_data_parallel: + from vllm.multimodal.utils import ( + run_dp_sharded_mrope_vision_model) + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync + merge_size = self.visual.spatial_merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return image_embeds.split(sizes) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + if self.use_data_parallel: + from vllm.multimodal.utils import ( + run_dp_sharded_mrope_vision_model) + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d") + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync + merge_size = self.visual.spatial_merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return video_embeds.split(sizes) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def _compute_deepstack_embeds( + self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor: + visual_lens = [ + x.shape[0] if isinstance(x, torch.Tensor) else len(x) + for x in multimodal_embeddings + ] + multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) + + visual_dim = multimodal_embeddings_cat.shape[-1] // ( + self.deepstack_num_level + 1) + + main_dim, multi_dim = visual_dim, visual_dim * self.deepstack_num_level + multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501 + multimodal_embeddings_cat, [main_dim, multi_dim], + dim=-1) + + multimodal_embeddings = torch.split(multimodal_embeddings_main, + visual_lens, + dim=0) + multimodal_embeddings_multiscale = torch.split( + multimodal_embeddings_multiscale, visual_lens, dim=0) + + deepstack_input_embeds = inputs_embeds.new_zeros( + inputs_embeds.size(0), + self.deepstack_num_level * inputs_embeds.size(1)) + + deepstack_input_embeds = merge_multimodal_embeddings( + input_ids, + deepstack_input_embeds, + multimodal_embeddings_multiscale, + placeholder_token_id=[ + self.config.image_token_id, self.config.video_token_id + ], + ) + deepstack_input_embeds = deepstack_input_embeds.view( + inputs_embeds.shape[0], self.deepstack_num_level, + visual_dim).contiguous() + deepstack_input_embeds = deepstack_input_embeds.permute( + 1, 0, 2).contiguous() + return deepstack_input_embeds, multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + deepstack_input_embeds = None + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None and self.use_deepstack: + deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501 + input_ids, inputs_embeds, multimodal_embeddings) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) + + if self.use_deepstack: + if deepstack_input_embeds is None: + deepstack_input_embeds = torch.zeros_like( + inputs_embeds).unsqueeze(0).repeat( + self.deepstack_num_level, 1, 1).contiguous() + self._set_deepstack_input_embeds(deepstack_input_embeds) + + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Qwen2_5_VLImageInputs] = None, + video_input: Optional[Qwen2_5_VLVideoInputs] = None, + ) -> torch.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) + + if self.use_deepstack: + visual_dim = inputs_embeds.shape[-1] + deepstack_input_embeds = None + if image_input is not None or video_input is not None: + deepstack_input_embeds = torch.zeros_like( + inputs_embeds).unsqueeze(1).repeat( + 1, self.deepstack_num_level, 1).flatten(1) + + if image_input is not None: + image_embeds = self._process_image_input(image_input) + if self.use_deepstack: + image_embeds = torch.cat(image_embeds) + + image_embeds, image_embeds_multiscale = image_embeds.split( + [visual_dim, visual_dim * self.deepstack_num_level], + dim=-1) + + deepstack_input_embeds = merge_multimodal_embeddings( + input_ids, + deepstack_input_embeds, + image_embeds_multiscale, + placeholder_token_id=self.config.image_token_id, + ) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + if self.use_deepstack: + video_embeds = torch.cat(video_embeds) + + video_embeds, video_embeds_multiscale = video_embeds.split( + [visual_dim, visual_dim * self.deepstack_num_level], + dim=-1) + + deepstack_input_embeds = merge_multimodal_embeddings( + input_ids, + deepstack_input_embeds, + video_embeds_multiscale, + placeholder_token_id=self.config.video_token_id, + ) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + + if self.use_deepstack and deepstack_input_embeds is not None: + deepstack_input_embeds = deepstack_input_embeds.view( + inputs_embeds.shape[0], self.deepstack_num_level, + visual_dim).permute(1, 0, 2).contiguous() + self._set_deepstack_input_embeds(deepstack_input_embeds) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Qwen3VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen3VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + + if self.use_deepstack and inputs_embeds is not None and get_pp_group( + ).is_first_rank: + deepstack_input_embeds = self._get_deepstack_input_embeds( + inputs_embeds.size(0)) + else: + deepstack_input_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + # args for deepstack + deepstack_input_embeds=deepstack_input_embeds, + ) + + if inputs_embeds is not None and get_pp_group().is_first_rank: + self._clear_deepstack_input_embeds(inputs_embeds.size(0)) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="model.visual.merger", + tower_model="model.visual.", + ) diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py new file mode 100644 index 0000000000000..d25bc71dcb59b --- /dev/null +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights.""" +import typing +from collections.abc import Iterable +from typing import Callable, Optional, Union + +import torch +from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import ( + Qwen3VLMoeConfig) + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + +from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from .qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder, + Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) +from .utils import is_pp_missing_parameter, maybe_prefix + +logger = init_logger(__name__) + + +class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3VLMoeConfig) + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + # the same shape as input_embeds + "deepstack_input_embeds": 0 + }) +class Qwen3MoeLLMModel(Qwen3MoeModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + if not get_pp_group().is_first_rank: + assert self.start_layer >= len( + vllm_config.model_config.hf_config.vision_config. + deepstack_visual_indexes), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + deepstack_input_embeds: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer_idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + layer_idx = layer_idx + self.start_layer + + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and \ + layer_idx in range(0, len(deepstack_input_embeds)): + hidden_states = hidden_states + deepstack_input_embeds[ + f"deepstack_input_embeds_{layer_idx}"] + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_fused_expert_weights(self, name: str, params_dict: dict, + loaded_weight: torch.Tensor, shard_id: str, + num_experts: int): + param = params_dict[name] + weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + for expert_id in range(num_experts): + curr_expert_weight = loaded_weight[expert_id] + success = weight_loader(param, + curr_expert_weight, + name, + shard_id, + expert_id, + return_success=True) + if not success: + return False + return True + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", + ".v_scale", "_v_scale", ".weight_scale", + "_weight_scale", ".input_scale", "_input_scale") + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + is_fused_expert = False + fused_expert_params_mapping = [ + ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), + ("experts.w2_weight", "experts.down_proj", 0, "w2"), + ] + num_experts = self.config.num_experts + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if ("experts.gate_up_proj" in name + or "experts.down_proj" in name): + is_fused_expert = True + expert_params_mapping = fused_expert_params_mapping + + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + if is_fused_expert: + loaded_weight = loaded_weight.transpose(-1, + -2) # no bias + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + success_w1 = self.load_fused_expert_weights( + name_mapped, params_dict, loaded_weight[0], + "w1", num_experts) + success_w3 = self.load_fused_expert_weights( + name_mapped, params_dict, loaded_weight[1], + "w3", num_experts) + success = success_w1 and success_w3 + else: + # down_proj + success = self.load_fused_expert_weights( + name_mapped, params_dict, loaded_weight, + shard_id, num_experts) + else: + if is_pp_missing_parameter(name_mapped, self): + continue + # Skip loading extra parameters for GPTQ/modelopt models + if name_mapped.endswith( + ignore_suffixes + ) and name_mapped not in params_dict: + continue + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith( + ignore_suffixes) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 + name, + remapped_kv_scale_name, + ) + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3MoeForCausalLM, self).__init__() + self.config = vllm_config.model_config.hf_config.text_config + self.quant_config = vllm_config.quant_config + self.model = Qwen3MoeLLMModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLMoeProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3VLForConditionalGeneration, self).__init__() + config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) + + self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, + "language_model")) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + self.use_deepstack = hasattr(config.vision_config, + 'deepstack_visual_indexes') + self.deepstack_num_level = len( + config.vision_config.deepstack_visual_indexes + ) if self.use_deepstack else 0 + # register buffer for deepstack + self.deepstack_input_embeds = [ + torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size) + for _ in range(self.deepstack_num_level) + ] if self.use_deepstack else None diff --git a/vllm/model_executor/models/radio.py b/vllm/model_executor/models/radio.py new file mode 100644 index 0000000000000..9cbf844ae9f86 --- /dev/null +++ b/vllm/model_executor/models/radio.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import math +from collections.abc import Iterable +from itertools import repeat +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import PretrainedConfig + +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.intern_vit import InternVisionEncoder + +input_dim_t = Union[int, tuple[int, int]] +norm_t = Union[tuple[float, float, float], torch.Tensor] + + +def _ntuple(n): + + def parse(x): + if isinstance(x, Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +class InputConditioner(nn.Module): + + def __init__( + self, + input_scale: float, + norm_mean: norm_t, + norm_std: norm_t, + dtype: torch.dtype = None, + ): + super().__init__() + + self.dtype = dtype + + self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale) + self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale) + + def forward(self, x: torch.Tensor): + y = (x - self.norm_mean) / self.norm_std + if self.dtype is not None: + y = y.to(self.dtype) + return y + + +def _to_tensor(v: norm_t): + return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1) + + +class ClsToken(nn.Module): + + def __init__( + self, + ndim: int, + num_tokens: int = 1, + enabled: bool = True, + register_multiple: Optional[int] = None, + num_registers: Optional[int] = None, + ): + super().__init__() + + self.ndim = ndim + self.enabled = enabled + self.num_registers = 0 + self.num_tokens = num_tokens + if enabled: + if num_registers: + self.num_registers = num_registers + elif register_multiple: + self.num_registers = register_multiple - (num_tokens % + register_multiple) + + scale = ndim**-0.5 + self.token = nn.Parameter( + torch.randn(num_tokens + self.num_registers, ndim) * scale) + + else: + self.token = None + + self.num_patches = self.num_tokens + self.num_registers + + def forward(self, x: torch.Tensor): + if self.token is None: + return x + + token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1) + x = torch.cat([ + token, + x, + ], dim=1) + + return x + + +class ViTPatchGenerator(nn.Module): + + def __init__( + self, + # config: PretrainedConfig, + patch_size: int, + embed_dim: int, + input_dims: input_dim_t, + abs_pos: bool = True, + normalize_patches: bool = False, + cls_token: bool = False, + max_input_dims: Optional[input_dim_t] = None, + pos_dropout: float = 0.0, + return_pos_enc: bool = False, + num_cls_tokens: int = 1, + register_multiple: Optional[int] = None, + num_registers: Optional[int] = None, + patch_bias: bool = False, + device=None, + dtype=None, + ): + super().__init__() + if isinstance(input_dims, int): + input_dims = (input_dims, input_dims) + + if max_input_dims is None: + max_input_dims = input_dims + if isinstance(max_input_dims, int): + max_input_dims = (max_input_dims, max_input_dims) + + max_input_dims = tuple( + int(math.ceil(d / patch_size) * patch_size) + for d in max_input_dims) + + self.cpe_mode = max_input_dims != input_dims + self.pos_dropout = pos_dropout + self.return_pos_enc = return_pos_enc + + factory = dict(device=device, dtype=dtype) + + self.patch_size = patch_size + self.abs_pos = abs_pos + self.embed_dim = embed_dim + + self.num_rows = max_input_dims[0] // patch_size + self.num_cols = max_input_dims[1] // patch_size + self.input_dims = tuple(d // patch_size for d in input_dims) + self.num_patches = self.num_rows * self.num_cols + self.max_input_dims = max_input_dims + + self.im_to_patches = Im2Patches(patch_size) + self.embedder = ViTPatchLinear(patch_size, + embed_dim, + bias=patch_bias, + **factory) + + if abs_pos: + scale = embed_dim**-0.5 + self.pos_embed = nn.Parameter( + torch.randn(1, self.num_patches, embed_dim, **factory) * scale) + + self.cls_token = ClsToken( + embed_dim, + num_tokens=num_cls_tokens, + enabled=cls_token, + register_multiple=register_multiple, + num_registers=num_registers, + ) + + self.patch_normalizer = nn.LayerNorm( + embed_dim) if normalize_patches else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + patches = self.embed_patches(x) + patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:]) + patches = self.cls_token(patches) + patches = self.patch_normalizer(patches) + if self.return_pos_enc: + return patches, pos_enc + return patches + + @property + def apply_cls_token(self): + return self.cls_token.enabled + + @property + def num_cls_tokens(self): + return self.cls_token.num_tokens + + @property + def num_cls_patches(self): + return self.cls_token.num_patches + + @property + def num_registers(self): + return self.cls_token.num_registers + + @property + def num_skip(self): + return self.num_cls_tokens + self.num_registers + + def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter): + if src_embed.shape != targ_embed.shape: + src_size = int(math.sqrt(src_embed.shape[1])) + + assert src_size**2 == src_embed.shape[ + 1], 'Unable to interpolate non-square embedding' + + src_embed = rearrange(src_embed, + 'b (h w) c -> b c h w', + h=src_size, + w=src_size) + src_embed = F.interpolate(src_embed, + size=(self.num_rows, self.num_cols), + mode='bicubic', + align_corners=True, + antialias=False) + src_embed = rearrange(src_embed, 'b c h w -> b (h w) c') + targ_embed.data.copy_(src_embed) + + def _load_projection(self, src_proj_weight: torch.Tensor, + targ_proj_weight: torch.Tensor): + if src_proj_weight.shape != targ_proj_weight.shape: + src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3)) + + assert (src_patch_size**2) * 3 == src_proj_weight.shape[ + 1], 'Unable to interpolate non-square patch size' + + src_proj_weight = rearrange(src_proj_weight, + 'b (c h w) -> b c h w', + c=3, + h=src_patch_size, + w=src_patch_size) + src_proj_weight = F.interpolate(src_proj_weight, + size=(self.patch_size, + self.patch_size), + mode='bicubic', + align_corners=True, + antialias=False) + src_proj_weight = rearrange(src_proj_weight, + 'b c h w -> b (c h w)') + targ_proj_weight.data.copy_(src_proj_weight) + + def embed_patches(self, x: torch.Tensor) -> torch.Tensor: + patches = self.im_to_patches(x) + patches = self.embedder(patches) + return patches + + def apply_pos_enc( + self, + patches: torch.Tensor, + patch_idxs: Optional[torch.Tensor] = None, + input_size: Optional[tuple[int, int]] = None, + ) -> torch.Tensor: + if not self.abs_pos: + return patches + + pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size) + + if self.training and self.pos_dropout > 0: + keeps = torch.rand(patches.shape[0], + 1, + 1, + dtype=pos_enc.dtype, + device=pos_enc.device) > self.pos_dropout + pos_enc_drop = torch.where(keeps, pos_enc, 0) + else: + pos_enc_drop = pos_enc + + return patches + pos_enc_drop, pos_enc + + def get_pos_enc( + self, + batch_size: int, + patch_idxs: Optional[torch.Tensor] = None, + input_size: Optional[tuple[int, int]] = None, + ) -> torch.Tensor: + if input_size is None: + input_dims = self.input_dims + else: + input_dims = tuple(d // self.patch_size for d in input_size) + + pos_embed = self._get_pos_embeddings(batch_size, input_dims) + + if patch_idxs is None: + return pos_embed + + exp_patch_idxs = patch_idxs.unsqueeze(-1).expand( + -1, -1, pos_embed.shape[-1]) + + pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), + dim=1, + index=exp_patch_idxs) + return pos_embed + + def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, + int]): + if (self.num_rows, self.num_cols) == input_dims: + return self.pos_embed + + pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, + -1).permute(0, 3, 1, 2) + + def window_select(pos_embed): + if input_dims[0] < pos_embed.shape[-2]: + pos_embed = pos_embed[..., :input_dims[0], :] + if input_dims[1] < pos_embed.shape[-1]: + pos_embed = pos_embed[..., :, :input_dims[1]] + return pos_embed + + if self.cpe_mode: + if self.training: + min_scale = math.sqrt(0.1) + scale = torch.rand(batch_size, 1, 1, device=pos_embed.device + ) * (1 - min_scale) + min_scale + aspect_min = math.log(3 / 4) + aspect_max = -aspect_min + aspect = torch.exp( + torch.rand(batch_size, 1, 1, device=pos_embed.device) * + (aspect_max - aspect_min) + aspect_min) + + scale_x = scale * aspect + scale_y = scale * (1 / aspect) + scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1) + + pos_xy = torch.rand( + batch_size, 1, 1, 2, + device=pos_embed.device) * (1 - scale_xy) + + lin_x = torch.linspace( + 0, 1, steps=input_dims[1], + device=pos_embed.device)[None, None].expand( + batch_size, input_dims[0], -1) + lin_y = torch.linspace( + 0, 1, steps=input_dims[0], + device=pos_embed.device)[None, :, None].expand( + batch_size, -1, input_dims[1]) + + lin_xy = torch.stack([lin_x, lin_y], dim=-1) + + grid_xy = lin_xy * scale_xy + pos_xy + + # Convert to [-1, 1] range + grid_xy.mul_(2).sub_(1) + + pos_embed = F.grid_sample( + pos_embed.float().expand(batch_size, -1, -1, -1), + grid=grid_xy, + mode='bilinear', + padding_mode='zeros', + align_corners=True, + ).to(pos_embed.dtype) + else: + max_dim = max(input_dims) + pos_embed = F.interpolate(pos_embed.float(), + size=(max_dim, max_dim), + align_corners=True, + mode='bilinear').to(pos_embed.dtype) + + pos_embed = window_select(pos_embed) + else: + pos_embed = window_select(pos_embed) + + if pos_embed.shape[-2:] != input_dims: + pos_embed = F.interpolate(pos_embed.float(), + size=input_dims, + align_corners=True, + mode='bilinear').to(pos_embed.dtype) + + pos_embed = pos_embed.flatten(2).permute(0, 2, 1) + + return pos_embed + + +class Im2Patches(nn.Module): + + def __init__(self, patch_size: int): + super().__init__() + self.patch_size = patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.patch_size == 1: + patches = x.flatten(2) + patches = patches.permute(0, 2, 1) + return patches + + py = x.shape[-2] // self.patch_size + px = x.shape[-1] // self.patch_size + patches = rearrange( + x, + 'b c (py yy) (px xx) -> b (py px) (c yy xx)', + py=py, + yy=self.patch_size, + px=px, + xx=self.patch_size, + ) + return patches + + +class ViTPatchLinear(nn.Linear): + + def __init__(self, + patch_size: int, + embed_dim: int, + bias: bool = False, + **factory): + super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory) + self.patch_size = patch_size + + +class RadioInternVisionModel(nn.Module): + packed_modules_mapping = { + "qkv": ["qkv"], + } + + def __init__( + self, + config: PretrainedConfig = None, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.img_size, self.grid_size, self.num_patches = self._init_img_size( + to_2tuple(config.patch_size), config.image_size) + max_img_size = int( + round(config.max_img_size / config.patch_size) * config.patch_size) + self.patch_generator = ViTPatchGenerator( + config.patch_size, + config.hidden_size, + input_dims=self.img_size, + max_input_dims=max_img_size, + cls_token=True, + register_multiple=config.reg_tokens) + + self.encoder = InternVisionEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.encoder", + ) + + def _init_img_size(self, patch_size, img_size: Union[int, tuple[int, + int]]): + if img_size is None: + return None, None, None + img_size = to_2tuple(img_size) + grid_size = tuple([s // p for s, p in zip(img_size, patch_size)]) + num_patches = grid_size[0] * grid_size[1] + return img_size, grid_size, num_patches + + def get_input_embeddings(self): + return self.embeddings + + def forward(self, x: torch.Tensor) -> torch.FloatTensor: + assert self.patch_generator is not None + hidden_states = self.patch_generator(x) + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + return encoder_outputs + + +class RadioModel(nn.Module): + packed_modules_mapping = { + "qkv": ["qkv"], + } + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.input_conditioner = InputConditioner( + input_scale=1.0, + norm_mean=config.norm_mean, + norm_std=config.norm_std, + ) + self.model = RadioInternVisionModel( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + num_dummy_heads=num_dummy_heads, + prefix=prefix) + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + x = self.input_conditioner(pixel_values) + y = self.model(x) + return self._extract_final(y) + + def load_weights(self, weights) -> set[str]: + loaded_params: set[str] = set() + params_dict = dict(self.named_parameters()) + + if isinstance(weights, dict): + weights_list = list(weights.items()) + else: + weights_list = list(weights) + + for name, weight in weights_list: + if not name.startswith("radio_model."): + # Skip non-radio weights + continue + + sub = name[len("radio_model."):] # drop "radio_model." prefix + + # Skip buffers not used in vLLM + if sub in {"summary_idxs"}: + continue + + vllm_key = None + if sub.startswith("model.patch_generator."): + vllm_key = f"model.patch_generator.{sub.split('.', 2)[-1]}" + elif sub.startswith("input_conditioner."): + vllm_key = f"input_conditioner.{sub.split('.', 1)[-1]}" + elif sub.startswith("model.blocks."): + # Encoder blocks: HF 'model.blocks.{i}.' -> + # vLLM 'model.encoder.layers.{i}.' + parts = sub.split(".") + if len(parts) >= 4: + layer_idx = parts[2] + suffix = ".".join(parts[3:]) + # Skip layer-scale entries that vLLM doesn't use + if suffix in {"ls1", "ls2"} or suffix.startswith( + ("ls1.", "ls2.")): + continue + vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}" + + if vllm_key and vllm_key in params_dict: + param = params_dict[vllm_key] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(vllm_key) + + return loaded_params + + def _extract_final(self, y: torch.Tensor): + # Remove CLS + REGISTERS tokens + patch_gen = getattr(self.model, "patch_generator", None) + if patch_gen is not None: + all_feat = y[:, patch_gen.num_skip:] + + return all_feat diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f236040bb2341..1382fd9e93ea3 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = { # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), + "BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), @@ -74,6 +75,7 @@ _TEXT_GENERATION_MODELS = { "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"), + "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), @@ -110,7 +112,7 @@ _TEXT_GENERATION_MODELS = { "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "MotifForCausalLM": ("motif", "MotifForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), @@ -119,6 +121,7 @@ _TEXT_GENERATION_MODELS = { "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), + "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), @@ -144,10 +147,6 @@ _TEXT_GENERATION_MODELS = { "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"), "XverseForCausalLM": ("llama", "LlamaForCausalLM"), "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"), - # [Encoder-decoder] - "BartModel": ("bart", "BartForConditionalGeneration"), - "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), - "MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"), } _EMBEDDING_MODELS = { @@ -155,6 +154,7 @@ _EMBEDDING_MODELS = { "BertModel": ("bert", "BertEmbeddingModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), + "Gemma3TextModel": ("gemma3", "Gemma3Model"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GritLM": ("gritlm", "GritLM"), @@ -184,14 +184,16 @@ _EMBEDDING_MODELS = { "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 - # Technically PrithviGeoSpatialMAE is a model that works on images, both in - # input and output. I am adding it here because it piggybacks on embedding + # Technically Terratorch models work on images, both in + # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. - "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), + "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"), + "Terratorch": ("terratorch", "Terratorch"), } _CROSS_ENCODER_MODELS = { "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "BertForTokenClassification": ("bert", "BertForTokenClassification"), "GteNewForSequenceClassification": ("bert_with_rope", "GteNewForSequenceClassification"), "ModernBertForSequenceClassification": ("modernbert", @@ -222,6 +224,7 @@ _MULTIMODAL_MODELS = { "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501 "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), + "NemotronH_Nano_VL": ("nano_nemotron_vl", "NemotronH_Nano_VL"), "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), @@ -231,11 +234,13 @@ _MULTIMODAL_MODELS = { "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), + "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 + "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"), "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"), # noqa: E501 "MiniCPMO": ("minicpmo", "MiniCPMO"), "MiniCPMV": ("minicpmv", "MiniCPMV"), @@ -255,17 +260,15 @@ _MULTIMODAL_MODELS = { "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 - "UltravoxModel": ("ultravox", "UltravoxModel"), + "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501 + "Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), # noqa: E501 + "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 + "UltravoxModel": ("ultravox", "UltravoxModel"), "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] - "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"), - "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 - "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 - "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 - "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 } @@ -275,13 +278,13 @@ _SPECULATIVE_DECODING_MODELS = { "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), - # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 - # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), + "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), # Temporarily disabled. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), @@ -319,7 +322,17 @@ _SUBPROCESS_COMMAND = [ sys.executable, "-m", "vllm.model_executor.models.registry" ] -_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"} +_PREVIOUSLY_SUPPORTED_MODELS = { + "Phi3SmallForCausalLM": "0.9.2", + # encoder-decoder models except whisper + # have been removed for V0 deprecation. + "BartModel": "0.10.2", + "BartForConditionalGeneration": "0.10.2", + "DonutForConditionalGeneration": "0.10.2", + "Florence2ForConditionalGeneration": "0.10.2", + "MBartForConditionalGeneration": "0.10.2", + "MllamaForConditionalGeneration": "0.10.2", +} @dataclass(frozen=True) @@ -638,6 +651,9 @@ class _ModelRegistry: model_info = self._try_inspect_model_cls(arch) if model_info is not None: return (model_info, arch) + elif model_config.model_impl == ModelImpl.TERRATORCH: + model_info = self._try_inspect_model_cls("Terratorch") + return (model_info, "Terratorch") # Fallback to transformers impl (after resolving convert_type) if (all(arch not in self.models for arch in architectures) @@ -686,6 +702,11 @@ class _ModelRegistry: model_cls = self._try_load_model_cls(arch) if model_cls is not None: return (model_cls, arch) + elif model_config.model_impl == ModelImpl.TERRATORCH: + arch = "Terratorch" + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) # Fallback to transformers impl (after resolving convert_type) if (all(arch not in self.models for arch in architectures) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 2bfa51162910b..ba405be416876 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -8,7 +8,7 @@ import torch from torch import nn from transformers import RobertaConfig -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -73,10 +73,16 @@ class RobertaEmbedding(nn.Module): class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" - def __init__(self, config: RobertaConfig): + def __init__(self, model_config: "ModelConfig"): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + config = model_config.hf_config + head_dtype = model_config.head_dtype + self.dense = nn.Linear(config.hidden_size, + config.hidden_size, + dtype=head_dtype) + self.out_proj = nn.Linear(config.hidden_size, + config.num_labels, + dtype=head_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: # CLSPool has already been applied in `pooling` @@ -184,7 +190,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): self.roberta = BertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), embedding_class=RobertaEmbedding) - self.classifier = RobertaClassificationHead(config) + self.classifier = RobertaClassificationHead(vllm_config.model_config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index c6244fb3b3e6a..7d90d3a7ef128 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -13,6 +13,7 @@ from torch.nn import functional as F from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import QuantizationConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -236,7 +237,15 @@ class Siglip2Attention(nn.Module): self.use_rope = config.use_rope # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.head_dim, dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.ROCM_AITER_FA @@ -280,7 +289,10 @@ class Siglip2Attention(nn.Module): if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func attn_output = flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(seq_length, -1) @@ -378,12 +390,9 @@ class Siglip2EncoderLayer(nn.Module): position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: """ Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all - attention layers. See `attentions` under - returned tensors for more detail. + hidden_states: Input tensor of shape (batch, seq_len, embed_dim). + cu_seqlens: Cumulative sequence lengths tensor. + position_embeddings: Position embeddings tensor. """ residual = hidden_states @@ -522,19 +531,11 @@ class Siglip2Encoder(nn.Module): ) -> torch.Tensor: r""" Args: - inputs_embeds (`torch.FloatTensor` of shape - `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to - directly pass an embedded representation. This is useful if - you want more control over how to convert `input_ids` indices - into associated vectors than the model's internal embedding - lookup matrix. - grid_thws (`torch.LongTensor`): - grid shape (num_patches, 3) - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See - `hidden_states` under returned tensors for more detail. - return_dict (`bool`, *optional*): + inputs_embeds: Input tensor of shape + (batch_size, sequence_length, hidden_size). + Embedded representation of the input tokens. + grid_thws: Grid tensor of shape (num_patches, 3) + containing grid dimensions. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 8dd52f1d204a5..94c862258b7ad 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -469,6 +469,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 97611d3e140ec..b8733fa5e6129 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -35,7 +35,8 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) logger = init_logger(__name__) @@ -386,6 +387,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE if not lora_config else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index f379d2c15fb6c..2ba5f94ea3b88 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -16,12 +16,12 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers import BatchFeature, PretrainedConfig, TensorType +from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -667,39 +667,25 @@ class Step3VisionAttention(nn.Module): self.q_size = self.num_heads * self.head_dim - if use_data_parallel: - self.qkv_proj = ReplicatedLinear( - self.embed_dim, - 3 * self.q_size, - bias=True, - quant_config=quant_config, - prefix=prefix, - ) - self.out_proj = ReplicatedLinear( - self.total_num_heads * self.head_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=prefix, - ) - else: - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.total_num_heads, - bias=True, - quant_config=quant_config, - prefix=prefix, - ) - self.out_proj = RowParallelLinear(self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=prefix) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) + self.out_proj = RowParallelLinear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() + # Use unified MultiHeadAttention with automatic backend selection + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, + self.scale) def forward( self, @@ -711,19 +697,9 @@ class Step3VisionAttention(nn.Module): # get query proj qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - q = q.view(bsz, tgt_len, self.num_heads, self.head_dim) - k = k.view(bsz, tgt_len, self.num_heads, self.head_dim) - v = v.view(bsz, tgt_len, self.num_heads, self.head_dim) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(q, - k, - v, - scale=self.scale, - is_causal=False) - attn_output = attn_output.transpose(1, 2).reshape( - bsz, tgt_len, self.num_heads * self.head_dim) + + # Use unified MultiHeadAttention with automatic backend selection + attn_output = self.attn(q, k, v) attn_output, _ = self.out_proj(attn_output) @@ -740,20 +716,18 @@ class Step3VisionMLP(nn.Module): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - cls_fc1 = (ReplicatedLinear - if use_data_parallel else ColumnParallelLinear) - self.fc1 = cls_fc1(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=prefix) - cls_fc2 = (ReplicatedLinear - if use_data_parallel else RowParallelLinear) - self.fc2 = cls_fc2(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=prefix) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py new file mode 100644 index 0000000000000..b9dfa8e9b6f51 --- /dev/null +++ b/vllm/model_executor/models/terratorch.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 IBM. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `Terratorch` models""" + +from collections import OrderedDict +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +from terratorch.vllm import (DummyDataGenerator, InferenceRunner, + InputDefinition, InputTypeEnum) +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import MultiModalProcessorOnlyCache +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargsItems, + MultiModalUUIDDict, PlaceholderRange) +from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .interfaces import (IsAttentionFree, MultiModalEmbeddings, + SupportsMultiModal) +from .interfaces_base import default_pooling_type + + +def _terratorch_field_names(pretrained_cfg: dict): + input_definition = InputDefinition(**pretrained_cfg["input"]) + return set(input_definition.data.keys()) + + +def _terratorch_field_factory( + pretrained_cfg: dict +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: + + def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): + input_definition = InputDefinition(**pretrained_cfg["input"]) + fields = {} + for input_name, input in input_definition.data.items(): + if input.type == InputTypeEnum.tensor: + fields[input_name] = "image" + + mm_fields_config = {} + for field_name, field_modality in fields.items(): + mm_fields_config[field_name] = MultiModalFieldConfig.shared( + batch_size=1, modality=field_modality) + return mm_fields_config + + return _terratorch_field_config + + +class TerratorchProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + +class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]): + + def __init__(self, info: TerratorchProcessingInfo): + super().__init__(info) + self.dummy_data_generator = DummyDataGenerator( + self.info.get_hf_config().to_dict()["pretrained_cfg"]) + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + # Dummy data is generated based on the 'input' section + # defined in the HF configuration file + return self.dummy_data_generator.get_dummy_mm_data() + + +class TerratorchMultiModalDataParser(MultiModalDataParser): + + def __init__(self, pretrained_cfg: dict, *args, **kwargs): + self._pretrained_cfg = pretrained_cfg + super().__init__(*args, **kwargs) + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + + terratorch_fields = _terratorch_field_names(self._pretrained_cfg) + + return DictEmbeddingItems( + data, + modality="image", + required_fields=terratorch_fields, + fields_factory=_terratorch_field_factory(self._pretrained_cfg), + ) + + return super()._parse_image_data(data) + + +class TerratorchMultiModalProcessor(BaseMultiModalProcessor): + + def __init__( + self, + info: TerratorchProcessingInfo, + dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]", + *, + cache: Optional[MultiModalProcessorOnlyCache] = None) -> None: + + self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"] + super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache) + + def _get_data_parser(self) -> MultiModalDataParser: + return TerratorchMultiModalDataParser( + pretrained_cfg=self.pretrained_cfg) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + return [] + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, + ) -> MultiModalInputs: + if "image" in mm_data: + image_data = mm_data["image"] + else: + image_data = mm_data + mm_data = {"image": mm_data} + + mm_items = self._to_mm_items(mm_data) + tokenization_kwargs = tokenization_kwargs or {} + mm_hashes = self._hash_mm_items(mm_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids) + mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} + + mm_processed_data = BatchFeature(image_data) + + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( + mm_processed_data, + self._get_mm_fields_config(mm_processed_data, + hf_processor_mm_kwargs), + ) + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=[1], + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholders, + ) + + +@default_pooling_type("All") +@MULTIMODAL_REGISTRY.register_processor( + TerratorchMultiModalProcessor, + info=TerratorchProcessingInfo, + dummy_inputs=TerratorchInputBuilder, +) +class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): + supports_multimodal_raw_input_only = True + is_pooling_model = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"] + + self.inference_runner = InferenceRunner(config) + self.model = self.inference_runner.model + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, ) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + # We do not really use any input tokens and therefore no embeddings + # to be calculated. However, due to the mandatory token ids in + # the input prompt we pass one token and the size of the dummy + # embedding tensors must reflect that. + return torch.empty((input_ids.shape[0], 0)) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + model_output = self.inference_runner.forward(**kwargs) + + return model_output.output + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_list = [] + model_buffers = dict(self.named_buffers()) + loaded_buffers = [] + for key, value in weights: + if isinstance(value, (dict, OrderedDict)): + if key == "state_dict": + weights_to_parse = value + for name, weight in weights_to_parse.items(): + name = f"inference_runner.{name}" + + if "pos_embed" in name: + continue + + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + + # this model requires a couple of buffers to be loaded + # that are not loadable with the AutoWeightsLoader + if name in model_buffers: + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + buffer = model_buffers[name] + weight_loader = getattr(buffer, "weight_loader", + default_weight_loader) + weight_loader(buffer, weight) + loaded_buffers.append(name) + else: + params_list.append((name, weight)) + break + + elif isinstance(value, torch.Tensor): + params_list.append((f"inference_runner.model.{key}", value)) + + # Load the remaining model parameters + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(params_list) + + return autoloaded_weights.union(set(loaded_buffers)) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 5ad0482330ecd..4f51441e28efa 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -44,7 +44,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, PlaceholderRange) + MultiModalInputs, MultiModalUUIDDict, + PlaceholderRange) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo) @@ -228,7 +229,8 @@ class MultiModalProcessingInfo(BaseProcessingInfo): def get_max_image_tokens(self) -> int: width, height = self.get_max_image_size() processor = self.get_hf_processor() - mm_processor_kwargs = self.ctx.model_config.mm_processor_kwargs or {} + multimodal_config = self.ctx.model_config.multimodal_config + mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} mm_tokens = processor._get_num_multimodal_tokens( image_sizes=([height, width], ), **mm_processor_kwargs) image_tokens = mm_tokens["num_image_tokens"][0] @@ -347,7 +349,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -379,8 +381,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1 mm_positions = torch.where(mm_token_type_ids == 1)[1] images = mm_items.get_items("image", ImageProcessorItems) - mm_processor_kwargs = (self.info.ctx.model_config.mm_processor_kwargs - or {}) + multimodal_config = self.info.ctx.model_config.multimodal_config + mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} image_sizes = [] for item_idx in range(len(images)): image_size = images.get_image_size(item_idx) @@ -415,9 +417,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): num_image_patches), ) # Use overrides if provided; fallback to data-dependent hashing. - mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else - self._hash_mm_items(mm_items, hf_processor_mm_kwargs, - tokenization_kwargs)) + mm_hashes = (mm_uuids if mm_uuids is not None else self._hash_mm_items( + mm_items, hf_processor_mm_kwargs, tokenization_kwargs)) return MultiModalInputs( type="multimodal", diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f91c4ddb6e834..371ca817d5f92 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -4,7 +4,7 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import torch from torch import nn @@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -43,26 +44,37 @@ _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _MAX_ENCODER_BATCH_SIZE = 16 -class UltravoxAudioFeatureInputs(TypedDict): +class UltravoxAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - b: batch size + - n: number of chunks + - t: Time frames (M) + - nmb: Number of mel bins + """ type: Literal["audio_features"] - data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] - """Shape: `(batch_size, num_chunks, 80, M)`""" - lens: Union[torch.Tensor, list[torch.Tensor]] - """ - Length of the audio frames. Used for attention mask in WhisperEncoder. - Shape: `(batch_size, num_chunks)` - """ - token_len: Union[torch.Tensor, list[torch.Tensor]] - """ - Length of the audio tokens. Used for flattening the audio features. - Shape: `(batch_size, num_chunks)` - """ + data: Annotated[Union[torch.Tensor, list[torch.Tensor], + list[list[torch.Tensor]]], + TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"})] + lens: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "n", dynamic_dims={"n"})] + """Length of the audio frames. Used for attention mask in WhisperEncoder.""" + token_len: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "n", dynamic_dims={"n"})] + """Length of the audio tokens. Used for flattening the audio features.""" -class UltravoxAudioEmbeddingInputs(TypedDict): +class UltravoxAudioEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: batch size + - na: number of audios + - afs: audio feature size + - hs: hidden size + """ type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`""" + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "na", "afs", "hs")] UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, @@ -264,7 +276,7 @@ class UltravoxProjector(nn.Module): else: self.act = get_act_fn(config.projector_act) - dim_out = config.text_hidden_size + dim_out = config.text_config.hidden_size self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) # Ultravox v0.4.1 and below use layer_norm after the second linear layer @@ -406,7 +418,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config: UltravoxConfig = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multi_modal_config = multimodal_config @@ -426,7 +438,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, - hf_config=config.text_config, + hf_config=config.wrapped_model_config, prefix=maybe_prefix(prefix, "language_model"), ) if config.text_model_id is not None: @@ -484,26 +496,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): return None if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(audio_features)}") - if not isinstance(audio_lens, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_lens. " - f"Got type: {type(audio_features)}") - if not isinstance(audio_token_len, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_token_len. " - f"Got type: {type(audio_features)}") - return UltravoxAudioFeatureInputs(type="audio_features", data=audio_features, lens=audio_lens, token_len=audio_token_len) if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") - return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) @@ -599,10 +597,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): with the `input_ids`. Args: - audio_features: A batch of audio input chunks [B, N, 80, M]. - audio_lens: Length of audio frames for each audio chunk [B]. - audio_token_len: Length of audio tokens for each audio chunk [B']. - Note: batch dim is different from batch dim in audio chunks. + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ @@ -663,7 +662,7 @@ def pad_and_concat_to_dim3( max_len = max(f.shape[-1] for f in features) # Ensure all features have dim=3 features = [f.view(-1, *f.shape[-2:]) for f in features] - # Pad and oncatenate: + # Pad and concatenate: # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features] return torch.cat(features) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 28cfefac30ddb..e716ec582baab 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -761,3 +761,10 @@ def fast_topk(values: torch.Tensor, topk: int, else: # Use topk for efficiency with larger k values return torch.topk(values, topk, dim=dim) + + +def get_model_hidden_size(hf_config: PretrainedConfig) -> int: + if hasattr(hf_config, "hidden_size"): + return hf_config.hidden_size + text_config = hf_config.get_text_config() + return text_config.hidden_size diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index de30509b1ccb4..81f86db7e1875 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -7,7 +7,6 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union import torch from transformers import PretrainedConfig -from vllm.attention.selector import get_env_variable_attn_backend from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform @@ -68,17 +67,18 @@ def get_vision_encoder_info( raise NotImplementedError(msg) -def get_vit_attn_backend(support_fa: bool = False) -> _Backend: +def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: """ Get the available attention backend for Vision Transformer. """ - # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. + # Lazy import to avoid circular dependency + from vllm.attention.selector import get_env_variable_attn_backend selected_backend: Optional[_Backend] = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend - return current_platform.get_vit_attn_backend(support_fa) + return current_platform.get_vit_attn_backend(head_size, dtype) def resolve_visual_encoder_outputs( diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index f3731b389cfe0..16a97389cd21b 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -23,15 +23,18 @@ from transformers.tokenization_utils_base import TextInput from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.inputs.data import PromptType from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import SupportsPP +from vllm.model_executor.models.module_mapping import MultiModelKeys # yapf: disable from vllm.model_executor.models.whisper import WhisperEncoder # yapf: enable from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) + MultiModalKwargsItems, MultiModalUUIDDict, + NestedTensors) from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -43,8 +46,8 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription) +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsTranscription) from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -290,14 +293,14 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_hash_overrides: Optional[dict[str, list[str]]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) # NOTE: The tokens are already inserted by the chat template @@ -312,13 +315,25 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] info=VoxtralProcessingInfo, dummy_inputs=VoxtralDummyInputsBuilder) class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP, SupportsTranscription): + SupportsPP, SupportsLoRA, + SupportsTranscription): supported_languages = ISO639_1_SUPPORTED_LANGS + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) + # update quant config to so that ignored module and target module names + # match the vLLM model names + if hasattr(vllm_config, "quant_config"): + vllm_config.quant_config = self.maybe_update_quant_config( + vllm_config.quant_config) + config = vllm_config.model_config.hf_config self.config = config self.downsample_factor = self.config.audio_config.downsample_factor @@ -340,6 +355,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, def get_language_model(self) -> torch.nn.Module: return self.language_model + def get_mm_mapping(self) -> MultiModelKeys: + """Get module prefix for multimodal models to filter LoRA modules.""" + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="audio_language_adapter", + tower_model=["whisper_encoder"], + ) + def forward( self, input_ids: torch.Tensor, @@ -542,6 +565,72 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, return loaded_weights + def maybe_update_quant_config( + self, quant_config: QuantizationConfig) -> QuantizationConfig: + """ + Update quant config to so that ignored module and target module names + match the vLLM model names. + Right now this is specific for compressed-tensors format and + load_format mistral. + """ + remapping_rules = [ + (r"output", r"language_model.lm_head"), + (r"layers\.(\d+)\.attention\.wo", + r"language_model.model.layers.\1.self_attn.out_proj"), + (r"layers\.(\d+)\.attention\.w(.*)", + r"language_model.model.layers.\1.self_attn.\2_proj"), + (r"layers\.(\d+)\.feed_forward\.w1", + r"language_model.model.layers.\1.mlp.gate_proj"), + (r"layers\.(\d+)\.feed_forward\.w2", + r"language_model.model.layers.\1.mlp.down_proj"), + (r"layers\.(\d+)\.feed_forward\.w3", + r"language_model.model.layers.\1.mlp.up_proj"), + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj" + ), + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj" + ), + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"), + (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0", + r"whisper_encoder.whisper_encoder.conv1"), + (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1", + r"whisper_encoder.whisper_encoder.conv2"), + (r"mm_whisper_embeddings\.audio_language_projection\.0", + r"audio_language_adapter.w_in"), + (r"mm_whisper_embeddings\.audio_language_projection\.2", + r"audio_language_adapter.w_out"), + ] + + # Update ignore list + if hasattr(quant_config, "ignore"): + mistral_ignore = [] + for name in quant_config.ignore: + mistral_name = name + for pattern, repl in remapping_rules: + if re.fullmatch(pattern, name): + mistral_name = re.sub(pattern, repl, name) + mistral_ignore.append(mistral_name) + quant_config.ignore = mistral_ignore + + # Update target list + if hasattr(quant_config, "config_groups"): + config_groups = quant_config.config_groups + for group_name in config_groups: + if "targets" in config_groups[group_name]: + targets = [] + for name in config_groups[group_name]["targets"]: + mistral_name = name + for pattern, repl in remapping_rules: + if re.fullmatch(pattern, name): + mistral_name = re.sub(pattern, repl, name) + targets.append(mistral_name) + config_groups[group_name]["targets"] = targets + quant_config.config_groups = config_groups + + return quant_config + class AudioLanguageAdapter(nn.Module): @@ -584,7 +673,6 @@ class VoxtralEncoderModel(nn.Module): self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "whisper_encoder"), - is_standalone_encoder=True, init_in_fp32=True) mel_filters = mel_filter_bank( num_frequency_bins=1 + self.config.window_size // 2, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 848b6e0f8093a..41ae7b129782d 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import Literal, Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import numpy as np import torch @@ -15,6 +15,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionType from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.cross_attention import CrossAttention from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig) from vllm.distributed import get_tensor_model_parallel_world_size @@ -40,9 +41,10 @@ from vllm.multimodal.processing import (BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription, SupportsV0Only) + SupportsTranscription) from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, make_layers) @@ -111,9 +113,44 @@ ISO639_1_SUPPORTED_LANGS = { } -class WhisperAudioInputs(TypedDict): - input_features: NestedTensors - """Shape: `(batch_size, 128, M)`""" +class WhisperAudioInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - nmb: Number of mel bins + - t: Time frames (M) + """ + + input_features: Annotated[Optional[NestedTensors], + TensorShape("b", "nmb", "t")] + + +class WhisperEncoderAttention(MultiHeadAttention): + """Multi-headed attention for Whisper encoder with 2D tensor support.""" + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + """ + Input shape: batch_size x seq_len x hidden_size + or seq_len x hidden_size + """ + is_2d = query.dim() == 2 + if is_2d: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + # Call the parent forward method + out = super().forward(query, key, value) + + if is_2d: + out = out.squeeze(0) + + return out class WhisperPositionalEmbedding(nn.Embedding): @@ -136,7 +173,6 @@ class WhisperAttention(nn.Module): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - standalone_encoder: bool = False, ): super().__init__() self.embed_dim = embed_dim @@ -172,14 +208,25 @@ class WhisperAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - if standalone_encoder: - self.attn = MultiHeadAttention( + if attn_type == AttentionType.ENCODER: + self.attn = WhisperEncoderAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, ) - else: + elif self.attn_type == AttentionType.ENCODER_DECODER: + self.attn = CrossAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=self.attn_type, + ) + else: # AttentionType.DECODER (regular decoder self-attention) self.attn = Attention( self.num_heads, self.head_dim, @@ -324,11 +371,7 @@ class WhisperMLP(nn.Module): class WhisperEncoderLayer(nn.Module): - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - is_standalone_encoder: bool = False): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -342,7 +385,6 @@ class WhisperEncoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - standalone_encoder=is_standalone_encoder, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.mlp = WhisperMLP( @@ -438,12 +480,10 @@ class WhisperEncoder(nn.Module): *, vllm_config: VllmConfig, prefix: str = "", - is_standalone_encoder: bool = False, init_in_fp32: bool = False): super().__init__() config = vllm_config.model_config.hf_config embed_dim = config.d_model - self.is_standalone_encoder = is_standalone_encoder self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions self.embed_scale = (math.sqrt(embed_dim) @@ -461,9 +501,7 @@ class WhisperEncoder(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, - prefix=f"{prefix}.layers", - is_standalone_encoder= - is_standalone_encoder), + prefix=f"{prefix}.layers"), prefix=f"{prefix}.layers", ) self.layer_norm = nn.LayerNorm(config.d_model) @@ -744,7 +782,7 @@ class WhisperMultiModalProcessor( info=WhisperProcessingInfo, dummy_inputs=WhisperDummyInputsBuilder) class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, - SupportsMultiModal, SupportsV0Only): + SupportsMultiModal): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", @@ -872,19 +910,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: - # TODO: This method does not obey the interface for SupportsMultiModal. - # Refactor this once encoder/decoder support is implemented in V1. + # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) - return self.model.get_encoder_outputs(audio_input["input_features"]) + return [self.model.get_encoder_outputs(audio_input["input_features"])] def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: - # TODO: This method just returns the decoder sequence embeddings since - # Whisper does not have encoder text tokens. Refactor this once - # encoder/decoder support is implemented in V1. + # This method just returns the decoder sequence embeddings since + # Whisper does not have encoder text tokens. return self.model.decoder.get_input_embeddings(input_ids) def _parse_and_validate_audio_input( diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index ed65944c109bd..e601bc3adb6e9 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -9,7 +9,7 @@ model alternates between state space model layers and attention-based layers. """ from collections.abc import Iterable from itertools import cycle -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch import nn @@ -528,8 +528,6 @@ class Zamba2MambaDecoderLayer(nn.Module): hidden_states: Input tensor [batch_size, seq_len, hidden_size] mamba_cache_params: Parameters for Mamba's state caches (one for conv, one for ssm) - sequence_idx: Index tensor for identifying sequences in batch - Required for proper chunked processing in prefill transformer_hidden_states: Optional output from transformer path Added to input if provided (used in hybrid architecture) positions: Optional position IDs (unused in Mamba) @@ -591,8 +589,6 @@ class Zamba2HybridLayer(nn.Module): Args: shared_transformer: Transformer decoder layer for attention pathway - linear: Linear projection for transformer output before Mamba - mamba: Mamba decoder layer for state space pathway """ super().__init__() self.block_idx = block_idx @@ -630,8 +626,6 @@ class Zamba2HybridLayer(nn.Module): positions: Position IDs for positional embeddings mamba_cache_params: Parameters for Mamba's state caches (one for conv, one for ssm) - sequence_idx: Indices for identifying sequences in batch, - required for proper chunked processing in prefill Returns: Output tensor combining transformer and Mamba representations @@ -915,8 +909,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): prefix: Optional prefix for parameter names Raises: - AssertionError: If prefix caching is enabled (not supported by - Mamba) + AssertionError: If prefix caching is enabled + (not supported by Mamba) """ config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -947,6 +941,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), ) # Tie weights with input embeddings if using same dimensions self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) @@ -971,7 +966,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> torch.Tensor: + **kwargs: Any) -> torch.Tensor: """Forward pass through the model. Args: @@ -1012,9 +1007,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str, - torch.Tensor], - **kwargs) -> dict[str, torch.Tensor]: + def copy_inputs_before_cuda_graphs( + self, input_buffers: dict[str, torch.Tensor], + **kwargs: Any) -> dict[str, torch.Tensor]: """Copy inputs before CUDA graph capture. Args: diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 9465308e94e65..221712ba9a338 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -57,6 +57,8 @@ class BasevLLMParameter(Parameter): weight_loader = _make_synced_weight_loader(weight_loader) self._weight_loader = weight_loader + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() @property def weight_loader(self): @@ -116,10 +118,10 @@ class _ColumnvLLMParameter(BasevLLMParameter): return self._output_dim def load_column_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() shard_size = self.data.shape[self.output_dim] loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + self.tp_rank * shard_size, + shard_size) assert self.data.shape == loaded_weight.shape self.data.copy_(loaded_weight) @@ -127,6 +129,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") + # TODO: move these to PackedColumnParameter and PackedvLLMParameter if isinstance( self, @@ -137,11 +140,11 @@ class _ColumnvLLMParameter(BasevLLMParameter): param_data = self.data - tp_rank = get_tensor_model_parallel_rank() param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(self.output_dim, - tp_rank * shard_size, shard_size) + self.tp_rank * shard_size, + shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -161,8 +164,8 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_offset=shard_offset, shard_size=shard_size) param_data = self.data - tp_rank = get_tensor_model_parallel_rank() - shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank // + num_heads) param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) loaded_weight = loaded_weight.narrow(self.output_dim, @@ -189,10 +192,10 @@ class RowvLLMParameter(BasevLLMParameter): return self._input_dim def load_row_parallel_weight(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() shard_size = self.data.shape[self.input_dim] loaded_weight = loaded_weight.narrow(self.input_dim, - tp_rank * shard_size, shard_size) + self.tp_rank * shard_size, + shard_size) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) @@ -414,9 +417,6 @@ class SharedWeightParameter(BasevLLMParameter): "weight_loader": self._fake_weight_loader } - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size > 1: raise NotImplementedError(f"{self.__class__.__name__} does not " "currently support tensor parallelism") diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 41ed0b09c5a2a..65436786f82ac 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -52,10 +52,11 @@ def set_weight_attrs( def _make_synced_weight_loader(original_weight_loader): def _synced_weight_loader(param, *args, **kwargs): - original_weight_loader(param, *args, **kwargs) + out = original_weight_loader(param, *args, **kwargs) # torch._sync doesn't support, is not needed for CPU tensors. if param.device != torch.device("cpu"): torch._sync(param) + return out return _synced_weight_loader diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 74599fa44c88c..a636a714145cf 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -10,6 +10,7 @@ import torch from tqdm import tqdm import vllm.envs as envs +from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( compute_aligned_M, deep_gemm_block_shape) @@ -80,9 +81,14 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: - if not (isinstance(module, FusedMoE) - and module.moe_config.quant_dtype == torch.float8_e4m3fn - and module.moe_config.block_shape == deep_gemm_block_shape()): + if not isinstance(module, FusedMoE): + return False + + moe_quant_config = module.quant_method.get_fused_moe_quant_config(module) + + if (moe_quant_config is None + or moe_quant_config.quant_dtype != torch.float8_e4m3fn + or moe_quant_config.block_shape != deep_gemm_block_shape()): return False if not isinstance(module.quant_method.fused_experts, @@ -131,11 +137,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: set[torch.Size] = set() -def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - num_topk: int): +def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( + w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor, + w2_scale: torch.Tensor, num_topk: int, max_tokens: int): if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE): return @@ -147,9 +151,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor, num_experts = w1.size(0) device = w1.device + # Assumes all ranks have the same max_num_batched_tokens + max_tokens_across_dp = get_dp_group().world_size * max_tokens + max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE) + # This is the maximum GroupedGemm M size that we expect to run # the grouped_gemm with. - MAX_M = compute_aligned_M(envs.VLLM_FUSED_MOE_CHUNK_SIZE, + MAX_M = compute_aligned_M(max_tokens, num_topk, num_experts, block_m, @@ -201,7 +209,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) -def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): +def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module, + max_tokens: int): dg_modules = [ m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m) @@ -211,9 +220,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module): w13, w13_scale, w2, w2_scale, num_topk = ( _extract_data_from_fused_moe_module(dgm)) _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( - w13, w2, w13_scale, w2_scale, num_topk) + w13, w2, w13_scale, w2_scale, num_topk, max_tokens) def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): deepgemm_fp8_gemm_nt_warmup(model, max_tokens) - deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model) + deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model, max_tokens) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 761172e4d3616..89ce20308f447 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING import torch import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported @@ -19,6 +20,8 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_worker import Worker +logger = init_logger(__name__) + def kernel_warmup(worker: "Worker"): # Deep GEMM warmup @@ -30,10 +33,33 @@ def kernel_warmup(worker: "Worker"): max_tokens = worker.scheduler_config.max_num_batched_tokens deep_gemm_warmup(model, max_tokens) - # FlashInfer autotune for Blackwell (SM 10.0) GPUs - if has_flashinfer() and current_platform.is_device_capability(100): + # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs + if has_flashinfer() and current_platform.has_device_capability(90): flashinfer_autotune(worker.model_runner) + # FlashInfer attention warmup + # Only warmup if the model has FlashInfer attention groups + # and is not a pooling model + def _is_flashinfer_backend(backend): + try: + return backend.get_name() == "FLASHINFER_VLLM_V1" + except NotImplementedError: + return False + + if not worker.model_runner.is_pooling_model and all( + _is_flashinfer_backend(group.backend) + for groups in worker.model_runner.attn_groups for group in groups): + logger.info("Warming up FlashInfer attention.") + # Warmup with mixed batch containing both prefill and decode tokens + # This is to warm up both prefill and decode attention kernels + worker.model_runner._dummy_run( + num_tokens=16, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_mixed_batch=True, + ) + def flashinfer_autotune(runner: "GPUModelRunner") -> None: """ diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index f3b273eb41e8f..d7e9d402a1f97 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -106,7 +106,7 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: return librosa.load(filepath, sr=None) - def encode_base64(self, media: tuple[npt.NDArray, float]) -> str: + def encode_base64(self, media: tuple[npt.NDArray, int]) -> str: audio, sr = media with BytesIO() as buffer: diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 35b743ed21d92..297b4c7fa7fbd 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -1,21 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import operator import sys from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union +from multiprocessing.synchronize import Lock as LockType +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast import torch from typing_extensions import TypeAlias, override +from vllm.distributed.device_communicators.shm_object_storage import ( + MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer) +from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.logger import init_logger -from vllm.utils import GiB_bytes, LRUCache +from vllm.utils import GiB_bytes, LRUCache, MiB_bytes from vllm.utils.jsontree import (json_count_leaves, json_map_leaves, json_reduce_leaves) -from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem, - MultiModalKwargs, MultiModalKwargsItem, - MultiModalKwargsItems, NestedTensors) +from .inputs import (MultiModalBatchedField, MultiModalFeatureSpec, + MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, MultiModalKwargsItems, + NestedTensors) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -86,26 +92,15 @@ _V = TypeVar("_V", bound=MultiModalCacheValue) class MultiModalCache: @classmethod - def get_leaf_size( - cls, - leaf: object, - *, - debug: bool = False, - ) -> int: + def get_leaf_size(cls, leaf: object) -> int: if isinstance(leaf, MultiModalProcessorCacheItem): return cls.get_leaf_size(leaf.item) if isinstance(leaf, MultiModalProcessorCacheItemMetadata): return leaf.item_size # These are not subclasses of dict - if isinstance(leaf, MultiModalKwargsItems): - return cls.get_item_size(leaf.data) # type: ignore - if isinstance(leaf, MultiModalKwargsItem): - return cls.get_item_size(leaf.data) # type: ignore - if isinstance(leaf, MultiModalKwargs): - return cls.get_item_size(leaf.data) # type: ignore - - if isinstance(leaf, MultiModalFieldElem): + if isinstance(leaf, (MultiModalKwargs, MultiModalKwargsItems, + MultiModalKwargsItem, MultiModalFieldElem)): return cls.get_item_size(leaf.data) # type: ignore # sys.getsizeof doesn't work for tensors @@ -121,11 +116,8 @@ class MultiModalCache: *, debug: bool = False, ) -> int: - size = json_reduce_leaves( - lambda a, b: a + b, - json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug), - value), - ) + size = json_reduce_leaves(operator.add, + json_map_leaves(cls.get_leaf_size, value)) if debug: leaf_count = json_count_leaves(value) @@ -389,6 +381,106 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): self._cache.clear() +class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the data in shared memory. + """ + + def __init__(self, vllm_config: "VllmConfig") -> None: + super().__init__() + + self.world_size = vllm_config.parallel_config.world_size + mm_config = vllm_config.model_config.get_multimodal_config() + + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), + name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + create=True, # sender is the writer + ) + self._shm_cache = SingleWriterShmObjectStorage( + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * + MiB_bytes, + n_readers=self.world_size, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + ) + # cache (prompt_updates, modality) for P0 only + self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], + str]] = {} + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return self._shm_cache.is_cached(mm_hash) + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + + if self._shm_cache.is_cached(mm_hash): + address, monotonic_id = self._shm_cache.get_cached(mm_hash) + prompt_updates, modality = self._p0_cache[mm_hash] + return self.address_as_item(address, monotonic_id, + modality), prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + try: + address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) + # Try to remove dangling items if p0 cache is too large. + if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index): + self.remove_dangling_items() + self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality + address_item = self.address_as_item(address, monotonic_id, + mm_item[0].modality) + return address_item, mm_item[1] + except (ValueError, MemoryError) as e: + # put may fail if the object is too large or + # the cache is full. + # In this case we log the error and keep the original mm_input. + logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, + e) + return mm_item + + @override + def clear_cache(self) -> None: + self._shm_cache.clear() + self._p0_cache.clear() + + def remove_dangling_items(self) -> None: + """Remove items that are no longer in the shared memory cache.""" + cached_hashes = self._shm_cache.key_index.keys() + dangling_hashes = set(self._p0_cache.keys()) - cached_hashes + for mm_hash in dangling_hashes: + del self._p0_cache[mm_hash] + + def address_as_item(self, address: int, monotonic_id: int, + modality: str) -> MultiModalKwargsItem: + addr_elem = MultiModalFieldElem( + modality=modality, + key="address", + data=address, + field=MultiModalBatchedField(), + ) + id_elem = MultiModalFieldElem( + modality=modality, + key="monotonic_id", + data=monotonic_id, + field=MultiModalBatchedField(), + ) + mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem]) + return mm_item + + def _enable_processor_cache( model_config: "ModelConfig", mm_registry: "MultiModalRegistry", @@ -408,6 +500,17 @@ def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: return supports_ipc_cache +def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool: + """Whether the shared memory based cache should be enabled.""" + + if not _enable_ipc_cache(vllm_config): + return False + + mm_config = vllm_config.model_config.get_multimodal_config() + + return mm_config.mm_processor_cache_type == "shm" + + def processor_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", @@ -421,7 +524,9 @@ def processor_cache_from_config( if not _enable_ipc_cache(vllm_config): return MultiModalProcessorOnlyCache(model_config) - return MultiModalProcessorSenderCache(model_config) + if not _enable_mm_input_shm_cache(vllm_config): + return MultiModalProcessorSenderCache(model_config) + return ShmObjectStoreSenderCache(vllm_config) def processor_only_cache_from_config( @@ -491,11 +596,68 @@ class MultiModalReceiverCache(BaseMultiModalReceiverCache): self._cache.clear() -def receiver_cache_from_config( +class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 Worker Process when IPC caching is enabled. + + How to update each item: + + - If the item has an address, replace the input with the cached item. + - If not, return the input. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + shared_worker_lock: LockType, + ) -> None: + super().__init__() + + self.world_size = vllm_config.parallel_config.world_size + mm_config = vllm_config.model_config.get_multimodal_config() + + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), + name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + create=False, # Server is a reader + ) + self._shm_cache = SingleWriterShmObjectStorage( + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * + MiB_bytes, + n_readers=self.world_size, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=shared_worker_lock, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + assert mm_item is not None, f"Expected an address item for {mm_hash=}" + if "address" in mm_item: + address = cast(int, mm_item["address"].data) + monotonic_id = cast(int, mm_item["monotonic_id"].data) + return self._shm_cache.get(address, monotonic_id) + + return mm_item + + @override + def clear_cache(self) -> None: + self._shm_cache.clear() + + +def engine_receiver_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", ) -> Optional[BaseMultiModalReceiverCache]: - """Return a `BaseMultiModalReceiverCache`, if enabled.""" + """ + This is used in the engine process. + Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and + mm_processor_cache_type=="lru". + """ model_config = vllm_config.model_config if not _enable_processor_cache(model_config, mm_registry): @@ -504,4 +666,31 @@ def receiver_cache_from_config( if not _enable_ipc_cache(vllm_config): return None - return MultiModalReceiverCache(model_config) + if not _enable_mm_input_shm_cache(vllm_config): + return MultiModalReceiverCache(model_config) + + return None + + +def worker_receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", + shared_worker_lock: LockType, +) -> Optional[BaseMultiModalReceiverCache]: + """ + This is used in the worker process. + Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and + mm_processor_cache_type=="shm". + """ + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + if not _enable_mm_input_shm_cache(vllm_config): + return None + + return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock) diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index da019d40a6fe4..df6c531d876ad 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -12,7 +12,6 @@ from blake3 import blake3 from PIL import Image from vllm.logger import init_logger -from vllm.multimodal.image import convert_image_mode logger = init_logger(__name__) @@ -20,23 +19,27 @@ logger = init_logger(__name__) class MultiModalHasher: @classmethod - def serialize_item(cls, obj: object) -> Union[bytes, memoryview]: + def serialize_item(cls, obj: object) -> Iterable[Union[bytes, memoryview]]: # Simple cases - if isinstance(obj, str): - return obj.encode("utf-8") if isinstance(obj, (bytes, memoryview)): - return obj + return (obj, ) + if isinstance(obj, str): + return (obj.encode("utf-8"), ) if isinstance(obj, (int, float)): - return np.array(obj).tobytes() + return (np.array(obj).tobytes(), ) if isinstance(obj, Image.Image): exif = obj.getexif() if Image.ExifTags.Base.ImageID in exif and isinstance( exif[Image.ExifTags.Base.ImageID], uuid.UUID): # If the image has exif ImageID tag, use that - return exif[Image.ExifTags.Base.ImageID].bytes - return cls.item_to_bytes( - "image", np.asarray(convert_image_mode(obj, "RGBA"))) + return (exif[Image.ExifTags.Base.ImageID].bytes, ) + data = {"mode": obj.mode, "data": np.asarray(obj)} + if obj.palette is not None: + data["palette"] = obj.palette.palette + if obj.palette.rawmode is not None: + data["palette_rawmode"] = obj.palette.rawmode + return cls.iter_item_to_bytes("image", data) if isinstance(obj, torch.Tensor): tensor_obj: torch.Tensor = obj.cpu() tensor_dtype = tensor_obj.dtype @@ -49,43 +52,34 @@ class MultiModalHasher: tensor_obj = tensor_obj.view( (tensor_obj.numel(), )).view(torch.uint8) - return cls.item_to_bytes( + return cls.iter_item_to_bytes( "tensor", { "original_dtype": str(tensor_dtype), "original_shape": tuple(tensor_shape), "data": tensor_obj.numpy(), }) - - return cls.item_to_bytes("tensor", tensor_obj.numpy()) + return cls.iter_item_to_bytes("tensor", tensor_obj.numpy()) if isinstance(obj, np.ndarray): # If the array is non-contiguous, we need to copy it first - arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes() - return cls.item_to_bytes("ndarray", { + arr_data = obj.view( + np.uint8).data if obj.flags.c_contiguous else obj.tobytes() + return cls.iter_item_to_bytes("ndarray", { "dtype": obj.dtype.str, "shape": obj.shape, "data": arr_data, }) - logger.warning( "No serialization method found for %s. " "Falling back to pickle.", type(obj)) - return pickle.dumps(obj) - - @classmethod - def item_to_bytes( - cls, - key: str, - obj: object, - ) -> bytes: - return b''.join(kb + vb for kb, vb in cls.iter_item_to_bytes(key, obj)) + return (pickle.dumps(obj), ) @classmethod def iter_item_to_bytes( cls, key: str, obj: object, - ) -> Iterable[tuple[bytes, Union[bytes, memoryview]]]: + ) -> Iterable[Union[bytes, memoryview]]: # Recursive cases if isinstance(obj, (list, tuple)): for i, elem in enumerate(obj): @@ -94,17 +88,15 @@ class MultiModalHasher: for k, v in obj.items(): yield from cls.iter_item_to_bytes(f"{key}.{k}", v) else: - key_bytes = key.encode("utf-8") - value_bytes = cls.serialize_item(obj) - yield key_bytes, value_bytes + yield key.encode("utf-8") + yield from cls.serialize_item(obj) @classmethod def hash_kwargs(cls, **kwargs: object) -> str: hasher = blake3() for k, v in kwargs.items(): - for k_bytes, v_bytes in cls.iter_item_to_bytes(k, v): - hasher.update(k_bytes) - hasher.update(v_bytes) + for bytes_ in cls.iter_item_to_bytes(k, v): + hasher.update(bytes_) return hasher.hexdigest() diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index f8ea3835f049d..240e34e139cfe 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -85,9 +85,10 @@ which are treated as audio embeddings; these are directly passed to the model without HF processing. """ -ModalityData: TypeAlias = Union[_T, list[_T]] +ModalityData: TypeAlias = Union[_T, list[Optional[_T]], None] """ -Either a single data item, or a list of data items. +Either a single data item, or a list of data items. Can only be None if UUID +is provided. The number of data items allowed per modality is restricted by `--limit-mm-per-prompt`. diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 88bb99529f200..493dd3560a516 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -36,7 +36,7 @@ class ModalityDataItems(ABC, Generic[_T, _I]): def __init__(self, data: _T, modality: str) -> None: super().__init__() - self.data = data + self.data: _T = data self.modality = modality def __repr__(self) -> str: @@ -177,7 +177,9 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): - def __init__(self, data: Sequence[HfAudioItem]) -> None: + def __init__(self, data: Optional[Sequence[HfAudioItem]]) -> None: + if data is None: + data = [None] super().__init__(data, "audio") def get_audio_length(self, item_idx: int) -> int: @@ -198,7 +200,9 @@ class ImageSize(NamedTuple): class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): - def __init__(self, data: Sequence[HfImageItem]) -> None: + def __init__(self, data: Optional[Sequence[HfImageItem]]) -> None: + if data is None: + data = [None] super().__init__(data, "image") def get_image_size(self, item_idx: int) -> ImageSize: @@ -223,10 +227,12 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): def __init__( self, - data: Sequence[HfVideoItem], + data: Optional[Sequence[HfVideoItem]], metadata: Optional[Union[dict[str, Any], list[Optional[dict[str, Any]]]]] = None, ) -> None: + if data is None: + data = [None] super().__init__(data, "video") self.metadata = metadata @@ -385,6 +391,9 @@ class MultiModalDataParser: self, data: ModalityData[AudioItem], ) -> Optional[ModalityDataItems[Any, Any]]: + if data is None: + return AudioProcessorItems(None) + # also check single audio item with sampling rate if self._is_empty(data) or (isinstance(data, tuple) and self._is_empty(data[0])): @@ -420,6 +429,9 @@ class MultiModalDataParser: self, data: ModalityData[ImageItem], ) -> Optional[ModalityDataItems[Any, Any]]: + if data is None: + return ImageProcessorItems(None) + if self._is_empty(data): return None @@ -441,6 +453,9 @@ class MultiModalDataParser: self, data: ModalityData[VideoItem], ) -> Optional[ModalityDataItems[Any, Any]]: + if data is None: + return VideoProcessorItems(None) + if self._is_empty(data): return None diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 0531b7bd9f0a7..7471bfcb4d508 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1022,13 +1022,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: return self.apply(prompt, mm_data, hf_processor_mm_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) def _get_data_parser(self) -> MultiModalDataParser: """ @@ -1076,7 +1075,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. """ mm_items = self.data_parser.parse_mm_data(mm_data) - for modality, items in mm_items.items(): self.validate_num_items(modality, len(items)) @@ -1364,8 +1362,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalHashes: """Create MM hashes to be returned (only used in V1). @@ -1376,30 +1373,30 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): model_id = self.info.model_id hashes: MultiModalHashes = {} - mm_hash_overrides = mm_hash_overrides or {} + mm_uuids = mm_uuids or {} for modality, items in mm_items.items(): - if modality in mm_hash_overrides: - mm_hashes = mm_hash_overrides[modality] - if isinstance(mm_hashes, str): - mm_hashes = [mm_hashes] + if modality in mm_uuids: + mm_uuids_per_modality = mm_uuids[modality] + if isinstance(mm_uuids_per_modality, str): + mm_uuids_per_modality = [mm_uuids_per_modality] # For None entries, compute a hash; otherwise, use provided ID. computed: list[str] = [] for i, item in enumerate(items): - mm_hash = mm_hashes[i] + item_uuid = mm_uuids_per_modality[i] - # NOTE: Even if a mm_hash is provided, we still compute a + # NOTE: Even if a item_uuid is provided, we still compute a # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs` # are provided. This is because the processed multimodal # inputs can be different depending on the processor kwargs. - if mm_hash is None or \ + if item_uuid is None or \ hf_processor_mm_kwargs or \ tokenization_kwargs: # NOTE: use provided hash string to hash with kwargs # if available for better performance. - item = mm_hash if mm_hash is not None else item + item = item_uuid if item_uuid is not None else item computed.append( MultiModalHasher.hash_kwargs( model_id=model_id, @@ -1407,7 +1404,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): **hf_processor_mm_kwargs, **tokenization_kwargs)) else: - computed.append(mm_hash) + computed.append(item_uuid) hashes[modality] = computed else: hashes[modality] = [ @@ -1438,10 +1435,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ] for modality, items_is_cached in mm_is_cached.items() } - mm_missing_data = { - modality: [mm_data_items[modality][idx] for idx in idxs] - for modality, idxs in mm_missing_idxs.items() - } + mm_missing_data = {} + for modality, idxs in mm_missing_idxs.items(): + missing_modality_data = [] + for idx in idxs: + data = mm_data_items[modality][idx] + if data is None: + raise ValueError( + f"Cache miss for {modality} at index {idx} " + f"but data is not provided.") + else: + missing_modality_data.append(data) + mm_missing_data[modality] = missing_modality_data return self._to_mm_items(mm_missing_data) @@ -1514,8 +1519,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: ( prompt_ids, @@ -1539,7 +1543,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) mm_prompt_updates = self._get_mm_prompt_updates( mm_data_items, @@ -1562,8 +1566,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: """ Apply the HF processor on the full prompt text, @@ -1578,13 +1581,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides) + mm_uuids=mm_uuids) mm_missing_data_items = self._get_cache_missing_items( cache=cache, @@ -1785,8 +1788,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1815,7 +1817,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items, hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) # NOTE: tokenization_kwargs are not required to init processor @@ -1901,8 +1903,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, *, - mm_hash_overrides: Optional[Union[dict[str, list[str]], - MultiModalUUIDDict]] = None, + mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalEncDecInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1917,7 +1918,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) return self._get_enc_dec_inputs( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index ffc69a2db60a4..fbbc55d3524ca 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]): if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper and Donut allows total_len > seq_len. + # NOTE: Whisper allows total_len > seq_len. elif total_len > seq_len and not envs.VLLM_USE_V1: # `max_num_batched_tokens` is defined by `SchedulerConfig` logger.warning_once( @@ -301,7 +301,7 @@ class MultiModalProfiler(Generic[_I]): Returns the maximum length of the multimodal (image placeholders+text) tokens, including any break/text tokens in-between image embeddings. - <im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end> + `<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>` Returns 9, even when the number of image embeddings is 6. This is important to take into account when profiling and diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index ac967dcc4003e..b308366fca282 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -213,7 +213,7 @@ class MediaConnector: image_mode: str = "RGB", ) -> Image.Image: """ - Load a PIL image from a HTTP or base64 data URL. + Load a PIL image from an HTTP or base64 data URL. By default, the image is converted into RGB format. """ @@ -237,7 +237,7 @@ class MediaConnector: image_mode: str = "RGB", ) -> Image.Image: """ - Asynchronously load a PIL image from a HTTP or base64 data URL. + Asynchronously load a PIL image from an HTTP or base64 data URL. By default, the image is converted into RGB format. """ @@ -261,7 +261,7 @@ class MediaConnector: image_mode: str = "RGB", ) -> tuple[npt.NDArray, dict[str, Any]]: """ - Load video from a HTTP or base64 data URL. + Load video from an HTTP or base64 data URL. """ image_io = ImageMediaIO(image_mode=image_mode, **self.media_io_kwargs.get("image", {})) @@ -281,7 +281,7 @@ class MediaConnector: image_mode: str = "RGB", ) -> tuple[npt.NDArray, dict[str, Any]]: """ - Asynchronously load video from a HTTP or base64 data URL. + Asynchronously load video from an HTTP or base64 data URL. By default, the image is converted into RGB format. """ @@ -310,7 +310,7 @@ class MediaConnector: def encode_audio_base64( audio: np.ndarray, - sampling_rate: float, + sampling_rate: int, ) -> str: """Encode audio as base64.""" audio_io = AudioMediaIO() @@ -370,7 +370,7 @@ def group_mm_inputs_by_modality( def modality_group_func( mm_input: MultiModalKwargsItems) -> Union[str, int]: - # If the input has multiple modalities, return a id as the unique key + # If the input has multiple modalities, return an id as the unique key # for the mm_input input. if len(mm_input) > 1: return id(mm_input) @@ -378,10 +378,7 @@ def group_mm_inputs_by_modality( elif len(mm_input) == 1: return next(iter(mm_input.keys())) - # FIXME(Isotr0py): Modality of mm_input from legacy pipeline is empty, - # this is used to make InternVL with legacy pipeline still work with v1. - else: - return "" + raise AssertionError("This line should be unreachable.") return [ list(group) for _, group in groupby(mm_inputs, key=modality_group_func) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index ef1380bdb614c..6981f2ce56239 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import base64 +import math from abc import abstractmethod from functools import partial from io import BytesIO from pathlib import Path -from typing import Any +from typing import Any, Union import numpy as np import numpy.typing as npt @@ -104,10 +104,12 @@ class OpenCVVideoBackend(VideoLoader): return api_pref @classmethod - def load_bytes(cls, - data: bytes, - num_frames: int = -1, - **kwargs) -> tuple[npt.NDArray, dict[str, Any]]: + def load_bytes( + cls, + data: bytes, + num_frames: int = -1, + **kwargs, + ) -> tuple[npt.NDArray, dict[str, Any]]: import cv2 backend = cls().get_cv2_video_api() @@ -119,6 +121,7 @@ class OpenCVVideoBackend(VideoLoader): original_fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames_num / original_fps if original_fps > 0 else 0 + # resample video to target num_frames full_read = num_frames == -1 or total_frames_num < num_frames if full_read: num_frames = total_frames_num @@ -148,12 +151,101 @@ class OpenCVVideoBackend(VideoLoader): assert i == num_frames, (f"Expected reading {num_frames} frames, " f"but only loaded {i} frames from video.") + # Use transformers transformers.video_utils.VideoMetadata format + # NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata + # can cause incorrect timestamp calculation without num_frames=-1. + metadata = { + "total_num_frames": num_frames, + "fps": num_frames / duration, + "duration": duration, + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + # extra field used to control hf processor's video + # sampling behavior + "do_sample_frames": num_frames == total_frames_num, + } + + return frames, metadata + + +@VIDEO_LOADER_REGISTRY.register("opencv_dynamic") +class OpenCVDynamicVideoBackend(OpenCVVideoBackend): + + @classmethod + def load_bytes( + cls, + data: bytes, + num_frames: int = -1, + fps: int = 2, + max_duration: int = 300, + **kwargs, + ) -> tuple[npt.NDArray, dict[str, Any]]: + import cv2 + + backend = cls().get_cv2_video_api() + cap = cv2.VideoCapture(BytesIO(data), backend, []) + if not cap.isOpened(): + raise ValueError("Could not open video stream") + + total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + original_fps = cap.get(cv2.CAP_PROP_FPS) + duration = total_frames_num / original_fps if original_fps > 0 else 0 + + # resample video to target num_frames + max_frame_idx = total_frames_num - 1 + duration = duration or round(max_frame_idx / original_fps) + 1 + + # Refer to: + # https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140 + frame_indices: Union[range, list[int]] + if duration <= max_duration: + n = int(math.floor(duration * fps)) + frame_indices = sorted({ + min(max_frame_idx, int(math.ceil(i * original_fps / fps))) + for i in range(n) + }) + else: + num_samples = int(max_duration * fps) + if num_samples >= total_frames_num: + frame_indices = range(total_frames_num) + else: + target_seconds = np.linspace(0, + duration, + num_samples, + endpoint=True) + frame_indices = sorted({ + min(max_frame_idx, int(math.ceil(t * original_fps))) + for t in target_seconds + }) + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames = np.empty((len(frame_indices), height, width, 3), + dtype=np.uint8) + + i = 0 + for idx in range(total_frames_num): + ok = cap.grab() + if not ok: + break + if idx in frame_indices: + ret, frame = cap.retrieve() + if ret: + frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + i += 1 + + assert i == len(frame_indices), ( + f"Expected reading {len(frame_indices)} frames, " + f"but only loaded {i} frames from video.") + # Use transformers transformers.video_utils.VideoMetadata format metadata = { "total_num_frames": total_frames_num, "fps": original_fps, "duration": duration, - "video_backend": "opencv" + "video_backend": "opencv_dynamic", + "frames_indices": list(frame_indices), + "do_sample_frames": False, } return frames, metadata diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 56edb8629e45b..9b64817da648c 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -169,37 +169,12 @@ def cpu_platform_plugin() -> Optional[str]: return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None -def neuron_platform_plugin() -> Optional[str]: - tnx_installed = False - nxd_installed = False - logger.debug("Checking if Neuron platform is available.") - try: - import transformers_neuronx # noqa: F401 - tnx_installed = True - logger.debug("Confirmed Neuron platform is available because" - " transformers_neuronx is found.") - except ImportError: - pass - - try: - import neuronx_distributed_inference # noqa: F401 - nxd_installed = True - logger.debug("Confirmed Neuron platform is available because" - " neuronx_distributed_inference is found.") - except ImportError: - pass - - is_neuron = tnx_installed or nxd_installed - return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None - - builtin_platform_plugins = { 'tpu': tpu_platform_plugin, 'cuda': cuda_platform_plugin, 'rocm': rocm_platform_plugin, 'xpu': xpu_platform_plugin, 'cpu': cpu_platform_plugin, - 'neuron': neuron_platform_plugin, } diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 12d5e0bf08652..544e091491bf5 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -75,12 +75,12 @@ class CpuPlatform(Platform): def supported_dtypes(self) -> list[torch.dtype]: if self.get_cpu_architecture() == CpuArchEnum.POWERPC: return [torch.bfloat16, torch.float32] - elif sys.platform.startswith( - "darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM: - # TODO: change this condition to check if the platform support bf16 - # instead of checking the OS. For instance M2 shall supports bf16 - # already. But we need to modify `cpu_extension.cmake` to activate - # the feature in the build. + elif (self.get_cpu_architecture() == CpuArchEnum.ARM + and sys.platform.startswith("darwin")): + if (subprocess.check_output( + ["sysctl -n hw.optional.arm.FEAT_BF16"], + shell=True).strip() == b"1"): + return [torch.bfloat16, torch.float16, torch.float32] return [torch.float16, torch.float32] # x86/aarch64 CPU has supported both bf16 and fp16 natively. return [torch.bfloat16, torch.float16, torch.float32] @@ -185,6 +185,11 @@ class CpuPlatform(Platform): parallel_config.distributed_executor_backend = "mp" if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker" + # Disable DBO + if parallel_config.enable_dbo: + logger.warning( + "Dual-Batch Overlap is not supported on CPU, disabled.") + parallel_config.enable_dbo = False # Note: workaround for v1 gpu_model_runner from vllm.config import CompilationLevel @@ -347,3 +352,7 @@ class CpuPlatform(Platform): @classmethod def opaque_attention_op(cls) -> bool: return True + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 5cbb7346436ef..8e3436a9e73c5 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -64,8 +64,7 @@ class CudaPlatformBase(Platform): if self.has_device_capability(80): # Ampere and Hopper or later NVIDIA GPUs. return [torch.bfloat16, torch.float16, torch.float32] - elif (not self.has_device_capability(80) - ) and self.has_device_capability(60): + if self.has_device_capability(60): # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported return [torch.float16, torch.float32] # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported, @@ -146,6 +145,7 @@ class CudaPlatformBase(Platform): # required block_size. use_flashmla = False use_cutlass_mla = False + use_flashinfer_mla = False if envs.VLLM_ATTENTION_BACKEND is None: # Default case @@ -164,6 +164,8 @@ class CudaPlatformBase(Platform): use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") use_cutlass_mla = ( envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") + use_flashinfer_mla = ( + envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA") from vllm.attention.ops.flashmla import is_flashmla_supported if use_flashmla and is_flashmla_supported()[0] \ @@ -177,22 +179,26 @@ class CudaPlatformBase(Platform): logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA backend.") + if use_flashinfer_mla and cache_config.block_size not in [32, 64]: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashInferMLA " + "backend.") + # lazy import to avoid circular import from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): + and compilation_config.cudagraph_mode + not in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE]): logger.info( - "Data Parallel: disabling cudagraphs since DP " - "with DeepEP high-throughput kernels are not CUDA Graph " - "compatible. The DeepEP low-latency kernels are CUDA Graph " - "compatible. Set the all_to_all backend to deepep_low_latency " - "to use those kernels instead.") - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - if model_config is not None: - model_config.enforce_eager = True + "Data Parallel with DeepEP high-throughput: using PIECEWISE " + "CUDA graphs and excluding MoE ops from capture. Set " + "VLLM_ALL2ALL_BACKEND=deepep_low_latency if you need MoE " + "graphs captured as well.") + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE @classmethod def get_current_memory_usage(cls, @@ -203,18 +209,24 @@ class CudaPlatformBase(Platform): return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if cls.has_device_capability(80) and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> _Backend: + if dtype not in (torch.float16, torch.bfloat16): + return _Backend.XFORMERS + + if cls.has_device_capability(80): + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + from vllm.attention.selector import is_attn_backend_supported + is_default_fa_supported = is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False) + if is_default_fa_supported: return _Backend.FLASH_ATTN - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision " - "module, so we use xformers backend instead. You can " - "run `pip install flash-attn` to use flash-attention " - "backend.") - # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + else: + # Fallback to XFORMERS + return _Backend.XFORMERS + else: + # Fallback for Volta/Turing GPUs or FA not supported + return _Backend.XFORMERS @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, @@ -223,9 +235,33 @@ class CudaPlatformBase(Platform): if use_mla: # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here - if selected_backend == _Backend.CUTLASS_MLA or ( - cls.is_device_capability(100) and selected_backend is None - and block_size == 128): + + from vllm.attention.ops.flashmla import is_flashmla_supported + from vllm.attention.utils.fa_utils import flash_attn_supports_mla + + use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( + selected_backend is None and cls.is_device_capability(100) + and block_size == 128) + use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( + selected_backend is None and cls.is_device_capability(100) + and block_size in [32, 64]) + use_flashmla = selected_backend in [ + _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 + ] or (selected_backend is None and is_flashmla_supported()[0]) + use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( + selected_backend is None and flash_attn_supports_mla()) + use_triton = selected_backend == _Backend.TRITON_MLA or ( + selected_backend is None) + + def _get_version(name, import_suffix) -> str: + if use_v1: + logger.info_once(f"Using {name} backend on V1 engine.") + return f"vllm.v1.attention.backends.mla.{import_suffix}" + else: + logger.info_once(f"Using {name} backend.") + return f"vllm.attention.backends.{import_suffix}" + + if use_cutlassmla: if use_v1: logger.info_once("Using Cutlass MLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." @@ -233,36 +269,40 @@ class CudaPlatformBase(Platform): else: logger.warning( "Cutlass MLA backend is only supported on V1 engine") - if selected_backend == _Backend.TRITON_MLA or block_size != 64: + if use_flashinfermla: if use_v1: - logger.info_once("Using Triton MLA backend on V1 engine.") + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + set_kv_cache_layout("HND") + logger.info_once( + "Using FlashInfer MLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") + "flashinfer_mla.FlashInferMLABackend") else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" - else: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) - if not is_flashmla_supported()[0]: logger.warning( - "FlashMLA backend is not supported due to %s", - is_flashmla_supported()[1]) - elif block_size != 64: + "FlashInfer MLA backend is only supported on V1 engine" + ) + if use_flashmla: + if block_size != 64: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", block_size) else: - if use_v1: - logger.info_once( - "Using FlashMLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashmla.FlashMLABackend") - else: - logger.info("Using FlashMLA backend.") - return ("vllm.attention.backends." - "flashmla.FlashMLABackend") + return _get_version("FlashMLA", "flashmla.FlashMLABackend") + if use_flashattn: + if use_v1: + logger.info_once( + "Using FlashAttention MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashattn_mla.FlashAttnMLABackend") + else: + logger.warning( + "FlashAttention MLA backend is only supported on V1 " + "engine.") + if use_triton: + return _get_version("Triton MLA", + "triton_mla.TritonMLABackend") if use_v1: FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 @@ -500,8 +540,10 @@ class CudaPlatformBase(Platform): else: attention_backend = "FLASHMLA" - # Only FlashMLA supports fp8 - if attention_backend == "FLASHMLA": + # Only FlashMLA and CUTLASS_MLA support fp8 + if attention_backend in [ + "FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA" + ]: supported = True else: supported = (not fp8_attention) @@ -520,6 +562,10 @@ class CudaPlatformBase(Platform): supported = flash_attn_supports_fp8() else: supported = True + elif attention_backend == "FLASHINFER": + supported = True + elif attention_backend == "TRITON_ATTN_VLLM_V1": + supported = cls.supports_fp8() return supported @classmethod @@ -542,6 +588,10 @@ class CudaPlatformBase(Platform): "You can use float16 instead by explicitly setting the " "`dtype` flag in CLI, for example: --dtype=half.") + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index ad12f7f788cf8..054d08c3a85be 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -48,13 +48,16 @@ class _Backend(enum.Enum): ROCM_AITER_MLA_VLLM_V1 = enum.auto() ROCM_AITER_FA = enum.auto() # used for ViT attn backend TORCH_SDPA = enum.auto() + TORCH_SDPA_VLLM_V1 = enum.auto() FLASHINFER = enum.auto() FLASHINFER_VLLM_V1 = enum.auto() + FLASHINFER_MLA = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA_VLLM_V1 = enum.auto() - FLASHMLA_VLLM_V1 = enum.auto() - FLASHMLA = enum.auto() # Supported by V1 CUTLASS_MLA = enum.auto() + FLASHMLA = enum.auto() # Supported by V1 + FLASHMLA_VLLM_V1 = enum.auto() + FLASH_ATTN_MLA = enum.auto() # Supported by V1 PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() @@ -72,7 +75,6 @@ class PlatformEnum(enum.Enum): TPU = enum.auto() XPU = enum.auto() CPU = enum.auto() - NEURON = enum.auto() OOT = enum.auto() UNSPECIFIED = enum.auto() @@ -163,9 +165,6 @@ class Platform: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU - def is_neuron(self) -> bool: - return self._enum == PlatformEnum.NEURON - def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT @@ -193,7 +192,8 @@ class Platform: return device_id @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> _Backend: return _Backend.TORCH_SDPA @classmethod @@ -587,6 +587,13 @@ class Platform: """ raise NotImplementedError + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + """ + Returns if the hybrid kv cache is supported by the current platform. + """ + return False + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py deleted file mode 100644 index cb8ac8db669fe..0000000000000 --- a/vllm/platforms/neuron.py +++ /dev/null @@ -1,151 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import enum -import os -from functools import lru_cache -from typing import TYPE_CHECKING, Optional - -from vllm import envs -from vllm.logger import init_logger -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS - -from .interface import Platform, PlatformEnum - -if TYPE_CHECKING: - from vllm.config import VllmConfig -else: - VllmConfig = None - -logger = init_logger(__name__) - - -class NeuronFramework(enum.Enum): - TRANSFORMERS_NEURONX = "transformers-neuronx" - NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" - - -class NeuronPlatform(Platform): - _enum = PlatformEnum.NEURON - device_name: str = "neuron" - device_type: str = "neuron" - ray_device_key: str = "neuron_cores" - supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"] - dist_backend: str = "gloo" - device_control_env_var: str = "NEURON_RT_VISIBLE_CORES" - - @classmethod - def get_device_name(cls, device_id: int = 0) -> str: - return "neuron" - - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False - - @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - parallel_config = vllm_config.parallel_config - if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = \ - "vllm.worker.neuron_worker.NeuronWorker" - - if parallel_config.world_size > 1: - parallel_config.distributed_executor_backend = "uni" - - if vllm_config.cache_config and vllm_config.model_config: - # neuron needs block_size = max_model_len - vllm_config.cache_config.block_size = \ - vllm_config.model_config.max_model_len # type: ignore - - if vllm_config.model_config and vllm_config.model_config.use_mla: - logger.info( - "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") - vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False - vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) - - @classmethod - def is_pin_memory_available(cls) -> bool: - logger.warning("Pin memory is not supported on Neuron.") - return False - - @classmethod - def get_device_communicator_cls(cls) -> str: - if envs.VLLM_USE_V1: - return "vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator" # noqa - else: - return Platform.get_device_communicator_cls() - - @classmethod - def use_all_gather(cls) -> bool: - return True - - @classmethod - @lru_cache - def is_neuronx_distributed_inference(cls) -> bool: - try: - import neuronx_distributed_inference - except ImportError: - neuronx_distributed_inference = None - return neuronx_distributed_inference is not None - - @classmethod - @lru_cache - def is_transformers_neuronx(cls) -> bool: - try: - import transformers_neuronx - except ImportError: - transformers_neuronx = None - return transformers_neuronx is not None - - def get_neuron_framework_to_use(self): - """Return the specified framework if corresponding installations are - available. - - If no framework is specified, use neuronx-distributed-inference by - default. - If that's unavailable, check and switch to transformers-neuronx. - """ - if not self.is_neuron(): - raise AssertionError( - f"Neuron Framework unavailable for platform: {self}") - - tnx_installed = self.is_transformers_neuronx() - nxd_installed = self.is_neuronx_distributed_inference() - - specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") - tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value - if specified_framework == tnx_framework and tnx_installed: - return self.TRANSFORMERS_NEURONX - - if ((specified_framework == nxd_framework and nxd_installed) - or (specified_framework is None and nxd_installed)): - return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - - if specified_framework is None and tnx_installed: - return NeuronFramework.TRANSFORMERS_NEURONX - - return None - - def use_neuronx_distributed(self): - """ - Return True if the framework determined in get_neuron_framework_to_use() - is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This - is used to select the Neuron model framework and framework-specific - configuration to apply during model compilation. - """ - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - return self.get_neuron_framework_to_use() == nxd_framework - - def use_transformers_neuronx(self): - """ - Return True if the framework determined in get_neuron_framework_to_use() - is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used - to select the Neuron model framework and framework-specific - configuration to apply during model compilation. - """ - return self.get_neuron_framework_to_use( - ) == NeuronFramework.TRANSFORMERS_NEURONX diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index c6d14aa87c7f2..4f540fe965e22 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -171,19 +171,19 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4" + "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao" ] @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if support_fa: - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA - and on_gfx9()): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. - return _Backend.ROCM_AITER_FA - if on_gfx9(): - return _Backend.FLASH_ATTN + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> _Backend: + if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA + and on_gfx9()): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. + return _Backend.ROCM_AITER_FA + if on_gfx9(): + return _Backend.FLASH_ATTN return _Backend.TORCH_SDPA @classmethod @@ -191,7 +191,7 @@ class RocmPlatform(Platform): kv_cache_dtype, block_size, use_v1, use_mla, has_sink) -> str: if use_mla: - from vllm.attention.backends.rocm_aiter_mla import ( + from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( is_aiter_mla_enabled) if selected_backend is None: @@ -322,23 +322,35 @@ class RocmPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + from vllm.config.compilation import CUDAGraphMode + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + parallel_config = vllm_config.parallel_config + is_eager_execution = compilation_config == CUDAGraphMode.NONE + + use_v1 = envs.VLLM_USE_V1 + use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \ + envs.VLLM_ROCM_USE_AITER_RMSNORM + if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": if vllm_config.speculative_config: - if not envs.VLLM_USE_V1: + if not use_v1: raise NotImplementedError( "Speculative decoding is not supported on vLLM V0.") parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" else: - if envs.VLLM_USE_V1: + if use_v1: parallel_config.worker_cls = \ "vllm.v1.worker.gpu_worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + # Aiter rms norm perform best when CUDA Graph capture is enabled. + if use_v1 and use_aiter_rms_norm and not is_eager_execution: + compilation_config.custom_ops.append("+rms_norm") @classmethod def verify_model_arch(cls, model_arch: str) -> None: @@ -486,3 +498,7 @@ class RocmPlatform(Platform): f"Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the " "`dtype` flag in CLI, for example: --dtype=half.") + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d7468d74b021f..6a061956d8141 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -200,6 +200,32 @@ class TpuPlatform(Platform): model_config: "ModelConfig") -> bool: return True + @classmethod + @torch.compile(backend="openxla") + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True) + dst_cache[dst_block_indices] = src_cache[src_block_indices].to( + dst_cache.device) + + @classmethod + @torch.compile(backend="openxla") + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """ tpu blocks to cpu blocks""" + torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True) + dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu() + try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index d61b921e19cfe..67ef058df10f1 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -37,14 +37,38 @@ class XPUPlatform(Platform): dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink: bool) -> str: - if selected_backend is not None and selected_backend != _Backend.IPEX: - logger.info("Cannot use %s backend on XPU.", selected_backend) use_v1 = envs.VLLM_USE_V1 if not use_v1: raise ValueError("XPU backend only supports V1.") + TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + logger.info_once("Using Triton backend on V1 engine.") + return TRITON_ATTN_VLLM_V1 + elif selected_backend == _Backend.FLASH_ATTN: + logger.info_once("Using Flash Attention backend on V1 engine.") + return FLASH_ATTN_V1 + elif selected_backend: + raise ValueError( + f"Invalid attention backend for {cls.device_name}, " + f"with use_v1: {use_v1} use_mla: {use_mla}") + logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, + model_config: "ModelConfig") -> bool: + """ + Check if the kv_cache_dtype is supported. + XPU only support fp8 kv cache with triton backend. + """ + if envs.is_set("VLLM_ATTENTION_BACKEND") and \ + envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN_VLLM_V1": + return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] + + return False + @classmethod def set_device(cls, device: torch.device) -> None: """ @@ -91,7 +115,7 @@ class XPUPlatform(Platform): cache_config.block_size = 64 # lazy import to avoid circular import - from vllm.config import CUDAGraphMode + from vllm.config import CompilationLevel, CUDAGraphMode compilation_config = vllm_config.compilation_config if compilation_config.cudagraph_mode is None or \ compilation_config.cudagraph_mode.max_cudagraph_mode() \ @@ -100,6 +124,9 @@ class XPUPlatform(Platform): "cudagraphs. Fallback to cudagraph_mode=NONE") compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if vllm_config.lora_config is not None: + compilation_config.level = CompilationLevel.NO_COMPILATION + # check and update parallel config parallel_config = vllm_config.parallel_config parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" @@ -136,6 +163,15 @@ class XPUPlatform(Platform): vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout("NHD") + logger.info("Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " + "only NHD layout is supported by XPU attention kernels.") + + @classmethod + def support_hybrid_kv_cache(cls) -> bool: + return True @classmethod def is_pin_memory_available(cls): @@ -183,3 +219,27 @@ class XPUPlatform(Platform): @classmethod def opaque_attention_op(cls) -> bool: return True + + @classmethod + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from src_cache to dst_cache on XPU.""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + + @classmethod + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from XPU to host (CPU).""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.cpu() diff --git a/vllm/plugins/io_processors/__init__.py b/vllm/plugins/io_processors/__init__.py index c5c4f6f8d97c3..3b17211b1b835 100644 --- a/vllm/plugins/io_processors/__init__.py +++ b/vllm/plugins/io_processors/__init__.py @@ -33,7 +33,7 @@ def get_io_processor( model_plugin = config_plugin if model_plugin is None: - logger.info("No IOProcessor plugins requested by the model") + logger.debug("No IOProcessor plugins requested by the model") return None logger.debug("IOProcessor plugin to be loaded %s", model_plugin) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 6672392b8d080..a6313367457a4 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -20,25 +20,33 @@ class PoolingParams( """API parameters for pooling models. Attributes: + truncate_prompt_tokens: Controls prompt truncation. + Set to -1 to use the model's default truncation size. + Set to k to keep only the last k tokens (left truncation). + Set to None to disable truncation. normalize: Whether to normalize the embeddings outputs. dimensions: Reduce the dimensions of embeddings - if model support matryoshka representation. + if model support matryoshka representation. activation: Whether to apply activation function to - the classification outputs. + the classification outputs. softmax: Whether to apply softmax to the reward outputs. """ + + # --8<-- [start:common-pooling-params] truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None - """If set to -1, will use the truncation size supported by the model. If - set to an integer k, will use only the last k tokens from the prompt - (i.e., left truncation). If set to `None`, truncation is disabled.""" + # --8<-- [end:common-pooling-params] ## for embeddings models + # --8<-- [start:embedding-pooling-params] dimensions: Optional[int] = None normalize: Optional[bool] = None + # --8<-- [end:embedding-pooling-params] - ## for classification models + ## for classification, scoring and rerank + # --8<-- [start:classification-pooling-params] activation: Optional[bool] = None + # --8<-- [end:classification-pooling-params] ## for reward models softmax: Optional[bool] = None diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py index 05a72ac23bf2e..3bd4d872ce22f 100644 --- a/vllm/reasoning/gptoss_reasoning_parser.py +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -6,6 +6,7 @@ from typing import Optional, Union from transformers import PreTrainedTokenizerBase +from vllm.entrypoints.harmony_utils import parse_chat_output from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) from vllm.logger import init_logger @@ -14,7 +15,7 @@ from vllm.reasoning import ReasoningParser, ReasoningParserManager logger = init_logger(__name__) -@ReasoningParserManager.register_module("GptOss") +@ReasoningParserManager.register_module("openai_gptoss") class GptOssReasoningParser(ReasoningParser): """ Reasoning parser for GptOss model. @@ -39,9 +40,10 @@ class GptOssReasoningParser(ReasoningParser): return False def extract_content_ids(self, input_ids: list[int]) -> list[int]: - raise RuntimeError( - "GptOss model uses harmony to extract reasoning content. This " - "function should not be called.") + _, content, _ = parse_chat_output(input_ids) + if content is None: + return [] + return self.model_tokenizer.encode(content) def extract_reasoning_content_streaming( self, @@ -52,13 +54,34 @@ class GptOssReasoningParser(ReasoningParser): current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: - raise RuntimeError( - "GptOss model uses harmony to extract reasoning content. This " - "function should not be called.") + prev_reasoning, prev_content, _ = parse_chat_output( + list(previous_token_ids)) + cur_reasoning, cur_content, _ = parse_chat_output( + list(current_token_ids)) + reasoning_delta = None + content_delta = None + if cur_reasoning is not None: + prev_r = prev_reasoning or "" + if cur_reasoning.startswith(prev_r): + reasoning_delta = cur_reasoning[len(prev_r):] or None + else: + reasoning_delta = cur_reasoning + if cur_content is not None: + prev_c = prev_content or "" + if cur_content.startswith(prev_c): + content_delta = cur_content[len(prev_c):] or None + else: + content_delta = cur_content + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning_content=reasoning_delta, + content=content_delta) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, + model_output: str, + request: ChatCompletionRequest, ) -> tuple[Optional[str], Optional[str]]: - raise RuntimeError( - "GptOss model uses harmony to extract reasoning content. This " - "function should not be called.") + raise NotImplementedError( + "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501 + ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c7b4ba34c602e..efe70d019ccc6 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -2,13 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampling parameters for text generation.""" import copy -from dataclasses import dataclass +from dataclasses import field from enum import Enum, IntEnum from functools import cached_property from typing import Annotated, Any, Optional, Union import msgspec -from pydantic import BaseModel +from pydantic.dataclasses import dataclass from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor @@ -28,60 +28,35 @@ class SamplingType(IntEnum): # maybe make msgspec? @dataclass -class GuidedDecodingParams: - """One of these fields will be used to build a logit processor.""" +class StructuredOutputsParams: + # One of these fields will be used to build a logit processor. json: Optional[Union[str, dict]] = None regex: Optional[str] = None choice: Optional[list[str]] = None grammar: Optional[str] = None json_object: Optional[bool] = None - """These are other options that can be set""" - backend: Optional[str] = None - backend_was_auto: bool = False + # These are other options that can be set. disable_fallback: bool = False disable_any_whitespace: bool = False disable_additional_properties: bool = False whitespace_pattern: Optional[str] = None structural_tag: Optional[str] = None - @staticmethod - def from_optional( - json: Optional[Union[dict, BaseModel, str]] = None, - regex: Optional[str] = None, - choice: Optional[list[str]] = None, - grammar: Optional[str] = None, - json_object: Optional[bool] = None, - backend: Optional[str] = None, - whitespace_pattern: Optional[str] = None, - structural_tag: Optional[str] = None, - ) -> Optional["GuidedDecodingParams"]: - if all(arg is None for arg in (json, regex, choice, grammar, - json_object, structural_tag)): - return None - # Extract json schemas from pydantic models - if isinstance(json, (BaseModel, type(BaseModel))): - json = json.model_json_schema() - return GuidedDecodingParams( - json=json, - regex=regex, - choice=choice, - grammar=grammar, - json_object=json_object, - backend=backend, - whitespace_pattern=whitespace_pattern, - structural_tag=structural_tag, - ) + _backend: Optional[str] = field(default=None, init=False) + """CAUTION: Should only be set by Processor._validate_structured_output""" + _backend_was_auto: bool = field(default=False, init=False) + """CAUTION: Should only be set by Processor._validate_structured_output""" def __post_init__(self): """Validate that some fields are mutually exclusive.""" - guide_count = sum([ + count = sum([ self.json is not None, self.regex is not None, self.choice is not None, self.grammar is not None, self.json_object is not None ]) - if guide_count > 1: + if count > 1: raise ValueError( - "You can only use one kind of guided decoding but multiple are " - f"specified: {self.__dict__}") + "You can only use one kind of structured outputs constraint " + f"but multiple are specified: {self.__dict__}") class RequestOutputKind(Enum): @@ -106,7 +81,13 @@ class SamplingParams( """ n: int = 1 - """Number of output sequences to return for the given prompt.""" + """Number of outputs to return for the given prompt request. + + NOTE: + `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs + are generated and streamed cumulatively per request. To see all `n` + outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY` + in `SamplingParams`.""" best_of: Optional[int] = None """Number of output sequences that are generated from the prompt. From these `best_of` sequences, the top `n` sequences are returned. `best_of` @@ -165,7 +146,8 @@ class SamplingParams( the sampled token, so there may be up to `logprobs+1` elements in the response. When set to -1, return all `vocab_size` log probabilities.""" prompt_logprobs: Optional[int] = None - """Number of log probabilities to return per prompt token.""" + """Number of log probabilities to return per prompt token. + When set to -1, return all `vocab_size` log probabilities.""" # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs. @@ -195,9 +177,8 @@ class SamplingParams( _all_stop_token_ids: set[int] = msgspec.field(default_factory=set) # Fields used to construct logits processors - guided_decoding: Optional[GuidedDecodingParams] = None - """If provided, the engine will construct a guided decoding logits - processor from these parameters.""" + structured_outputs: Optional[StructuredOutputsParams] = None + """Parameters for configuring structured outputs.""" logit_bias: Optional[dict[int, float]] = None """If provided, the engine will construct a logits processor that applies these logit biases.""" @@ -245,7 +226,7 @@ class SamplingParams( msgspec.Meta( ge=-1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, - guided_decoding: Optional[GuidedDecodingParams] = None, + structured_outputs: Optional[StructuredOutputsParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, allowed_token_ids: Optional[list[int]] = None, extra_args: Optional[dict[str, Any]] = None, @@ -287,7 +268,7 @@ class SamplingParams( logits_processors=logits_processors, truncate_prompt_tokens=truncate_prompt_tokens, output_kind=output_kind, - guided_decoding=guided_decoding, + structured_outputs=structured_outputs, logit_bias=logit_bias, allowed_token_ids=allowed_token_ids, extra_args=extra_args, @@ -409,9 +390,11 @@ class SamplingParams( and self.logprobs < 0): raise ValueError( f"logprobs must be non-negative or -1, got {self.logprobs}.") - if self.prompt_logprobs is not None and self.prompt_logprobs < 0: - raise ValueError(f"prompt_logprobs must be non-negative, got " - f"{self.prompt_logprobs}.") + if (self.prompt_logprobs is not None and self.prompt_logprobs != -1 + and self.prompt_logprobs < 0): + raise ValueError( + f"prompt_logprobs must be non-negative or -1, got " + f"{self.prompt_logprobs}.") if (self.truncate_prompt_tokens is not None and (self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1)): @@ -556,7 +539,7 @@ class SamplingParams( "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " - f"guided_decoding={self.guided_decoding}, " + f"structured_outputs={self.structured_outputs}, " f"extra_args={self.extra_args})") diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 23679b8228d6f..91dcc2fd84e17 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -36,7 +36,6 @@ MODELS_ON_S3 = [ "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf", # "meta-llama/Llama-2-7b-hf", - "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-1B", "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Meta-Llama-3-8B", diff --git a/vllm/third_party/pynvml.py b/vllm/third_party/pynvml.py index c06aa567444d8..6aabbc217dd03 100644 --- a/vllm/third_party/pynvml.py +++ b/vllm/third_party/pynvml.py @@ -3533,7 +3533,7 @@ def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): return [] elif (ret == NVML_ERROR_INSUFFICIENT_SIZE): # typical case - # oversize the array incase more processes are created + # oversize the array in case more processes are created c_count.value = c_count.value * 2 + 5 proc_array = c_nvmlProcessInfo_v3_t * c_count.value c_procs = proc_array() diff --git a/vllm/tracing.py b/vllm/tracing.py index 6a287d82be5ff..7537e9901a044 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -119,6 +119,11 @@ class SpanAttributes: # forward, block/sync across workers, cpu-gpu sync time and sampling time. GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = ( "gen_ai.latency.time_in_model_execute") + GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \ + "gen_ai.latency.time_in_model_prefill" + GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode" + GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \ + "gen_ai.latency.time_in_model_inference" def contains_trace_headers(headers: Mapping[str, str]) -> bool: diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index d09c5fa924fb0..3a97f2c056181 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -35,7 +35,6 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", - "florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index bec792465bfbb..cafc43f6b7673 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import enum import json import os import time from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Literal, Optional, TypeVar, Union import huggingface_hub from huggingface_hub import get_safetensors_metadata, hf_hub_download @@ -27,6 +26,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger +from vllm.transformers_utils.config_parser_base import ConfigParserBase from vllm.transformers_utils.utils import check_gguf_file if envs.VLLM_USE_MODELSCOPE: @@ -71,14 +71,16 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( jais="JAISConfig", mlp_speculator="MLPSpeculatorConfig", medusa="MedusaConfig", + midashenglm="MiDashengLMConfig", eagle="EAGLEConfig", speculators="SpeculatorsConfig", nemotron="NemotronConfig", + olmo3="Olmo3Config", ovis="OvisConfig", ultravox="UltravoxConfig", step3_vl="Step3VLConfig", step3_text="Step3TextConfig", -) + qwen3_next="Qwen3NextConfig") _CONFIG_ATTRS_MAPPING: dict[str, str] = { "llm_config": "text_config", @@ -88,21 +90,169 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = { "internvl_chat": { "has_no_defaults_at_init": True }, - # transformers regards mllama as is_encoder_decoder=False - # vllm needs is_encoder_decoder=True to enable cross-attention - "mllama": { - "is_encoder_decoder": True - }, "NVLM_D": { "has_no_defaults_at_init": True }, } -class ConfigFormat(str, enum.Enum): - AUTO = "auto" - HF = "hf" - MISTRAL = "mistral" +class HFConfigParser(ConfigParserBase): + + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE + config_dict, _ = PretrainedConfig.get_config_dict( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type is None: + model_type = "speculators" if config_dict.get( + "speculators_config") is not None else model_type + + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + else: + try: + kwargs = _maybe_update_auto_config_kwargs( + kwargs, model_type=model_type) + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + except ValueError as e: + if (not trust_remote_code + and "requires you to execute the configuration file" + in str(e)): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + config = _maybe_remap_hf_config_attrs(config) + return config_dict, config + + +class MistralConfigParser(ConfigParserBase): + + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = _download_mistral_config_file(model, revision) + if (max_position_embeddings := + config_dict.get("max_position_embeddings")) is None: + max_position_embeddings = _maybe_retrieve_max_pos_from_hf( + model, revision, **kwargs) + config_dict["max_position_embeddings"] = max_position_embeddings + + from vllm.transformers_utils.configs.mistral import adapt_config_dict + + config = adapt_config_dict(config_dict) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if ((sliding_window := getattr(config, "sliding_window", None)) + and isinstance(sliding_window, list)): + pattern_repeats = config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + config.sliding_window = next(filter(None, sliding_window), None) + + return config_dict, config + + +_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = { + "hf": HFConfigParser, + "mistral": MistralConfigParser, +} + +ConfigFormat = Literal[ + "auto", + "hf", + "mistral", +] + + +def get_config_parser(config_format: str) -> ConfigParserBase: + """Get the config parser for a given config format.""" + if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER: + raise ValueError(f"Unknown config format `{config_format}`.") + return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]() + + +def register_config_parser(config_format: str): + + """Register a customized vllm config parser. + When a config format is not supported by vllm, you can register a customized + config parser to support it. + Args: + config_format (str): The config parser format name. + Examples: + + >>> from vllm.transformers_utils.config import (get_config_parser, + register_config_parser) + >>> from vllm.transformers_utils.config_parser_base import ConfigParserBase + >>> + >>> @register_config_parser("custom_config_parser") + ... class CustomConfigParser(ConfigParserBase): + ... def parse(self, + ... model: Union[str, Path], + ... trust_remote_code: bool, + ... revision: Optional[str] = None, + ... code_revision: Optional[str] = None, + ... **kwargs) -> tuple[dict, PretrainedConfig]: + ... raise NotImplementedError + >>> + >>> type(get_config_parser("custom_config_parser")) + <class 'CustomConfigParser'> + """ # noqa: E501 + + def _wrapper(config_parser_cls): + if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER: + logger.warning( + "Config format `%s` is already registered, and will be " + "overwritten by the new parser class `%s`.", config_format, + config_parser_cls) + if not issubclass(config_parser_cls, ConfigParserBase): + raise ValueError("The config parser must be a subclass of " + "`ConfigParserBase`.") + _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls + logger.info("Registered config parser `%s` with config format `%s`", + config_parser_cls, config_format) + return config_parser_cls + + return _wrapper _R = TypeVar("_R") @@ -349,7 +499,7 @@ def get_config( trust_remote_code: bool, revision: Optional[str] = None, code_revision: Optional[str] = None, - config_format: ConfigFormat = ConfigFormat.AUTO, + config_format: Union[str, ConfigFormat] = "auto", hf_overrides_kw: Optional[dict[str, Any]] = None, hf_overrides_fn: Optional[Callable[[PretrainedConfig], PretrainedConfig]] = None, @@ -362,20 +512,22 @@ def get_config( kwargs["gguf_file"] = Path(model).name model = Path(model).parent - if config_format == ConfigFormat.AUTO: + if config_format == "auto": try: if is_gguf or file_or_path_exists( model, HF_CONFIG_NAME, revision=revision): - config_format = ConfigFormat.HF + config_format = "hf" elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): - config_format = ConfigFormat.MISTRAL + config_format = "mistral" else: raise ValueError( "Could not detect config format for no config file found. " - "Ensure your model has either config.json (HF format) " - "or params.json (Mistral format).") + "With config_format 'auto', ensure your model has either" + "config.json (HF format) or params.json (Mistral format)." + "Otherwise please specify your_custom_config_format" + "in engine args for customized config parser") except Exception as e: error_message = ( @@ -394,92 +546,14 @@ def get_config( raise ValueError(error_message) from e - if config_format == ConfigFormat.HF: - kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE - config_dict, _ = PretrainedConfig.get_config_dict( - model, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - # Use custom model class if it's in our registry - model_type = config_dict.get("model_type") - if model_type is None: - model_type = "speculators" if config_dict.get( - "speculators_config") is not None else model_type - - if model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[model_type] - config = config_class.from_pretrained( - model, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - else: - try: - kwargs = _maybe_update_auto_config_kwargs( - kwargs, model_type=model_type) - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - except ValueError as e: - if (not trust_remote_code - and "requires you to execute the configuration file" - in str(e)): - err_msg = ( - "Failed to load the model config. If the model " - "is a custom model not yet available in the " - "HuggingFace transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - config = _maybe_remap_hf_config_attrs(config) - - elif config_format == ConfigFormat.MISTRAL: - # This function loads a params.json config which - # should be used when loading models in mistral format - config_dict = _download_mistral_config_file(model, revision) - if (max_position_embeddings := - config_dict.get("max_position_embeddings")) is None: - max_position_embeddings = _maybe_retrieve_max_pos_from_hf( - model, revision, **kwargs) - config_dict["max_position_embeddings"] = max_position_embeddings - - from vllm.transformers_utils.configs.mistral import adapt_config_dict - - config = adapt_config_dict(config_dict) - - # Mistral configs may define sliding_window as list[int]. Convert it - # to int and add the layer_types list[str] to make it HF compatible - if ((sliding_window := getattr(config, "sliding_window", None)) - and isinstance(sliding_window, list)): - pattern_repeats = config.num_hidden_layers // len(sliding_window) - layer_types = sliding_window * pattern_repeats - config.layer_types = [ - "full_attention" if layer_type is None else "sliding_attention" - for layer_type in layer_types - ] - config.sliding_window = next(filter(None, sliding_window), None) - else: - supported_formats = [ - fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO - ] - raise ValueError( - f"Unsupported config format: {config_format}. " - f"Supported formats are: {', '.join(supported_formats)}. " - f"Ensure your model uses one of these configuration formats " - f"or specify the correct format explicitly.") - + config_parser = get_config_parser(config_format) + config_dict, config = config_parser.parse( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs, + ) # Special architecture mapping check for GGUF models if is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: @@ -600,20 +674,21 @@ def get_hf_file_to_dict(file_name: str, @cache -def get_pooling_config(model: str, revision: Optional[str] = 'main'): +def get_pooling_config(model: str, + revision: Optional[str] = 'main') -> Optional[dict]: """ This function gets the pooling and normalize config from the model - only applies to sentence-transformers models. Args: - model (str): The name of the Hugging Face model. - revision (str, optional): The specific version - of the model to use. Defaults to 'main'. + model: The name of the Hugging Face model. + revision: The specific version of the model to use. + Defaults to 'main'. Returns: - dict: A dictionary containing the pooling - type and whether normalization is used. + A dictionary containing the pooling type and whether + normalization is used, or None if no pooling configuration is found. """ modules_file_name = "modules.json" @@ -913,7 +988,7 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: hf_config = get_config(model=model, trust_remote_code=trust_remote_code_val, revision=revision, - config_format=ConfigFormat.HF) + config_format="hf") if hf_value := hf_config.get_text_config().max_position_embeddings: max_position_embeddings = hf_value except Exception as e: diff --git a/vllm/transformers_utils/config_parser_base.py b/vllm/transformers_utils/config_parser_base.py new file mode 100644 index 0000000000000..c27177f74d4ba --- /dev/null +++ b/vllm/transformers_utils/config_parser_base.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Union + +from transformers import PretrainedConfig + + +class ConfigParserBase(ABC): + + @abstractmethod + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + raise NotImplementedError diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8339c55bcf808..91bfeb8c55ee5 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -17,12 +17,16 @@ from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.midashenglm import MiDashengLMConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config +from vllm.transformers_utils.configs.olmo3 import Olmo3Config from vllm.transformers_utils.configs.ovis import OvisConfig +from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig +from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, Step3VisionEncoderConfig, @@ -36,16 +40,20 @@ __all__ = [ "RWConfig", "JAISConfig", "MedusaConfig", + "MiDashengLMConfig", "MLPSpeculatorConfig", "MoonViTConfig", "KimiVLConfig", "NemotronConfig", "NemotronHConfig", "Nemotron_Nano_VL_Config", + "Olmo3Config", "OvisConfig", + "RadioConfig", "SpeculatorsConfig", "UltravoxConfig", "Step3VLConfig", "Step3VisionEncoderConfig", "Step3TextConfig", + "Qwen3NextConfig", ] diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 6aabf9e5262e6..444ed70de3d0c 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -46,6 +46,7 @@ class EAGLEConfig(PretrainedConfig): # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM # LlamaForCausalLM -> Eagle3LlamaForCausalLM + # LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3 if method == "eagle": assert self.model is not None, \ "model should not be None when method is eagle" @@ -53,6 +54,7 @@ class EAGLEConfig(PretrainedConfig): f"Eagle{arch}" if not arch.startswith("Eagle") \ else arch for arch in self.model.architectures ] + elif method == "eagle3": assert self.model is not None, \ "model should not be None when method is eagle3" diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 767c4ddae870d..d5ca2c7b4751a 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -74,10 +74,10 @@ class JAISConfig(PretrainedConfig): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). - scale_attn_by_inverse_layer_idx (`bool`, *optional*, - defaults to `False`): - Whether to additionally scale attention weights by - `1 / layer_idx + 1`. + scale_attn_by_inverse_layer_idx + (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights + by `1 / layer_idx + 1`. reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): Whether to scale keys (K) prior to computing attention (dot-product) diff --git a/vllm/transformers_utils/configs/midashenglm.py b/vllm/transformers_utils/configs/midashenglm.py new file mode 100644 index 0000000000000..1c23202e23c8e --- /dev/null +++ b/vllm/transformers_utils/configs/midashenglm.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Horizon team, Xiaomi MiLM Plus. +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from transformers import PretrainedConfig +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniTextConfig) + + +class DashengConfig(PretrainedConfig): + model_type = "midashenglm_dasheng_encoder" + + def __init__( + self, + embed_dim: int = 768, + outputdim: int = 527, + patch_size: Union[int, tuple[int, int]] = 16, + patch_stride: Union[int, tuple[int, int]] = 16, + input_channels: int = 1, + target_length: int = 1012, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + init_values: Optional[float] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + f_min: float = 0.0, + f_max: float = 8000.0, + center: bool = True, + win_length: int = 512, + hop_length: int = 160, + sample_rate: int = 16000, + n_fft: int = 512, + n_mels: int = 64, + **kwargs, + ): + self.embed_dim = embed_dim + self.outputdim = outputdim + self.patch_size = patch_size + self.patch_stride = patch_stride + self.input_channels = input_channels + self.target_length = target_length + self.depth = depth + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.init_values = init_values + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.f_min = f_min + self.f_max = f_max + self.center = center + self.win_length = win_length + self.hop_length = hop_length + self.sample_rate = sample_rate + self.n_fft = n_fft + self.n_mels = n_mels + super().__init__(**kwargs) + + +class MiDashengLMConfig(PretrainedConfig): + model_type = "midashenglm" + + def __init__( + self, + audio_encoder_config: Optional[dict] = None, + subsample_factor: int = 5, + text_config: Optional[dict] = None, + audio_token_id: Optional[int] = None, + **kwargs, + ): + self.audio_encoder_config = DashengConfig( + **(audio_encoder_config or {})) + self.subsample_factor = subsample_factor + self.text_config = (Qwen2_5OmniTextConfig( + **text_config) if text_config else Qwen2_5OmniTextConfig()) + self.text_config.rope_scaling = None # uses_mrope is false + self.audio_token_id = audio_token_id + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index 8a9c660b882fd..5d9206e188322 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -157,6 +157,7 @@ def _remap_mistral_audio_args(config: dict) -> dict: encoder_attention_heads=encoder_args["n_heads"], vocab_size=encoder_args["vocab_size"], max_source_positions=encoder_args["max_source_positions"], + is_encoder_decoder=False, # Override WhisperConfig default ) } if quant_config: diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py new file mode 100644 index 0000000000000..874507db43a7f --- /dev/null +++ b/vllm/transformers_utils/configs/olmo3.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class Olmo3Config(PretrainedConfig): + + model_type = "olmo3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + sliding_window=4096, + layer_types=None, + **kwargs, + ): + # This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM + # in vLLM. + if "architectures" not in kwargs: + kwargs["architectures"] = ["Olmo2ForCausalLM"] + elif "Olmo3ForCausalLM" in kwargs["architectures"]: + kwargs["architectures"].remove("Olmo3ForCausalLM") + kwargs["architectures"].append("Olmo2ForCausalLM") + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + self.sliding_window = sliding_window + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 4 != 0 else "full_attention" + for i in range(self.num_hidden_layers) + ] diff --git a/vllm/transformers_utils/configs/qwen3_next.py b/vllm/transformers_utils/configs/qwen3_next.py new file mode 100644 index 0000000000000..c7af26acd1b9f --- /dev/null +++ b/vllm/transformers_utils/configs/qwen3_next.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3-Next model configuration""" + +from transformers.configuration_utils import (PretrainedConfig, + layer_type_validation) +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Qwen3NextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a + Qwen3-Next model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 10): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 512): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`list[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + layer_types (`list[str]`, *optional*): + Types of each layer (attention or linear). + + ```python + >>> from transformers import Qwen3NextModel, Qwen3NextConfig + + >>> # Initializing a Qwen3Next style configuration + >>> configuration = Qwen3NextConfig() + + >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration + >>> model = Qwen3NextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ # noqa: E501 + + model_type = "qwen3_next" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=48, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.25, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + decoder_sparse_step=1, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=10, + num_experts=512, + norm_topk_prob=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + layer_types=None, + **kwargs, + ): + if mlp_only_layers is None: + mlp_only_layers = [] + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "linear_attention" if bool((i + 1) % 4) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + # linear attention part + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = mlp_only_layers + + +__all__ = ["Qwen3NextConfig"] diff --git a/vllm/transformers_utils/configs/radio.py b/vllm/transformers_utils/configs/radio.py new file mode 100644 index 0000000000000..58ad7b8187bcd --- /dev/null +++ b/vllm/transformers_utils/configs/radio.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Radio vision model configuration""" + +from typing import Optional, Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VIT_TIMM_DIM_BY_NAME: dict[str, tuple[int, int, int, int]] = { + "vit_small_patch16_224": (384, 12, 6, 1536), + "vit_base_patch16_224": (768, 12, 12, 3072), + "vit_large_patch16_224": (1024, 24, 16, 4096), + "vit_huge_patch16_224": (1280, 32, 16, 5120), +} + +OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) + + +class RadioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a Radio + vision model. It is used to instantiate a Radio model according to the + specified arguments, defining the model architecture. + + Args: + model_name (`str`, *optional*, defaults to "vit_base_patch16_224"): + Name of the vision transformer model (e.g., "vit_base_patch16_224"). + Used to determine architecture dimensions from + `VIT_TIMM_DIM_BY_NAME`. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + qkv_bias (`bool`, *optional*, defaults to True): + Whether to add a bias to the queries, keys and values. + qk_normalization (`bool`, *optional*, defaults to False): + Whether to apply normalization to queries and keys. + norm_type (`str`, *optional*, defaults to "layer_norm"): + The normalization type to use. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices. + hidden_act (`str`, *optional*, defaults to "gelu"): + The non-linear activation function in the encoder. + max_img_size (`int`, *optional*, defaults to 2048): + Maximum image size for position embeddings. + norm_mean (`tuple` or `list`, *optional*, + defaults to (0.48145466, 0.4578275, 0.40821073)): + Mean values for image normalization (RGB channels). + norm_std (`tuple` or `list`, *optional*, + defaults to (0.26862954, 0.26130258, 0.27577711)): + Standard deviation values for image normalization (RGB channels). + reg_tokens (`int`, *optional*): + Number of register tokens to use. + """ + + model_type = "radio" + + def __init__( + self, + model_name: str, + image_size: int = 224, + patch_size: int = 16, + qkv_bias: bool = True, + qk_normalization: bool = False, + norm_type: str = "layer_norm", + layer_norm_eps: float = 1e-6, + initializer_factor: float = 1.0, + hidden_act: str = "gelu", + max_img_size: int = 2048, + norm_mean: Union[tuple[float, float, float], list] = OPENAI_CLIP_MEAN, + norm_std: Union[tuple[float, float, float], list] = OPENAI_CLIP_STD, + reg_tokens: Optional[int] = None, + **kwargs, + ): + self.model_name = model_name + ( + self.hidden_size, + self.num_hidden_layers, + self.num_attention_heads, + self.intermediate_size, + ) = VIT_TIMM_DIM_BY_NAME[model_name] + self.image_size = image_size + self.patch_size = patch_size + self.qkv_bias = qkv_bias + self.qk_normalization = qk_normalization + self.norm_type = norm_type + self.layer_norm_eps = layer_norm_eps + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + self.max_img_size = max_img_size + self.norm_mean = list(norm_mean) if isinstance(norm_mean, + (tuple, + list)) else norm_mean + self.norm_std = list(norm_std) if isinstance(norm_std, + (tuple, + list)) else norm_std + self.reg_tokens = reg_tokens + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index 87064cc12deda..aaf31d84d0c1a 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -20,10 +20,13 @@ class UltravoxConfig(transformers.PretrainedConfig): Args: audio_config (`Union[AutoConfig, dict]`, *optional*): - Custom audio config or dict + Custom audio config or dict. text_config (`Union[AutoConfig, dict]`, *optional*): - The config object of the text backbone. Can be any of `LlamaConfig` - or `MistralConfig`. + The config object of the text backbone. + audio_model_id (`str`, *optional*): + The model ID of the audio backbone. + text_model_id (`str`, *optional*): + The model ID of the text backbone. ignore_index (`int`, *optional*, defaults to -100): The ignore index for the loss function. audio_token_index (`int`, *optional*, defaults to 32000): @@ -34,16 +37,12 @@ class UltravoxConfig(transformers.PretrainedConfig): The initialization value for the layer normalization. projector_act (`str`, *optional*, defaults to `"swiglu"`): The activation function used by the multimodal projector. - text_model_lora_config (`LoraConfigSimplified`, *optional*): - The LoRA configuration for finetuning the text model. - audio_model_lora_config (`LoraConfigSimplified`, *optional*): - The LoRA configuration for finetuning the audio model. projector_ln_mid (`bool`, *optional*, defaults to `False`): Whether to apply layer normalization at the middle of the projector or at the end. Versions v0.4.1 and below use `False`, but v0.5 and above use `True`. """ - + wrapped_model_config: transformers.PretrainedConfig model_type = "ultravox" audio_token = "<|audio|>" is_composition = False @@ -60,15 +59,10 @@ class UltravoxConfig(transformers.PretrainedConfig): stack_factor: int = 8, norm_init: float = 0.4, projector_act: str = "swiglu", - text_model_lora_config: Optional[dict[str, Any]] = None, - audio_model_lora_config: Optional[dict[str, Any]] = None, projector_ln_mid: bool = False, **kwargs, ): self.ignore_index = ignore_index - - self.audio_model_id = audio_model_id - self.text_model_id = text_model_id self.audio_token_index = audio_token_index self.hidden_size = hidden_size @@ -77,36 +71,46 @@ class UltravoxConfig(transformers.PretrainedConfig): self.projector_act = projector_act self.projector_ln_mid = projector_ln_mid - if text_model_id is not None: - # Avoid circular import - from vllm.transformers_utils.config import get_config - - text_config_obj = get_config(text_model_id, - trust_remote_code=False) - else: + # N.B. May set the wrapped_model_config below. + self.text_model_id = text_model_id + if text_model_id is None: text_config = text_config or {} - text_config_obj = transformers.CONFIG_MAPPING[text_config.get( - "model_type", "llama")](**text_config) + self.wrapped_model_config = transformers.CONFIG_MAPPING[ + text_config.get("model_type", "llama")](**text_config) - inner_text_config = text_config_obj.get_text_config() - - if audio_model_id is not None: - # Avoid circular import - from vllm.transformers_utils.config import get_config - - audio_config = get_config(audio_model_id, trust_remote_code=False) - else: + # N.B. May set the audio_config below. + self.audio_model_id = audio_model_id + if audio_model_id is None: + self.audio_model_id = None audio_config = audio_config or {} - audio_config = transformers.CONFIG_MAPPING[audio_config.get( + self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( "model_type", "whisper")](**audio_config) - self.text_config = text_config_obj - self.audio_config = audio_config - self.text_model_lora_config = text_model_lora_config or {} - self.audio_model_lora_config = audio_model_lora_config or {} - - self.vocab_size = inner_text_config.vocab_size - self.initializer_range = inner_text_config.initializer_range - self.text_hidden_size = inner_text_config.hidden_size - super().__init__(**kwargs) + + def __setattr__(self, key, value): + # Since --hf-overrides are applied _after_ the UltravoxConfig is + # instantiated, load the configs implicitly when assigning text_model_id + # or audio_model_id. This allows: + # + # --hf-overrides.text_model_id=<quantized variant> + # + # to behave as intended. + if key == "text_model_id" and value is not None: + from vllm.transformers_utils.config import get_config + + self.wrapped_model_config = get_config(value, + trust_remote_code=False) + elif key == "audio_model_id" and value is not None: + from vllm.transformers_utils.config import get_config + + self.audio_config = get_config(value, trust_remote_code=False) + + return super().__setattr__(key, value) + + @property + def text_config(self) -> transformers.PretrainedConfig: + # When Ultravox wraps a multi-modal model (e.g. Gemma), we instantiate + # the full model, but the text config is the text config of the inner + # model. + return self.wrapped_model_config.get_text_config() diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 56b01ecf78c46..e2d2846a28073 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence, from .detokenizer_utils import (convert_prompt_ids_to_tokens, detokenize_incrementally) from .tokenizer import AnyTokenizer -from .tokenizer_group import TokenizerGroup class Detokenizer: """Provides methods to decode the output of a model into text.""" - def __init__(self, tokenizer_group: TokenizerGroup): - self.tokenizer_group = tokenizer_group - - def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: - """Returns the HF tokenizer to use for a given sequence.""" - return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) + def __init__(self, tokenizer: AnyTokenizer): + self.tokenizer = tokenizer def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, prompt_logprobs: list[Optional[dict[ @@ -32,9 +27,9 @@ class Detokenizer: Args: seq_group: The sequence group to decode. prompt_logprobs: The logprobs to decode. - position_offset: Offset of the first index of the logprobs + position_offset: Offset of the first index of the logprobs relative to the start of the sequence (for chunked prefill). - + Returns: The prompt logprobs with the decoded tokens. """ @@ -46,7 +41,6 @@ class Detokenizer: # Only prompt, without the generated token. all_token_ids = seq.get_token_ids() prompt_token_ids = all_token_ids[:-1] - tokenizer = self.get_tokenizer_for_seq(seq) prefix_offset = 0 read_offset = 0 next_iter_prefix_offset = 0 @@ -70,7 +64,7 @@ class Detokenizer: prompt_token_ids[:token_position] + [token_id]) (new_tokens, new_text, new_prefix_offset, new_read_offset) = detokenize_incrementally( - tokenizer=tokenizer, + tokenizer=self.tokenizer, all_input_ids=prompt_token_ids_with_token, prev_tokens=prev_tokens, prefix_offset=prefix_offset, @@ -111,7 +105,6 @@ class Detokenizer: """ all_input_ids = seq.get_token_ids() token_id_generated_this_iteration = all_input_ids[-1] - tokenizer = self.get_tokenizer_for_seq(seq) # Convert prompt token IDs to tokens if necessary. # Do it here so that we don't have to repeat this @@ -119,14 +112,14 @@ class Detokenizer: if seq.tokens is None: (seq.tokens, seq.prefix_offset, seq.read_offset) = convert_prompt_ids_to_tokens( - tokenizer=tokenizer, + tokenizer=self.tokenizer, prompt_ids=all_input_ids[:-1], skip_special_tokens=prms.skip_special_tokens, ) (new_tokens, new_decoded_token_text, prefix_offset, read_offset) = detokenize_incrementally( - tokenizer=tokenizer, + tokenizer=self.tokenizer, all_input_ids=all_input_ids, prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, @@ -150,7 +143,7 @@ class Detokenizer: and token_id != VLLM_INVALID_TOKEN_ID): all_input_ids_with_logprob = previous_tokens + [token_id] (_, new_text, _, _) = detokenize_incrementally( - tokenizer=tokenizer, + tokenizer=self.tokenizer, all_input_ids=all_input_ids_with_logprob, prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py index 5896bde312657..d1d117b4e2cf4 100644 --- a/vllm/transformers_utils/processors/deepseek_vl2.py +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -25,6 +25,7 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import math +from typing import Any import torch import torchvision.transforms as T @@ -178,17 +179,15 @@ class DeepseekVLV2Processor(ProcessorMixin): prompt: str, images: list[Image.Image], inference_mode: bool = True, - **kwargs, + **kwargs: Any, ): """ Args: prompt (str): the formatted prompt; - conversations (list[dict]): conversations with a list of messages; images (list[ImageType]): the list of images; inference_mode (bool): if True, then remove the last eos token; - system_prompt (str): the system prompt; - **kwargs: + **kwargs: Additional keyword arguments. Returns: outputs (BaseProcessorOutput): the output of the processor, @@ -259,7 +258,7 @@ class DeepseekVLV2Processor(ProcessorMixin): text: str, images: list[Image.Image], inference_mode: bool = True, - **kwargs, + **kwargs: Any, ): """ diff --git a/vllm/transformers_utils/runai_utils.py b/vllm/transformers_utils/runai_utils.py new file mode 100644 index 0000000000000..b7bee1974de5b --- /dev/null +++ b/vllm/transformers_utils/runai_utils.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import shutil +import signal +import tempfile +from typing import Optional + +from vllm.logger import init_logger +from vllm.utils import PlaceholderModule + +logger = init_logger(__name__) + +SUPPORTED_SCHEMES = ['s3://', 'gs://'] + +try: + from runai_model_streamer import list_safetensors as runai_list_safetensors + from runai_model_streamer import pull_files as runai_pull_files +except (ImportError, OSError): + # see https://github.com/run-ai/runai-model-streamer/issues/26 + # OSError will be raised on arm64 platform + runai_model_streamer = PlaceholderModule( + "runai_model_streamer") # type: ignore[assignment] + runai_pull_files = runai_model_streamer.placeholder_attr("pull_files") + runai_list_safetensors = runai_model_streamer.placeholder_attr( + "list_safetensors") + + +def list_safetensors(path: str = "") -> list[str]: + """ + List full file names from object path and filter by allow pattern. + + Args: + path: The object storage path to list from. + + Returns: + list[str]: List of full object storage paths allowed by the pattern + """ + return runai_list_safetensors(path) + + +def is_runai_obj_uri(model_or_path: str) -> bool: + return model_or_path.lower().startswith(tuple(SUPPORTED_SCHEMES)) + + +class ObjectStorageModel: + """ + A class representing an ObjectStorage model mirrored into a + temporary directory. + + Attributes: + dir: The temporary created directory. + + Methods: + pull_files(): Pull model from object storage to the temporary directory. + """ + + def __init__(self) -> None: + for sig in (signal.SIGINT, signal.SIGTERM): + existing_handler = signal.getsignal(sig) + signal.signal(sig, self._close_by_signal(existing_handler)) + + self.dir = tempfile.mkdtemp() + + def __del__(self): + self._close() + + def _close(self) -> None: + if os.path.exists(self.dir): + shutil.rmtree(self.dir) + + def _close_by_signal(self, existing_handler=None): + + def new_handler(signum, frame): + self._close() + if existing_handler: + existing_handler(signum, frame) + + return new_handler + + def pull_files(self, + model_path: str = "", + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None) -> None: + """ + Pull files from object storage into the temporary directory. + + Args: + model_path: The object storage path of the model. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + """ + if not model_path.endswith("/"): + model_path = model_path + "/" + runai_pull_files(model_path, self.dir, allow_pattern, ignore_pattern) diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py index f95aae7815e0b..b848898ff6dad 100644 --- a/vllm/transformers_utils/s3_utils.py +++ b/vllm/transformers_utils/s3_utils.py @@ -2,15 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import fnmatch -import os -import shutil -import signal -import tempfile -from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional from vllm.utils import PlaceholderModule +if TYPE_CHECKING: + from botocore.client import BaseClient + try: import boto3 except ImportError: @@ -31,7 +29,7 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: ] -def glob(s3=None, +def glob(s3: Optional["BaseClient"] = None, path: str = "", allow_pattern: Optional[list[str]] = None) -> list[str]: """ @@ -56,7 +54,7 @@ def glob(s3=None, def list_files( - s3, + s3: "BaseClient", path: str, allow_pattern: Optional[list[str]] = None, ignore_pattern: Optional[list[str]] = None @@ -93,70 +91,3 @@ def list_files( paths = _filter_ignore(paths, ignore_pattern) return bucket_name, prefix, paths - - -class S3Model: - """ - A class representing a S3 model mirrored into a temporary directory. - - Attributes: - s3: S3 client. - dir: The temporary created directory. - - Methods: - pull_files(): Pull model from S3 to the temporary directory. - """ - - def __init__(self) -> None: - self.s3 = boto3.client('s3') - for sig in (signal.SIGINT, signal.SIGTERM): - existing_handler = signal.getsignal(sig) - signal.signal(sig, self._close_by_signal(existing_handler)) - - self.dir = tempfile.mkdtemp() - - def __del__(self): - self._close() - - def _close(self) -> None: - if os.path.exists(self.dir): - shutil.rmtree(self.dir) - - def _close_by_signal(self, existing_handler=None): - - def new_handler(signum, frame): - self._close() - if existing_handler: - existing_handler(signum, frame) - - return new_handler - - def pull_files(self, - s3_model_path: str = "", - allow_pattern: Optional[list[str]] = None, - ignore_pattern: Optional[list[str]] = None) -> None: - """ - Pull files from S3 storage into the temporary directory. - - Args: - s3_model_path: The S3 path of the model. - allow_pattern: A list of patterns of which files to pull. - ignore_pattern: A list of patterns of which files not to pull. - - """ - if not s3_model_path.endswith("/"): - s3_model_path = s3_model_path + "/" - - bucket_name, base_dir, files = list_files(self.s3, s3_model_path, - allow_pattern, - ignore_pattern) - if len(files) == 0: - return - - for file in files: - destination_file = os.path.join( - self.dir, - file.removeprefix(base_dir).lstrip("/")) - local_dir = Path(destination_file).parent - os.makedirs(local_dir, exist_ok=True) - self.s3.download_file(bucket_name, file, destination_file) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index b3f1977f26cf4..9aaac66817394 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union import huggingface_hub from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from typing_extensions import assert_never from vllm import envs from vllm.logger import init_logger @@ -19,7 +20,6 @@ from vllm.transformers_utils.config import ( get_sentence_transformer_tokenizer_config) from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import make_async if TYPE_CHECKING: from vllm.config import ModelConfig @@ -274,20 +274,19 @@ def cached_tokenizer_from_config( ) -def get_lora_tokenizer(lora_request: LoRARequest, *args, - **kwargs) -> Optional[AnyTokenizer]: - if lora_request is None: - return None - try: - tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs) - except Exception as e: - # No tokenizer was found in the LoRA folder, - # use base model tokenizer - logger.warning( - "No tokenizer found in %s, using base model tokenizer instead. " - "(Exception: %s)", lora_request.lora_path, e) - tokenizer = None - return tokenizer +def init_tokenizer_from_configs(model_config: ModelConfig): + runner_type = model_config.runner_type + if runner_type == "generate" or runner_type == "draft": + truncation_side = "left" + elif runner_type == "pooling": + truncation_side = "right" + else: + assert_never(runner_type) - -get_lora_tokenizer_async = make_async(get_lora_tokenizer) + return get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision, + truncation_side=truncation_side, + ) diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py index 20e5fea714e70..b1f84a023fc31 100644 --- a/vllm/transformers_utils/tokenizer_base.py +++ b/vllm/transformers_utils/tokenizer_base.py @@ -61,6 +61,11 @@ class TokenizerBase(ABC): def max_token_id(self) -> int: raise NotImplementedError() + @property + @abstractmethod + def truncation_side(self) -> str: + raise NotImplementedError() + def __len__(self) -> int: return self.vocab_size diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py deleted file mode 100644 index ae8220f9b9dc5..0000000000000 --- a/vllm/transformers_utils/tokenizer_group.py +++ /dev/null @@ -1,131 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -from typing_extensions import assert_never - -from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens, - get_lora_tokenizer, - get_lora_tokenizer_async, - get_tokenizer) -from vllm.utils import LRUCache - - -class TokenizerGroup: - """A group of tokenizers that can be used for LoRA adapters.""" - - def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, - max_input_length: Optional[int], **tokenizer_config): - self.tokenizer_id = tokenizer_id - self.tokenizer_config = tokenizer_config - self.enable_lora = enable_lora - self.max_input_length = max_input_length - self.truncation_side = tokenizer_config.get("truncation_side", "left") - self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - max_loras = tokenizer_config.get("max_loras", 0) - self.lora_tokenizers = LRUCache[int, AnyTokenizer]( - capacity=max(max_loras, max_num_seqs) if enable_lora else 0) - - def get_max_input_len(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - """Get the maximum input length for the LoRA request.""" - return self.max_input_length - - def _raise_if_input_too_long(self, - encoded_tokens: list[int], - lora_request: Optional[LoRARequest] = None): - input_length = len(encoded_tokens) - if lora_request: - max_input_length = (lora_request.long_lora_max_len - or self.max_input_length) - else: - max_input_length = self.max_input_length - if max_input_length is not None and input_length > max_input_length: - raise ValueError("Input too long.", input_length, max_input_length) - - def encode(self, - prompt: str, - max_length: Optional[int] = None, - truncation: Optional[bool] = None, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> list[int]: - - tokenizer = self.get_lora_tokenizer(lora_request) - ret = encode_tokens(tokenizer, - prompt, - max_length=max_length, - truncation=truncation, - add_special_tokens=add_special_tokens) - self._raise_if_input_too_long(ret, lora_request) - return ret - - async def encode_async( - self, - prompt: str, - max_length: Optional[int] = None, - truncation: Optional[bool] = None, - lora_request: Optional[LoRARequest] = None, - add_special_tokens: Optional[bool] = None) -> list[int]: - tokenizer = await self.get_lora_tokenizer_async(lora_request) - ret = encode_tokens(tokenizer, - prompt, - max_length=max_length, - truncation=truncation, - add_special_tokens=add_special_tokens) - self._raise_if_input_too_long(ret, lora_request) - return ret - - def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (get_lora_tokenizer( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers[lora_request.lora_int_id] - - async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - if not lora_request or not self.enable_lora: - return self.tokenizer - if lora_request.lora_int_id not in self.lora_tokenizers: - tokenizer = (await get_lora_tokenizer_async( - lora_request, **self.tokenizer_config) or self.tokenizer) - self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) - return tokenizer - else: - return self.lora_tokenizers[lora_request.lora_int_id] - - -def init_tokenizer_from_configs(model_config: ModelConfig, - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig]): - runner_type = model_config.runner_type - if runner_type == "generate" or runner_type == "draft": - truncation_side = "left" - elif runner_type == "pooling": - truncation_side = "right" - else: - assert_never(runner_type) - - return TokenizerGroup( - tokenizer_id=model_config.tokenizer, - enable_lora=bool(lora_config), - max_num_seqs=scheduler_config.max_num_seqs, - max_loras=lora_config.max_loras if lora_config else 0, - max_input_length=None, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision, - truncation_side=truncation_side) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index f545993a5a980..d8a8d19391cd0 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -274,7 +274,7 @@ class MistralTokenizer(TokenizerBase): return tokenizer_file # the following attributes are set to fit vLLM's design and are used - # by the guided structured output backends. + # by the structured output backends. @property def all_special_tokens_extended(self) -> list[str]: from mistral_common.tokens.tokenizers.base import SpecialTokens @@ -327,6 +327,10 @@ class MistralTokenizer(TokenizerBase): def max_token_id(self) -> int: return self._max_token_id + @property + def truncation_side(self) -> str: + raise NotImplementedError() + def __len__(self) -> int: return self.vocab_size @@ -459,9 +463,6 @@ class MistralTokenizer(TokenizerBase): return decoded - # WARN: Outlines logits processors can overwrite this method. - # See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer - # for more. def decode(self, ids: Union[list[int], int], skip_special_tokens: bool = True) -> str: diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 0fcf5d15afd1d..828536e6408b1 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -7,8 +7,10 @@ from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder, if HAS_TRITON: import triton import triton.language as tl + import triton.language.extra.libdevice as tldevice else: triton = TritonPlaceholder() tl = TritonLanguagePlaceholder() + tldevice = TritonLanguagePlaceholder() -__all__ = ["HAS_TRITON", "triton", "tl"] +__all__ = ["HAS_TRITON", "triton", "tl", "tldevice"] diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index 372200027bf95..2a06a9b7d11e4 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -68,7 +68,7 @@ class TritonPlaceholder(types.ModuleType): def __init__(self): super().__init__("triton") - self.__version__ = "3.3.0" + self.__version__ = "3.4.0" self.jit = self._dummy_decorator("jit") self.autotune = self._dummy_decorator("autotune") self.heuristics = self._dummy_decorator("heuristics") diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 9c78e56d580e0..f13381ecd9ff3 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -78,6 +78,7 @@ if TYPE_CHECKING: from argparse import Namespace from vllm.config import ModelConfig, VllmConfig + from vllm.sequence import IntermediateTensors logger = init_logger(__name__) @@ -162,6 +163,12 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" +MB_bytes = 1_000_000 +"""The number of bytes in one megabyte (MB).""" + +MiB_bytes = 1 << 20 +"""The number of bytes in one mebibyte (MiB).""" + GB_bytes = 1_000_000_000 """The number of bytes in one gigabyte (GB).""" @@ -1472,7 +1479,8 @@ def current_stream() -> torch.cuda.Stream: # is hurting performance. Therefore creating a dedicated stream # per process if current_platform.is_rocm(): - _current_stream_tls.value = torch.cuda.Stream() + # torch.cuda.set_stream here is the alias of _pathed_set_stream + torch.cuda.set_stream(torch.cuda.Stream()) elif current_platform.is_cpu(): _current_stream_tls.value = _StreamPlaceholder() else: @@ -2074,6 +2082,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) +@lru_cache def supports_kw( callable: Callable[..., object], kw_name: str, @@ -2278,7 +2287,8 @@ def weak_ref_tensor(tensor: Any) -> Any: def weak_ref_tensors( - tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor], + IntermediateTensors] ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: """ Convenience function to create weak references to tensors, @@ -2290,6 +2300,15 @@ def weak_ref_tensors( return [weak_ref_tensor(t) for t in tensors] if isinstance(tensors, tuple): return tuple(weak_ref_tensor(t) for t in tensors) + + # For IntermediateTensors used in pipeline parallelism + from vllm.sequence import IntermediateTensors + if isinstance(tensors, IntermediateTensors): + ret = IntermediateTensors({ + key: weak_ref_tensor(val) + for key, val in tensors.tensors.items() + }) + return ret raise ValueError("Invalid type for tensors") @@ -2779,7 +2798,10 @@ def memory_profiling( result.torch_peak_increase = diff_profile.torch_peak result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp - result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + non_torch_memory = result.non_torch_increase + peak_activation_memory = result.torch_peak_increase + result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory # noqa # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 @@ -3249,7 +3271,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool: and getattr(cfg.attn_config, "alibi", False))))) -def sha256(input) -> int: +def sha256(input) -> bytes: """Hash any picklable Python object using SHA-256. The input is serialized using pickle before hashing, which allows @@ -3260,16 +3282,15 @@ def sha256(input) -> int: input: Any picklable Python object. Returns: - An integer representing the SHA-256 hash of the serialized input. + Bytes representing the SHA-256 hash of the serialized input. """ input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - return int.from_bytes(hashlib.sha256(input_bytes).digest(), - byteorder="big") + return hashlib.sha256(input_bytes).digest() -def sha256_cbor_64bit(input) -> int: +def sha256_cbor(input) -> bytes: """ - Hash objects using CBOR serialization and SHA-256, then truncate to 64bits. + Hash objects using CBOR serialization and SHA-256. This option is useful for non-Python-dependent serialization and hashing. @@ -3280,17 +3301,13 @@ def sha256_cbor_64bit(input) -> int: Custom classes must implement CBOR serialization methods. Returns: - An integer in the range [0, 2^64-1] representing the lower 64 bits - of the SHA-256 hash of the CBOR serialized input. + Bytes representing the SHA-256 hash of the CBOR serialized input. """ input_bytes = cbor2.dumps(input, canonical=True) - full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(), - byteorder="big") - - return full_hash & ((1 << 64) - 1) + return hashlib.sha256(input_bytes).digest() -def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]: +def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: """Get a hash function by name, or raise an error if the function is not found. Args: @@ -3300,10 +3317,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]: """ if hash_fn_name == "sha256": return sha256 - if hash_fn_name == "sha256_cbor_64bit": - return sha256_cbor_64bit - if hash_fn_name == "builtin": - return hash + if hash_fn_name == "sha256_cbor": + return sha256_cbor raise ValueError(f"Unsupported hash function: {hash_fn_name}") @@ -3366,7 +3381,7 @@ def has_triton_kernels() -> bool: def set_process_title(name: str, suffix: str = "", - append: bool = False) -> None: + prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: """ Set the current process title to a specific name with an optional suffix. @@ -3374,15 +3389,11 @@ def set_process_title(name: str, Args: name: The title to assign to the current process. suffix: An optional suffix to append to the base name. - append: Whether to append to the existing process title. + prefix: A prefix to prepend to the front separated by `::`. """ if suffix: name = f"{name}_{suffix}" - if append: - name = f"{setproctitle.getproctitle()}_{name}" - else: - name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}" - setproctitle.setproctitle(name) + setproctitle.setproctitle(f"{prefix}::{name}") def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 90cdd396209c7..38d92f01192b1 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -67,23 +67,6 @@ def _missing(*_: Any, **__: Any) -> NoReturn: "package to enable FP8 kernels.") -def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None: - """Return the *new* symbol if it exists, otherwise the *old* one.""" - if hasattr(module, new): - return getattr(module, new) - if hasattr(module, old): - # TODO(wentao): deprecate old symbol in the future. - logger.warning_once( - "Found legacy DeepGEMM symbol `%s`. Please upgrade the `deep_gemm` " - "package so that `%s` is available. Support for the legacy symbol " - "will be removed in a future vLLM release.", - old, - new, - ) - return getattr(module, old) - return None - - _fp8_gemm_nt_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None @@ -109,14 +92,9 @@ def _lazy_init() -> None: _dg = importlib.import_module("deep_gemm") - _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt", - "gemm_fp8_fp8_bf16_nt") - _grouped_impl = _resolve_symbol( - _dg, "m_grouped_fp8_gemm_nt_contiguous", - "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous") - _grouped_masked_impl = _resolve_symbol( - _dg, "fp8_m_grouped_gemm_nt_masked", - "m_grouped_gemm_fp8_fp8_bf16_nt_masked") + _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None) + _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) + _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None) def fp8_gemm_nt(*args, **kwargs): diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index fab134733d4fd..2179bddae2435 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -154,28 +154,31 @@ def has_nvidia_artifactory() -> bool: @functools.cache -def supports_trtllm_attention() -> tuple[bool, Optional[str]]: - """Cache result which only depends on the environment""" - # This is a lambda, call it once - env_value = envs.VLLM_USE_TRTLLM_ATTENTION - +def supports_trtllm_attention() -> bool: + """ + TRTLLM attention is supported if the platform is SM100 and + NVIDIA artifactory is accessible + """ # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - if not (current_platform.is_device_capability(100) - and has_nvidia_artifactory()): - return False, env_value + return current_platform.is_device_capability( + 100) and has_nvidia_artifactory() + +@functools.cache +def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]: + """Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" if env_value is not None: logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) - # Environment variable is set - respect it - # Making the conditional check for zero because - # the path is automatically enabled if the batch size condition - # is satisfied. - use_trtllm = (env_value == "1") - if use_trtllm: - logger.info_once("Using TRTLLM attention.") - return use_trtllm, env_value + return env_value - return True, None + +def force_use_trtllm_attention() -> Optional[bool]: + """ + Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set, + return ``True`` if TRTLLM attention is forced to be used, + return ``False`` if TRTLLM attention is forced to be not used. + """ + return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) def use_trtllm_attention( @@ -185,26 +188,41 @@ def use_trtllm_attention( max_seq_len: int, kv_cache_dtype: str, q_dtype: torch.dtype, - is_prefill: bool, has_sinks: bool = False, ) -> bool: - use_trtllm, env_value = supports_trtllm_attention() - if not use_trtllm: + """Return ``True`` if TRTLLM attention is used.""" + force_use_trtllm = force_use_trtllm_attention() + + # Environment variable is set to 0 - respect it + if force_use_trtllm is not None and not force_use_trtllm: return False + # The platform is not supported + if not supports_trtllm_attention(): + if force_use_trtllm: + logger.warning_once( + "TRTLLM attention is not supported on this platform, " + "but VLLM_USE_TRTLLM_ATTENTION is set to 1") + return False + + # The combination of query and key heads is not supported if num_qo_heads % num_kv_heads != 0: + if force_use_trtllm: + logger.warning_once( + "TRTLLM attention is not supported for this combination of " + "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1" + ) return False # Must use TRTLLM attention if query is FP8 quantized if q_dtype == current_platform.fp8_dtype(): + if has_sinks: + raise RuntimeError( + "TRTLLM FP8-qkv kernel is not supported for attention sinks. " + "Use kv_cache_dtype=auto for now.") logger.info_once("Using TRTLLM attention (query is quantized).") return True - # TRTLLM prefill attention does not support FP8 kv cache with - # non-quantized query - if is_prefill and kv_cache_dtype.startswith("fp8"): - return False - # If sinks are being used, we must use TRTLLM attention as it's # the only backend that supports them if has_sinks: @@ -212,15 +230,17 @@ def use_trtllm_attention( "Using TRTLLM attention (required for attention sinks).") return True - if env_value is None: + if force_use_trtllm is None: # Environment variable not set - use auto-detection - use_trtllm = (num_tokens <= 256 and max_seq_len < 131072 + use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto") if use_trtllm: logger.warning_once("Using TRTLLM attention (auto-detected).") return use_trtllm # Environment variable is set to 1 - respect it + logger.info_once( + "Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") return True @@ -353,6 +373,12 @@ def flashinfer_scaled_fp8_mm( return output +@functools.cache +def flashinfer_disable_q_quantization() -> bool: + """Cache result which only depends on the environment""" + return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -366,6 +392,7 @@ __all__ = [ "has_nvidia_artifactory", "supports_trtllm_attention", "use_trtllm_attention", + "flashinfer_disable_q_quantization", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", ] diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index ced8234a7b433..6627164c98798 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device) -> None: - self.kv_cache_spec = kv_cache_spec - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.scheduler_config = vllm_config.scheduler_config # For reorder @@ -641,10 +641,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): attn_metadata: TorchSDPAMetadata, attn_type: str = AttentionType.DECODER, ) -> None: - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - attn_masks = attn_metadata.get_attn_bias(attn_type) if attn_masks is None: if self.alibi_slopes is not None: @@ -665,6 +661,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) + value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) + causal_attn = (attn_type == AttentionType.DECODER) seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index dd2b956d4fa3d..20f1904b3be6f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -167,7 +167,7 @@ class FlashAttentionMetadataBuilder( # work for mixed prefill-decode and uniform-decode. But for non-spec decodes # the graphs would not work for mixed prefill-decode; sorta the inverse # of UNIFORM_SINGLE_TOKEN_DECODE. - # Theres probably a better way to describe this using `AttentionCGSupport` + # There's probably a better way to describe this using `AttentionCGSupport` # but for now just set it to `UNIFORM_BATCH` to get use to drop down # to FULL_AND_PIECEWISE. # TODO(luka, lucas): audit FA2 as part of: @@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 5fc3a1517b690..dda6dd4fbea7a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (supports_trtllm_attention, +from vllm.utils.flashinfer import (flashinfer_disable_q_quantization, + supports_trtllm_attention, use_trtllm_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block @@ -48,8 +49,89 @@ FP4_DTYPE = torch.uint8 logger = init_logger(__name__) -class FlashInferBackend(AttentionBackend): +@triton.jit +def _trtllm_prefill_attn_kvfp8_dequant( + kv_cache_ptr, + block_tables_prefill_ptr, + block_table_stride, + mock_kv_cache_ptr, + k_scale_ptr, + v_scale_ptr, + K_CACHE_STRIDE: tl.constexpr, + KV_CACHE_STRIDE: tl.constexpr, +): + batch_idx = tl.program_id(0).to(tl.int64) + mock_block_table_idx = tl.program_id(1).to(tl.int64) + orig_page_num = tl.load(block_tables_prefill_ptr + + batch_idx * block_table_stride + + mock_block_table_idx).to(tl.int64) + if orig_page_num <= 0: + return + dequant_dtype = mock_kv_cache_ptr.dtype.element_ty + # Dequantize K + k_scale_val = tl.load(k_scale_ptr) + offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val + mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx + + 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + # Dequantize V + v_scale_val = tl.load(v_scale_ptr) + offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + + tl.arange(0, K_CACHE_STRIDE)) + fp8_vals = tl.load(kv_cache_ptr + offset) + dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val + mock_cache_offset = ( + (batch_idx * block_table_stride + mock_block_table_idx + 1) * + KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)) + dequantized_vals = dequantized_vals.to(dequant_dtype) + tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + +def trtllm_prefill_attn_kvfp8_dequant( + kv_cache: torch.Tensor, + block_tables_prefill: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + dequant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_of_page_per_token = block_tables_prefill.shape + s = kv_cache.shape + assert s[1] == 2 + assert dequant_dtype in (torch.bfloat16, torch.float16) + k_cache_stride = s[2] * s[3] * s[4] + kv_cache_stride = k_cache_stride * s[1] + new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) + # mock kv cache contains just the pages needed by this prefill + mock_kv_cache = torch.empty(new_s, + dtype=dequant_dtype, + device=kv_cache.device) + # we simply sequentially index the pages needed by this prefill + mock_block_table = torch.arange( + start=1, + end=batch_size * num_of_page_per_token + 1, + dtype=torch.int32, + device=block_tables_prefill.device, + ).reshape(batch_size, num_of_page_per_token) + grid = (batch_size, num_of_page_per_token) + _trtllm_prefill_attn_kvfp8_dequant[grid]( + kv_cache, + block_tables_prefill, + num_of_page_per_token, + mock_kv_cache, + k_scale, + v_scale, + k_cache_stride, + kv_cache_stride, + ) + return mock_kv_cache, mock_block_table + + +class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True @classmethod @@ -122,7 +204,6 @@ class FlashInferBackend(AttentionBackend): @dataclass class FlashInferMetadata: - num_actual_tokens: int # Number of tokens excluding padding. # The data type of the query @@ -163,11 +244,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.device = device - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config - self.kv_cache_spec = kv_cache_spec self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) @@ -177,8 +256,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\ - decode_mode() == CUDAGraphMode.FULL + self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\ + decode_mode() == CUDAGraphMode.FULL) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. @@ -194,20 +273,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): FlashInferBackend.validate_head_size(self.head_dim) self.page_size = self.kv_cache_spec.block_size - self.enable_fusion = ( - self.compilation_config.pass_config.enable_attn_fusion) - self.q_data_type = self.model_config.dtype self.cache_dtype = self.cache_config.cache_dtype if self.cache_dtype.startswith("fp8"): self.kv_cache_dtype = ( FlashInferBackend.get_fp8_dtype_for_flashinfer( self.cache_dtype)) - # Insert FP8 quant for query if FP8 kv cache and attn fusion enabled - if self.enable_fusion: - self.q_data_type = self.kv_cache_dtype else: + assert self.kv_cache_spec.dtype == self.model_config.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype + # Use model dtype as q dtype when TRTLLM attn is not supported, or + # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to + # use fp8 q if kv cache is fp8, and will fall back to model dtype + # if TRTLLM attention kernel is not used when building attn metadata + if supports_trtllm_attention() and \ + not flashinfer_disable_q_quantization(): + self.q_data_type = self.kv_cache_dtype + else: + self.q_data_type = self.model_config.dtype + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers @@ -218,7 +302,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.window_left = self.global_hyperparameters.window_left self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap self.has_sinks = self.global_hyperparameters.has_sinks - + if self.has_sinks and not supports_trtllm_attention(): + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention on " + "earlier GPUs.") # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, dtype=torch.int32, @@ -291,7 +379,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indices_buffer=paged_kv_indices, paged_kv_last_page_len_buffer=paged_kv_last_page_len, # Tensor cores are enabled by default because the perf would be - # atleast as good as cuda cores for all attention ops in latest + # at least as good as cuda cores for all attention ops in latest # gpus. use_tensor_cores=True, ) @@ -317,7 +405,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) page_size = self.page_size max_q_len = common_attn_metadata.max_query_len @@ -392,14 +481,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_last_page_len_np, ) - # Check if any layer uses sinks (requires TRTLLM attention) prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_prefill_tokens, max_seq_len, self.cache_dtype, self.q_data_type, - is_prefill=True, has_sinks=self.has_sinks) decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, @@ -407,8 +494,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): max_seq_len, self.cache_dtype, self.q_data_type, - is_prefill=False, has_sinks=self.has_sinks) + if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention on " + "earlier GPUs.") + + # If TRTLLM attention is not used, the q quantization is not supported. + # Fall back to use model dtype. + if not (prefill_use_trtllm and decode_use_trtllm): + self.q_data_type = self.model_config.dtype attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -542,22 +638,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) return attn_metadata - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with FlashInfer. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "FlashInfer only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) - def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting @@ -667,8 +747,6 @@ class FlashInferImpl(AttentionImpl): # The attn+quant fusion happens when output_scale is provided. if output_scale is None: - assert attn_metadata.q_data_type != FP8_DTYPE, \ - "Query can only be FP8 if output fusion happened." assert output_block_scale is None, "output_block_scale "\ "is not supported when fusion has not happened" else: @@ -686,7 +764,7 @@ class FlashInferImpl(AttentionImpl): else: raise ValueError(f"Unsupported output dtype: {output.dtype}") - # TRTLLM attn kernel requires o scale to pass as a host scalar, + # TRTLLM attn kernel requires to scale to pass as a host scalar, # store the o scale as a host scalar in warmup run with cuda graph # not enabled if layer._o_scale_float is None: @@ -696,7 +774,8 @@ class FlashInferImpl(AttentionImpl): elif output.dtype == FP4_DTYPE: self.o_sf_scale = layer._o_scale_float - # Insert FP8 quant for query + # Insert FP8 quant for query + if attn_metadata.q_data_type == FP8_DTYPE: num_tokens, num_heads, head_size = query.shape query, _ = ops.scaled_fp8_quant( query.reshape( @@ -805,11 +884,29 @@ class FlashInferImpl(AttentionImpl): assert self.o_sf_scale is None out = output[num_decode_tokens:] + if attn_metadata.q_data_type != FP8_DTYPE \ + and self.kv_cache_dtype.startswith("fp8"): + # TRTLLM prefill attention does not support BF16 Q + # and fp8 kv cache. So to enable prefill attention + # with fp8 kv cache, we can construct a mock block + # and mock kv cache with BF16 KV involved in the prefill + mock_kv_cache, mock_block_table = ( + trtllm_prefill_attn_kvfp8_dequant( + kv_cache_permute, + block_tables_prefill, + layer._k_scale, + layer._v_scale, + attn_metadata.q_data_type, + )) + else: + mock_kv_cache = kv_cache_permute + mock_block_table = block_tables_prefill + trtllm_batch_context_with_kv_cache( query=prefill_query, - kv_cache=kv_cache_permute, + kv_cache=mock_kv_cache, workspace_buffer=workspace_buffer, - block_tables=block_tables_prefill, + block_tables=mock_block_table, seq_lens=seq_lens_prefill, max_q_len=attn_metadata.max_q_len, max_kv_len=attn_metadata.max_seq_len, @@ -847,7 +944,7 @@ class FlashInferImpl(AttentionImpl): decode_query = decode_query.contiguous() workspace_buffer = decode_wrapper._float_workspace_buffer block_tables_decode = attn_metadata.\ - block_table_tensor[:num_decode_tokens] + block_table_tensor[:num_decode_tokens] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index d5b1c15e68d0e..662d3984554ad 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) @@ -719,6 +720,15 @@ class FlexAttentionImpl(AttentionImpl): (query, key, value), ) + query = query[:, :, :num_actual_tokens, :] + if ((key_tensor.size(-2) > num_actual_tokens) + or (value_tensor.size(-2) > num_actual_tokens)): + # In the encoder-only model with torch.compile, + # qkv might be padded, which might cause exception. + # see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290 + key_tensor = key_tensor[:, :, :num_actual_tokens, :] + value_tensor = value_tensor[:, :, :num_actual_tokens, :] + else: assert self.attn_type == AttentionType.DECODER key_cache, value_cache = kv_cache.unbind(0) @@ -743,7 +753,8 @@ class FlexAttentionImpl(AttentionImpl): (query, key_cache, value_cache), ) - query = query[:, :, :num_actual_tokens, :] + query = query[:, :, :num_actual_tokens, :] + # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py new file mode 100644 index 0000000000000..5dadc52d0fb1c --- /dev/null +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Backend for GatedDeltaNet attention.""" +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class GDNAttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: + return GDNAttentionMetadataBuilder + + +@dataclass +class GDNAttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + num_spec_decodes: int + num_spec_decode_tokens: int + num_actual_tokens: int + + has_initial_state: Optional[torch.Tensor] = None + + spec_query_start_loc: Optional[ + torch.Tensor] = None # shape: [num_spec_decodes + 1,] + non_spec_query_start_loc: Optional[ + torch.Tensor] = None # shape: [batch - num_spec_decodes + 1,] + + spec_state_indices_tensor: Optional[ + torch.Tensor] = None # shape: [batch, num_spec] + non_spec_state_indices_tensor: Optional[ + torch.Tensor] = None # shape: [batch - num_spec_decodes,] + spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,] + spec_token_masks: Optional[ + torch. + Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,] + num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] + + # The following attributes are for triton implementation of causal_conv1d + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None + + +class GDNAttentionMetadataBuilder( + AttentionMetadataBuilder[GDNAttentionMetadata]): + + cudagraph_support = AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.speculative_config = vllm_config.speculative_config + self.kv_cache_spec = kv_cache_spec + if self.speculative_config: + self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501 + else: + self.num_spec = 0 + self.use_spec_decode = self.num_spec > 0 + self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc] + + self.use_full_cuda_graph = \ + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs * + (self.num_spec + 1), self.compilation_config.max_capture_size) + + self.spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, self.num_spec + 1), + dtype=torch.int32, + device=device, + ) + self.non_spec_state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + self.spec_sequence_masks = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.bool, + device=device, + ) + self.spec_token_masks = torch.empty( + (self.decode_cudagraph_max_bs * (self.num_spec + 1), ), + dtype=torch.bool, + device=device, + ) + self.spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1, ), + dtype=torch.int32, + device=device, + ) + self.non_spec_query_start_loc = torch.empty( + (self.decode_cudagraph_max_bs + 1, ), + dtype=torch.int32, + device=device, + ) + self.num_accepted_tokens = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + + def build( # type: ignore[override] + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + num_accepted_tokens: Optional[torch.Tensor] = None, + num_draft_tokens: Optional[torch.Tensor] = None, + fast_build: bool = False, + ) -> GDNAttentionMetadata: + m = common_attn_metadata + + query_start_loc = m.query_start_loc + context_lens = m.num_computed_tokens_cpu + context_lens_tensor = context_lens.to(query_start_loc.device) + seq_lens_tensor = m.seq_lens + + if (not self.use_spec_decode or num_draft_tokens is None + or num_draft_tokens.sum().item() == 0): + spec_sequence_masks = None + else: + spec_sequence_masks = (num_draft_tokens > 0) & ( + context_lens_tensor + + (num_draft_tokens + 1) == seq_lens_tensor) + if spec_sequence_masks.sum().item() == 0: + spec_sequence_masks = None + + if spec_sequence_masks is None: + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(m, decode_threshold=1)) + num_spec_decodes = 0 + num_spec_decode_tokens = 0 + spec_token_masks = None + spec_state_indices_tensor = None + non_spec_state_indices_tensor = m.block_table_tensor[:, 0] + spec_query_start_loc = None + non_spec_query_start_loc = query_start_loc + num_accepted_tokens = None + else: + num_spec_decodes = spec_sequence_masks.sum().item() + query_lens = query_start_loc[1:] - query_start_loc[:-1] + + non_spec_query_lens = query_lens[~spec_sequence_masks] + num_decodes = (non_spec_query_lens == 1).sum().item() + num_prefills = non_spec_query_lens.size(0) - num_decodes + num_decode_tokens = num_decodes + num_prefill_tokens = non_spec_query_lens.sum().item( + ) - num_decode_tokens + + if num_prefills == 0 and num_decodes == 0: + spec_token_masks = torch.ones( + (min(num_spec_decodes * + (self.num_spec + 1), query_start_loc[-1].item())), + dtype=torch.bool, + device=query_start_loc.device) + spec_state_indices_tensor = m.block_table_tensor[:, :self. + num_spec + 1] + non_spec_state_indices_tensor = None + spec_query_start_loc = query_start_loc + non_spec_query_start_loc = None + else: + spec_token_masks = torch.repeat_interleave( + spec_sequence_masks, query_lens) + spec_state_indices_tensor = m.block_table_tensor[ + spec_sequence_masks, :self.num_spec + 1] + non_spec_state_indices_tensor = \ + m.block_table_tensor[~spec_sequence_masks, 0] + + spec_query_start_loc = torch.zeros( + num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device) + torch.cumsum(query_lens[spec_sequence_masks], + dim=0, + out=spec_query_start_loc[1:]) + non_spec_query_start_loc = torch.zeros( + query_lens.size(0) - num_spec_decodes + 1, + dtype=torch.int32, + device=query_start_loc.device) + torch.cumsum(query_lens[~spec_sequence_masks], + dim=0, + out=non_spec_query_start_loc[1:]) + + num_spec_decode_tokens = (query_lens.sum().item() - + num_prefill_tokens - num_decode_tokens) + assert num_accepted_tokens is not None + num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] + + if num_prefills > 0: + has_initial_state = context_lens_tensor > 0 + if spec_sequence_masks is not None: + has_initial_state = has_initial_state[~spec_sequence_masks] + else: + has_initial_state = None + num_actual_tokens = num_prefill_tokens + num_decode_tokens + \ + num_spec_decode_tokens + + # prepare tensors for cudagraph + # + # With speculative decoding, the xgrammar backend may rollback tokens + # and causing some sequences has less draft tokens than self.num_spec. + # + # In above cases, the max possible batch size for n tokens, can be + # min(n, cudagraph_max_bs). + if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0 + and num_spec_decodes <= self.decode_cudagraph_max_bs + and num_spec_decode_tokens <= self.decode_cudagraph_max_bs): + num_actual_tokens = self.vllm_config.pad_for_cudagraph( + m.num_actual_tokens) + batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens) + + self.spec_state_indices_tensor[:num_spec_decodes].copy_( + spec_state_indices_tensor, non_blocking=True) + spec_state_indices_tensor = self.spec_state_indices_tensor[: + batch_size] + spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) + + self.spec_sequence_masks[:num_spec_decodes].copy_( + spec_sequence_masks, non_blocking=True) + spec_sequence_masks = self.spec_sequence_masks[:batch_size] + spec_sequence_masks[num_spec_decodes:].fill_(False) + + assert spec_token_masks is not None + self.spec_token_masks[:spec_token_masks.size(0)].copy_( + spec_token_masks, non_blocking=True) + spec_token_masks = self.spec_token_masks[:num_actual_tokens] + spec_token_masks[spec_token_masks.size(0):].fill_(False) + + self.spec_query_start_loc[:num_spec_decodes + 1].copy_( + spec_query_start_loc, non_blocking=True) + spec_num_query_tokens = spec_query_start_loc[ + -1] # type: ignore[index] + spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1] + spec_query_start_loc[num_spec_decodes + + 1:].fill_(spec_num_query_tokens) + + self.num_accepted_tokens[:num_spec_decodes].copy_( + num_accepted_tokens, non_blocking=True) + num_accepted_tokens = self.num_accepted_tokens[:batch_size] + num_accepted_tokens[num_spec_decodes:].fill_(1) + + if (self.use_full_cuda_graph and num_prefills == 0 + and num_spec_decodes == 0 + and num_decodes <= self.decode_cudagraph_max_bs): + num_actual_tokens = self.vllm_config.pad_for_cudagraph( + m.num_actual_tokens) + batch_size = num_actual_tokens + + self.non_spec_state_indices_tensor[:num_decodes].copy_( + non_spec_state_indices_tensor, non_blocking=True) + non_spec_state_indices_tensor = \ + self.non_spec_state_indices_tensor[:batch_size] + non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) + + self.non_spec_query_start_loc[:num_decodes + 1].copy_( + non_spec_query_start_loc, non_blocking=True) + non_spec_num_query_tokens = non_spec_query_start_loc[ + -1] # type: ignore[index] + non_spec_query_start_loc = \ + self.non_spec_query_start_loc[:batch_size + 1] + non_spec_query_start_loc[num_decodes + + 1:].fill_(non_spec_num_query_tokens) + + attn_metadata = GDNAttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_spec_decodes=num_spec_decodes, + num_spec_decode_tokens=num_spec_decode_tokens, + num_actual_tokens=num_actual_tokens, + has_initial_state=has_initial_state, + spec_query_start_loc=spec_query_start_loc, + non_spec_query_start_loc=non_spec_query_start_loc, + spec_state_indices_tensor=spec_state_indices_tensor, + non_spec_state_indices_tensor=non_spec_state_indices_tensor, + spec_sequence_masks=spec_sequence_masks, + spec_token_masks=spec_token_masks, + num_accepted_tokens=num_accepted_tokens, + ) + return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens + and ((m.num_reqs + 1) * (self.num_spec + 1) + >= m.num_actual_tokens)), \ + "GDN only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + num_accepted_tokens = torch.full((m.num_reqs, ), + m.max_query_len, + dtype=torch.int32, + device=m.query_start_loc.device) + num_drafted_tokens = torch.full((m.num_reqs, ), + self.num_spec, + dtype=torch.int32, + device=m.query_start_loc.device) + + # Fixes query-start loc for spec-sequence-indices. + m.query_start_loc = torch.arange(0, + m.num_actual_tokens + 1, + step=m.max_query_len, + device=m.query_start_loc.device, + dtype=torch.int32) + m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full( + (m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu')) + + return self.build(0, m, num_accepted_tokens, num_drafted_tokens) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index f08b6d7f177c7..3ff201d83a79b 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec def build(self, common_prefix_len: int, @@ -52,8 +52,9 @@ class LinearAttentionMetadataBuilder( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) attn_metadata = LinearAttentionMetadata( num_prefills=num_prefills, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 97a1aa86dda0d..7cbfa2c2c9a54 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -50,8 +50,9 @@ class Mamba1AttentionMetadataBuilder( query_start_loc.device) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) has_initial_states = None padded_decodes = num_decodes diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ed30884fdbc94..2fe1f14ca1db0 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -16,9 +16,58 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, from vllm.v1.kv_cache_interface import AttentionSpec -def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, - chunk_size: int, - total_seqlens: int): +def _query_start_loc_to_chunk_indices_offsets( + query_start_loc: torch.Tensor, chunk_size: int, + total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + query_start_loc (torch.Tensor): 1D tensor of cumulative sequence + lengths, shape (num_seqs + 1,). + The first element should be 0. Each entry represents the starting + index of a sequence in the flattened token array. + chunk_size (int): The size of each physical mamba chunk + (number of tokens per chunk). + total_seqlens (int): The total number of tokens in the batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - chunk_indices (torch.Tensor): 1D tensor of indices + indicating the physical chunk for each logical chunk. + - chunk_offsets (torch.Tensor): 1D tensor of offsets + indicating the starting index of each logical chunk within + its physical chunk. + + This function computes the chunk indices and offsets for the given + query_start_loc and chunk_size. Both are tensors of integers with length N, + where N is the number of logical (pseudo) chunks. + A logical chunk is a sequence of tokens that are all part of the same + sequence and are all in the same physical mamba chunk. + In other words, a logical chunk changes every time we cross a sequence + boundary or a physical mamba chunk boundary. + Logical chunks are needed to handle batched requests with initial states + (see _state_passing_fwd and _chunk_scan_fwd). + The chunk_indices tensor contains the index of the physical chunk for each + logical chunk. + The chunk_offsets tensor contains the offset (AKA starting index) of the + logical chunk in the physical chunk. + + Example: + query_start_loc = [0, 5, 10] + chunk_size = 8 + total_seqlens = 10 + -> chunk_indices = [0, 0, 1] + -> chunk_offsets = [0, 5, 0] + + In this example, we have 2 sequences, each with 5 tokens. The physical + chunk size is 8 tokens. + We have three logical chunks: + - the first logical chunk starts at token 0 in the first physical chunk + and contains all 5 tokens from the first sequence + - the second logical chunk starts at token 5 in the first physical chunk + and contains first 3 tokens from the second sequence + - the third logical chunk starts at token 0 in the second physical chunk + and contains the remaining 2 tokens from the second sequence + """ cu_seqlens = query_start_loc[1:] # remove prepended 0 @@ -83,8 +132,8 @@ class Mamba2AttentionMetadata: # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None class Mamba2AttentionMetadataBuilder( @@ -115,8 +164,9 @@ class Mamba2AttentionMetadataBuilder( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 07ef7cb69a160..9970331a6042c 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - self.device = device - self.vllm_config = vllm_config - self.layer_names = layer_names + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + assert isinstance(kv_cache_spec, MambaSpec) self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs, @@ -52,4 +49,4 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): m.max_query_len = 1 # decode-only - return self.build(0, m) \ No newline at end of file + return self.build(0, m) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 9f93b50b075b4..a990cb2f1a972 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -201,10 +201,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import get_mla_dims +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig -from vllm.distributed.parallel_state import is_global_first_rank +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -323,6 +324,13 @@ class MLACommonPrefillMetadata: seq_lens: torch.Tensor workspace: torch.Tensor + # for mla DCP + cp_chunk_seq_lens: Optional[list[list[int]]] = None + origin_context_lens: Optional[list[int]] = None + cp_cu_seq_lens: Optional[torch.Tensor] = None + chunk_size: Optional[int] = None + cu_seq_lens_lst: Optional[list[list[int]]] = None + block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int @@ -373,6 +381,7 @@ class MLACommonMetadata(Generic[D]): num_reqs: int max_query_len: int + max_seq_len: int num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor @@ -401,7 +410,7 @@ M = TypeVar("M", bound=MLACommonMetadata) def use_flashinfer_prefill() -> bool: - # For blackwell default to flashinfer prefill if its available since + # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL and current_platform.is_device_capability(100)) @@ -435,17 +444,26 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata self.kv_cache_spec = kv_cache_spec - self.device = device scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config + cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.device = device + self.num_heads = self.model_config.get_num_attention_heads( parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 - # Dont try to access the runner on AMD + # Don't try to access the runner on AMD if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size @@ -465,12 +483,27 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): 128 * 1024) assert self.chunked_prefill_workspace_size >= \ scheduler_config.max_num_seqs * cache_config.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) + if self.dcp_world_size > 1: + # Note(hc): The local kvcache is incomplete when DCP is triggered, + # an additional kvcache allgather across the DCP group is therefore + # required, so the workspace has to be enlarged by 1/DCP relative + # to the original TP allocation. + assert self.chunked_prefill_workspace_size % \ + self.dcp_world_size == 0 + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + else: + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() @@ -551,7 +584,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) # Prepare context prefills @@ -572,17 +604,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): logits_soft_cap=self._global_hyperparameters. logits_soft_cap, q_data_type=self.model_config.dtype, - kv_data_type=self.kv_cache_spec.dtype, ) prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor): + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, ) def build_for_cudagraph_capture( @@ -592,11 +627,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ + assert m.num_reqs <= (m.num_actual_tokens * + self.reorder_batch_threshold), \ "MLA only supports decode-only full CUDAGraph capture. " \ "Make sure all cudagraph capture sizes <= max_num_seq." - assert m.max_query_len == 1 # decode-only + assert m.max_query_len <= self.reorder_batch_threshold # decode only return self.build(0, m) @@ -607,6 +643,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because @@ -618,6 +655,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -625,7 +663,14 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_seq_lens_cpu) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) + + # Note(hc): update seq_lens of decode reqs under DCP. + if self.dcp_world_size > 1: + seq_lens[:num_decodes] = seq_lens[:num_decodes] \ + // self.dcp_world_size + (self.dcp_rank <= \ + (seq_lens[:num_decodes] - 1) % self.dcp_world_size) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -635,6 +680,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): reqs_start = num_decodes # prefill_start context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + # Note(hc): The context lengths in the perspective of dcp rank0. + cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() / + self.dcp_world_size).int() + origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = query_start_loc[ @@ -687,20 +736,66 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + if self.dcp_world_size > 1: + # Note(hc): The above max_context_chunk already enforces + # block_size alignment, DCP just need the block_size can + # be divisible by dcp_world_size, because DCP use + # cp_gather_cache which not require `cp_chunk_starts` + # aligned to page_size. + assert max_context_chunk % self.dcp_world_size == 0 + cp_max_context_chunk = max_context_chunk // \ + self.dcp_world_size + cp_chunk_starts = \ + torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) \ + * cp_max_context_chunk + cp_chunk_ends = torch.min( + cp_context_lens_cpu.unsqueeze(0), + cp_chunk_starts + cp_max_context_chunk) + cp_chunk_seq_lens = (cp_chunk_ends - + cp_chunk_starts).clamp(min=0) + + cp_cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(cp_chunk_seq_lens, + dim=1, + out=cp_cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata_cls = \ CudnnPrefillMetadata.ChunkedContextMetadata \ if self._use_cudnn_prefill else \ MLACommonPrefillMetadata.ChunkedContextMetadata - - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - seq_lens=chunk_seq_lens, - workspace=self.chunked_prefill_workspace, - ) + if self.dcp_world_size > 1: + chunked_context_metadata = \ + chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu \ + .to(device, non_blocking=True), + starts=cp_chunk_starts.to(device, non_blocking=True), + seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), + origin_context_lens=origin_context_lens, + cp_cu_seq_lens=cp_cu_seq_lens_cpu \ + .to(device, non_blocking=True), + chunk_size=max_context_chunk, + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + ) + else: + chunked_context_metadata = \ + chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu \ + .to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) if self._use_cudnn_prefill: chunked_context_metadata.seq_lens = chunk_seq_lens @@ -725,12 +820,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): if num_decodes > 0: decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], - seq_lens=seq_lens[:num_decodes], + seq_lens_cpu=seq_lens_cpu[:num_decodes], + seq_lens_device=seq_lens[:num_decodes], + query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], + query_start_loc_device=query_start_loc[:num_decodes + 1], + num_decode_tokens=num_decode_tokens, ) attn_metadata = self.metadata_cls( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, @@ -750,6 +850,71 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): return attn_metadata +def reorg_kvcache( + allgatered_kv_c_normed: torch.Tensor, + allgatered_k_pe: torch.Tensor, + cp_chunk_seq_lens_lst: list[int], + origin_context_lens: list[int], + cp_world_size: int, + sum_seq_len: int, + max_seq_len: int, + chunk_size: int, + chunk_idx: int, + toks: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + reorg kvcache after cp local gather to tp layout for attn kernel. + + Args: + cp_chunk_seq_lens_lst: chunk context lengths under CP. + origin_context_lens: origin full context lengths under CP. + cp_world_size: CP size. + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: equals to max_context_chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst, + origin_context_lens): + chunk_context_len = chunk_size + if cp_chunk_seq_len != 0: + chunk_context_len = min( + chunk_context_len, origin_context_len - chunk_size * chunk_idx) + cp_target_rank = (chunk_context_len - 1) % cp_world_size + cur_seq_len = 0 + for rank in range(cp_world_size): + if rank > cp_target_rank and cp_chunk_seq_len: + real_cp_chunk_seq_len = cp_chunk_seq_len - 1 + else: + real_cp_chunk_seq_len = cp_chunk_seq_len + if real_cp_chunk_seq_len: + kv_c_segment = allgatered_kv_c_normed[rank * toks + + src_token_idx:rank * + toks + src_token_idx + + real_cp_chunk_seq_len] + k_pe_segment = allgatered_k_pe[rank * toks + + src_token_idx:rank * toks + + src_token_idx + + real_cp_chunk_seq_len] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += real_cp_chunk_seq_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += cp_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe + + class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ NOTE: Please read the comment at the top of the file before trying to @@ -829,6 +994,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.vllm_flash_attn_version == 3 and current_platform.get_device_capability()[0] == 9) + self.dcp_world_size: Optional[int] = None + def _flash_attn_varlen_diff_headdims(self, q, k, @@ -1011,7 +1178,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return layer.weight # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( @@ -1145,6 +1312,108 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return output, output_lse + def _context_parallel_compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + dcp_world_size: int, + ): + assert k_scale is None, "DCP not support scaled kvcache now." + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None + assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.origin_context_lens is not None + assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None + assert prefill_metadata.chunked_context.chunk_size is not None + assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + # workspace + # |------- N tokens --------|--------- N*dcp_size tokens ----------| + # |<- use for loca_gather ->|<--------- use for allgather -------->| + allgather_offset = workspace.shape[0] // (dcp_world_size + 1) + assert allgather_offset * (dcp_world_size + + 1) == workspace.shape[0] + assert toks <= allgather_offset + local_gathered_kvcache = workspace[:toks] + cur_allgather_workspace = workspace[ + allgather_offset:allgather_offset * (1 + dcp_world_size)] + assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] + cur_allgather_kvcache = cur_allgather_workspace[:toks * + dcp_world_size] + cur_allgather_kvcache.copy_(get_dcp_group().all_gather( + local_gathered_kvcache, dim=0)) + assert cur_allgather_kvcache.shape[ + -1] == self.kv_lora_rank + self.qk_rope_head_dim + allgatered_kv_c_normed, allgatered_k_pe = \ + cur_allgather_kvcache.unsqueeze( + 1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kv_c_normed, k_pe = reorg_kvcache( + allgatered_kv_c_normed, + allgatered_k_pe, + cp_chunk_seq_lens_lst=prefill_metadata.chunked_context. + cp_chunk_seq_lens[i], + origin_context_lens=prefill_metadata.chunked_context. + origin_context_lens, + cp_world_size=dcp_world_size, + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i] + [-1], + max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, + toks=toks) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + def _forward_prefill( self, q: torch.Tensor, @@ -1155,6 +1424,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k_scale: torch.Tensor, ) -> torch.Tensor: assert attn_metadata.prefill is not None + assert self.dcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ @@ -1174,8 +1444,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if has_context: suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata, k_scale) + if self.dcp_world_size > 1: + context_output, context_lse = \ + self._context_parallel_compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, + k_scale=None, dcp_world_size=self.dcp_world_size) + else: + context_output, context_lse = \ + self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1195,12 +1472,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): @abstractmethod def _forward_decode( self, - ql_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: raise NotImplementedError def forward( @@ -1228,6 +1504,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # same expert outputs. return output.fill_(0) + if self.dcp_world_size is None: + self.dcp_world_size = get_dcp_group().world_size + fp8_attention = self.kv_cache_dtype.startswith("fp8") num_actual_toks = attn_metadata.num_actual_tokens @@ -1306,7 +1585,22 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): layer._q_scale) decode_q_pe = decode_q_pe.reshape(q_pe_shape) - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) + decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + # call decode attn + attn_out, lse = self._forward_decode(decode_q, kv_cache, + attn_metadata, layer) + + # recorect dcp attn_out with lse. + if self.dcp_world_size > 1: + attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) + + # v_up projection + output[:num_decode_tokens] = self._v_up_proj(attn_out) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 8a17d3a492783..21be17a750df4 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch @@ -76,6 +76,7 @@ g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): + can_return_lse_for_decode: bool = True def __init__( self, @@ -108,16 +109,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): "are not implemented for " "CutlassMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "CutlassMLA V1 with FP8 KV cache not yet supported") - - self._use_old_cutlass_mla = False - force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None) - if force_old_cutlass: - logger.warning_once("Forcing old cutlass mla kernel") - self._use_old_cutlass_mla = True - # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging # issues. In case the code hangs, use: # FORCE_NUM_KV_SPLITS=1 @@ -142,7 +133,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): workspace: torch.Tensor, sm_scale: float, num_kv_splits: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: assert (q_nope.ndim == 3 ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}" assert ( @@ -182,11 +173,10 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): > 0), f"block num must be greater than 0, got {block_num}" assert block_num % (128 / PAGE_SIZE) == 0 - # TODO(kaixih@nvidia): support fp8 assert q_nope.dtype in ( - torch.float16, - torch.bfloat16, - ), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}." + torch.float16, torch.bfloat16, torch.float8_e4m3fn), ( + f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got " + f"{q_nope.dtype}.") assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype assert ( seq_lens.dtype == torch.int32 @@ -195,10 +185,16 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): page_table.dtype == torch.int32 ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." - out = q_nope.new_empty((B_q, MAX_HEADS, D_latent)) + dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype) + else q_nope.dtype) + out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) + lse = (torch.empty( + (B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) + if self.need_to_return_lse_for_decode else torch.Tensor()) ops.sm100_cutlass_mla_decode( out, + lse, q_nope, q_pe, kv_c_and_k_pe_cache, @@ -208,83 +204,44 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): sm_scale, num_kv_splits, ) - return out[:, :H].contiguous() - def _sm100_forward_decode( + if H < MAX_HEADS: + # Extract the subsets of the outputs + returned_lse = lse[:, :H].contiguous( + ) if self.need_to_return_lse_for_decode else lse + out = out[:, :H] + + return out, returned_lse + + def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Cutlass MLA not yet supported") + if type(q) is tuple: + q_nope, q_pe = q + else: + q_nope, q_pe = torch.split( + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) # Adjust workspace size (if necessary) self._workspace.ensure_size(attn_metadata, self._num_kv_splits) # Run MLA - # Clone q_nope and q_pe to make sure strides computation is correct. - # TODO: Check if we really need it - q_nope = q_nope.clone() - q_pe = q_pe.clone() + o, lse = self._sm100_cutlass_mla_decode( + q_nope, + q_pe, + kv_c_and_k_pe_cache, + attn_metadata.decode.seq_lens, + attn_metadata.decode.block_table, + self._workspace.get_buf(), + self.scale, + self._num_kv_splits, + ) - o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata.decode.seq_lens, - attn_metadata.decode.block_table, - self._workspace.get_buf(), - self.scale, self._num_kv_splits) - - return self._v_up_proj(o) - - # TODO: Currently we leave it here only for backup in case something is - # wrong with the new SM100 CUTLASS MLA kernel - def _old_forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - assert attn_metadata.decode is not None - - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Cutlass MLA not yet supported") - - B = q_nope.shape[0] - - o = torch.empty((B, self.num_heads, self.kv_lora_rank), - dtype=q_nope.dtype, - device=q_nope.device) - - # Run MLA - # Clone q_nope and q_pe to make sure strides computation is correct. - q_nope = q_nope.clone() - q_pe = q_pe.clone() - - ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata.decode.seq_lens, - attn_metadata.decode.block_table, self.scale) - - return self._v_up_proj(o) - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - layer: AttentionLayer, - ) -> torch.Tensor: - if self._use_old_cutlass_mla: - # TODO: Remove the old cutlass MLA kernel after more extensive - # testing - return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata) - - return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata) + return o, (lse if self.need_to_return_lse_for_decode else None) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py new file mode 100644 index 0000000000000..472095e13615b --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import ClassVar, Optional, Union + +import torch + +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, + is_quantized_kv_cache) +from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, + get_flash_attn_version) +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_dcp_group +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata + +logger = init_logger(__name__) + +# NOTE(matt): This is an arbitrary number, copied from +# woosuk's implementation in standard FlashAttention backend +_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16 + + +class FlashAttnMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_MLA" + + @staticmethod + def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: + return FlashAttnMLAMetadata + + @staticmethod + def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: + return FlashAttnMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashAttnMLAImpl"]: + return FlashAttnMLAImpl + + +@dataclass +class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): + query_start_loc: torch.Tensor + max_query_len: int + max_seq_len: int + scheduler_metadata: Optional[torch.Tensor] = None + max_num_splits: int = 0 + + +@dataclass +class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): + pass + + +class FlashAttnMLAMetadataBuilder( + MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: ClassVar[int] = 512 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, + FlashAttnMLAMetadata) + self.max_num_splits = 0 # No upper bound on the number of splits. + self.fa_aot_schedule = (get_flash_attn_version() == 3) + + self.use_full_cuda_graph = \ + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + + if self.use_full_cuda_graph and self.fa_aot_schedule: + self.max_cudagraph_size = self.compilation_config.max_capture_size + + if self.max_cudagraph_size > 992: + # This condition derives from FA3's internal heuristic. + # TODO(woosuk): Support larger cudagraph sizes. + raise ValueError( + "Capture size larger than 992 is not supported for " + "full cuda graph.") + + self.scheduler_metadata = torch.zeros( + vllm_config.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device, + ) + # When using cuda graph, we need to set the upper bound of the + # number of splits so that large enough intermediate buffers are + # pre-allocated during capture. + self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + + # TODO(lucas): Until we add support for the DCP custom masking we need + # to restrict decodes to q_len == 1 when DCP is enabled. + self.__class__.reorder_batch_threshold = 1 \ + if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + if self.fa_aot_schedule: + return get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + num_heads_q=self.num_heads, + num_heads_kv=1, + headdim=self.mla_dims.qk_rope_head_dim, + cache_seqlens=seqlens, + qkv_dtype=self.kv_cache_spec.dtype, + headdim_v=self.mla_dims.kv_lora_rank, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + num_splits=self.max_num_splits, + ) + return None + + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> FlashAttnMLADecodeMetadata: + query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) + max_query_len = query_lens_cpu.max().item() + max_seq_len = seq_lens_cpu.max().item() + + scheduler_metadata = self._schedule_decode( + num_reqs=seq_lens_cpu.numel(), + cu_query_lens=query_start_loc_device, + max_query_len=max_query_len, + seqlens=seq_lens_device, + max_seq_len=max_seq_len, + causal=True, + ) + + # For FA3 + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and scheduler_metadata is not None: + n = scheduler_metadata.shape[0] + # Ensure the persistent buffer is large enough + assert n <= self.scheduler_metadata.shape[0], \ + f"Scheduler metadata size {n} exceeds buffer size " + \ + f"{self.scheduler_metadata.shape[0]}" + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + if num_decode_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + + return FlashAttnMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + query_start_loc=query_start_loc_device, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + scheduler_metadata=scheduler_metadata, + max_num_splits=max_num_splits, + ) + + +class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + assert flash_attn_supports_mla(), \ + "FlashAttnMLA is not supported on this device" + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashAttnMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttnMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashAttnMLA V1 with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashAttnMLAMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if type(q) is tuple: + q_nope, q_pe = q + else: + q_nope, q_pe = torch.split( + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "FP8 FlashAttention MLA not yet supported") + + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + + # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the + # kernel uses this to calculate grid dimensions. Ensure it's at least 1 + # to prevent invalid grid configuration during graph capture. + max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) + + attn_out = flash_attn_varlen_func( + q=q_pe, + k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 + v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 + q_v=q_nope, + max_seqlen_q=max_seqlen_q, + cu_seqlens_q=attn_metadata.decode.query_start_loc, + max_seqlen_k=attn_metadata.decode.max_seq_len, + seqused_k=attn_metadata.decode.seq_lens, + block_table=attn_metadata.decode.block_table, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=self.need_to_return_lse_for_decode, + fa_version=3, # only version 3 is supported + scheduler_metadata=attn_metadata.decode.scheduler_metadata, + num_splits=attn_metadata.decode.max_num_splits, + ) + + if self.need_to_return_lse_for_decode: + o, lse = attn_out + # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] + return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] + else: + o = attn_out + return o, None diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py new file mode 100644 index 0000000000000..701248670f72e --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla + +from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) + +logger = init_logger(__name__) + +FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + + +class FlashInferMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER_MLA" + + @staticmethod + def get_impl_cls() -> type["FlashInferMLAImpl"]: + return FlashInferMLAImpl + + +g_fi_workspace = torch.zeros( + FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device="cuda", +) + + +class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashInferMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLAImpl") + + self._workspace_buffer = g_fi_workspace + self.bmm1_scale: Optional[float] = None + self.bmm2_scale: Optional[float] = None + + def _forward_decode( + self, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if isinstance(q, tuple): + q_nope, q_pe = q + q = torch.cat([q_nope, q_pe], dim=-1) + + # trtllm API requires extra dimension q_len_per_request for MTP + q = q.unsqueeze(1) + + if self.bmm1_scale is None: + self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * + self.scale) + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + o = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), + workspace_buffer=self._workspace_buffer, + qk_nope_head_dim=self.qk_nope_head_dim, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + block_tables=attn_metadata.decode.block_table, + seq_lens=attn_metadata.decode.seq_lens, + max_seq_len=attn_metadata.max_seq_len, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + ) + + # TODO: Return LSE pending support from Flashinfer API: + # https://github.com/flashinfer-ai/flashinfer/pull/1566 + return o, None diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 1c50144d47900..150e38553e4bb 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch @@ -62,7 +62,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): super().__init__(kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata) - self.compilation_config = vllm_config.compilation_config self.num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) @@ -86,10 +85,14 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): dtype=torch.int32) def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( - seq_lens, + seq_lens_device, self.num_q_heads, 1, # MQA for the decode path ) @@ -123,7 +126,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): return FlashMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, ) @@ -131,6 +134,8 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + can_return_lse_for_decode: bool = True + def __init__( self, num_heads: int, @@ -150,8 +155,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - assert is_flashmla_supported(), \ - "FlashMLA is not supported on this device" + is_supported, reason = is_flashmla_supported() + assert is_supported, reason unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): @@ -167,20 +172,20 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) + if type(q) is tuple: + q = torch.cat(q, dim=-1) - o, _ = flash_mla_with_kvcache( - q=q, + assert isinstance(q, torch.Tensor) + o, lse = flash_mla_with_kvcache( + q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, @@ -194,4 +199,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): descale_k=layer._k_scale.reshape(1), ) - return self._v_up_proj(o) + return o, lse diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 870cc600388e7..db27a34d8959a 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch @@ -105,11 +105,15 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): device=device) def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size - block_table_bounds = (seq_lens + page_size - 1) // page_size + block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device - num_reqs = seq_lens.size(0) + num_reqs = seq_lens_device.size(0) mask = (torch.arange(block_table_tensor.size(1), dtype=block_table_tensor.dtype, @@ -117,7 +121,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[mask] - paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = seq_lens_device % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) @@ -156,7 +160,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, @@ -218,18 +222,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - B = q_nope.shape[0] + if type(q) is tuple: + q = torch.cat(q, dim=-1) - q = torch.cat([q_nope, q_pe], dim=-1) + assert isinstance(q, torch.Tensor) + B = q.shape[0] o = torch.zeros(B, self.num_heads, self.kv_lora_rank, @@ -247,4 +252,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len) - return self._v_up_proj(o) + return o, None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index f2974ed668d99..d692b00d78b46 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch @@ -123,21 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") - B = q_nope.shape[0] + if type(q) is tuple: + q = torch.cat(q, dim=-1) - q = torch.cat([q_nope, q_pe], dim=-1) + assert isinstance(q, torch.Tensor) + B = q.shape[0] o = torch.zeros(B, self.num_heads, self.kv_lora_rank, @@ -171,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata.decode.seq_lens, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) - return self._v_up_proj(o) + return o, None diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 173a0a255e491..afb2283c44d37 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -232,15 +232,15 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata]): - cudagraph_support = AttentionCGSupport.ALWAYS + cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) @@ -248,7 +248,6 @@ class AiterFlashAttentionMetadataBuilder( self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None @@ -480,8 +479,8 @@ class AiterFlashAttentionImpl(AttentionImpl): ) if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(torch.float8_e4m3fnuz) - value_cache = value_cache.view(torch.float8_e4m3fnuz) + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index d80ced8ec876a..717c40b37ecfb 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -34,8 +34,8 @@ class ShortConvAttentionMetadata: # For causal_conv1d nums_dict: Optional[dict] = None cu_seqlen: Optional[int] = None - batch_ptr: Optional[torch.tensor] = None - token_chunk_offset_ptr: Optional[torch.tensor] = None + batch_ptr: Optional[torch.Tensor] = None + token_chunk_offset_ptr: Optional[torch.Tensor] = None class ShortConvAttentionMetadataBuilder( @@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec def build(self, common_prefix_len: int, @@ -58,8 +58,9 @@ class ShortConvAttentionMetadataBuilder( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) has_initial_states = None if num_prefills > 0: #[batch,] @@ -78,4 +79,4 @@ class ShortConvAttentionMetadataBuilder( has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, ) - return attn_metadata \ No newline at end of file + return attn_metadata diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index b96d957a150b5..10238f36455d2 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -165,7 +165,8 @@ class TreeAttentionMetadataBuilder( vllm_config: VllmConfig, device: torch.device, ): - self.kv_cache_spec = kv_cache_spec + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size spec_config = vllm_config.speculative_config diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index a37a7f6811ef9..784912a122f68 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -7,7 +7,6 @@ from typing import ClassVar, Optional import torch -from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) @@ -16,6 +15,8 @@ from vllm.attention.ops.chunked_prefill_paged_decode import ( from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym) from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionCGSupport, @@ -23,6 +24,11 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + logger = init_logger(__name__) @@ -62,9 +68,9 @@ class TritonAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.device = device + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( @@ -198,6 +204,9 @@ def use_aiter_unified_attention() -> bool: class TritonAttentionImpl(AttentionImpl): + def fused_output_quant_supported(self, quant_key: QuantKey): + return quant_key == kFp8StaticTensorSym + def __init__( self, num_heads: int, @@ -293,9 +302,9 @@ class TritonAttentionImpl(AttentionImpl): """ assert output is not None, "Output tensor must be provided." - if output_scale is not None or output_block_scale is not None: + if output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" + "fused block_scale output quantization is not yet supported" " for TritonAttentionImpl") if attn_metadata is None: @@ -337,7 +346,7 @@ class TritonAttentionImpl(AttentionImpl): layer._v_scale, ) else: - torch.ops._C_cache_ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash( key, value, key_cache, @@ -352,11 +361,12 @@ class TritonAttentionImpl(AttentionImpl): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape - assert layer._q_scale == 1.0, \ + assert layer._q_scale_float == 1.0, \ "A non 1.0 q_scale is not currently supported." - if not current_platform.is_rocm(): - # Skip Q quantization on ROCm, since dequantizing back to - # f32 in the attention kernel is not supported. + if current_platform.is_cuda(): + # Skip Q quantization on ROCm and XPU, enable this on cuda + # only, since dequantizing back to f32 in the attention kernel + # is not supported. query, _ = ops.scaled_fp8_quant( query.reshape( (num_tokens, num_heads * head_size)).contiguous(), @@ -389,6 +399,7 @@ class TritonAttentionImpl(AttentionImpl): alibi_slopes=self.alibi_slopes, sliding_window=self.sliding_window[0], sm_scale=self.scale, + output_scale=output_scale, sinks=self.sinks, ) @@ -414,6 +425,6 @@ class TritonAttentionImpl(AttentionImpl): k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), sinks=self.sinks, - ) + output_scale=output_scale) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 011a90ece01bd..63326d19194f0 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -5,8 +5,8 @@ import enum import functools from abc import abstractmethod from dataclasses import dataclass, fields, make_dataclass -from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol, - TypeVar) +from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional, + Protocol, TypeVar, Union, get_args) import numpy as np import torch @@ -28,9 +28,15 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.ubatch_utils import UBatchSlice logger = init_logger(__name__) -_KV_CACHE_LAYOUT_OVERRIDE = None +KVCacheLayoutType = Literal["NHD", "HND"] +_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None + + +def is_valid_kv_cache_layout(value: str) -> bool: + return value in get_args(KVCacheLayoutType) @dataclass @@ -72,11 +78,8 @@ class CommonAttentionMetadata: logits_indices_padded: Optional[torch.Tensor] = None num_logits_indices: Optional[int] = None - -@dataclass -class UbatchSlice: - request_slice: slice - token_slice: slice + # Needed by CrossAttentionBuilder + encoder_seq_lens: Optional[np.ndarray] = None def slice_query_start_locs( @@ -95,7 +98,7 @@ def slice_query_start_locs( def _make_metadata_with_slice( - ubatch_slice: UbatchSlice, + ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: """ This function creates a new CommonAttentionMetadata that corresponds to @@ -125,6 +128,11 @@ def _make_metadata_with_slice( torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()) + # This is to account for the case where we are in a dummy + # run and query_start_loc_cpu is full of 0s + if max_query_len == 0: + max_query_len = attn_metadata.max_query_len + block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] @@ -144,12 +152,12 @@ def _make_metadata_with_slice( def split_attn_metadata( - ubatch_slices: list[UbatchSlice], + ubatch_slices: list[UBatchSlice], common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ Creates a new CommonAttentionMetadata instance that corresponds to the - requests for each UbatchSlice in ubatch_slices. + requests for each UBatchSlice in ubatch_slices. Note: This function does not modify common_attn_metadata """ @@ -193,6 +201,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.kv_cache_spec = kv_cache_spec + self.layer_names = layer_names + self.vllm_config = vllm_config + self.device = device @abstractmethod def build(self, @@ -290,12 +301,13 @@ def get_kv_cache_layout(): if cache_layout is None: cache_layout = get_kv_connector_cache_layout() else: + assert is_valid_kv_cache_layout(cache_layout) logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) return cache_layout -def set_kv_cache_layout(cache_layout: str): +def set_kv_cache_layout(cache_layout: KVCacheLayoutType): global _KV_CACHE_LAYOUT_OVERRIDE _KV_CACHE_LAYOUT_OVERRIDE = cache_layout @@ -542,7 +554,14 @@ def make_local_attention_virtual_batches( 1) batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), local_blocks * pages_per_local_batch) - block_table_local = block_table[batch_indices, block_indices]\ + + # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance + # regression when using numpy arrays (batch and block indices) to index into + # torch tensor (block_table). As a workaround, convert numpy arrays to torch + # tensor first, which recovers perf. + batch_indices_torch = torch.from_numpy(batch_indices) + block_indices_torch = torch.from_numpy(block_indices) + block_table_local = block_table[batch_indices_torch, block_indices_torch]\ .view(virtual_batches, -1) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) @@ -709,7 +728,7 @@ def reorder_batch_to_split_decodes_and_prefills( for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, + # for now treat 1 scheduled token as "decode" even if it's not, # we should update this to something like < 8 in the future but # currently the TritonMLA._forward_decode only supports # num_tokens = 1 diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 7f888c1135743..a6ca334912353 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, ClassVar, Optional import torch @@ -197,6 +197,8 @@ class XFormersAttentionMetadata: class XFormersAttentionMetadataBuilder( AttentionMetadataBuilder[XFormersAttentionMetadata]): + reorder_batch_threshold: ClassVar[int] = 1 + def __init__( self, kv_cache_spec: AttentionSpec, @@ -204,17 +206,19 @@ class XFormersAttentionMetadataBuilder( vllm_config: VllmConfig, device: torch.device, ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + assert XFORMERS_AVAILABLE - self.kv_cache_spec = kv_cache_spec self.block_size = kv_cache_spec.block_size self._num_decodes = 0 self._num_decode_tokens = 0 def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) + return reorder_batch_to_split_decodes_and_prefills( + input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) def build( self, @@ -223,8 +227,9 @@ class XFormersAttentionMetadataBuilder( fast_build: bool = False, ) -> XFormersAttentionMetadata: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index b537cac8e1d72..d1e1c1c8d0382 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - FreeKVCacheBlockQueue, KVCacheBlock) + ExternalBlockHash, + FreeKVCacheBlockQueue, KVCacheBlock, + get_block_hash, + make_block_hash_with_group_id, + maybe_convert_block_hash) from vllm.v1.request import Request logger = init_logger(__name__) @@ -84,8 +88,10 @@ class BlockPool: """ cached_blocks = [] for group_id in kv_cache_group_ids: + block_hash_with_group_id = make_block_hash_with_group_id( + block_hash, group_id) cached_blocks_one_group = self.cached_block_hash_to_block.get( - BlockHashWithGroupId(block_hash, group_id)) + block_hash_with_group_id) if not cached_blocks_one_group: return None first_block = next(iter(cached_blocks_one_group.values())) @@ -124,28 +130,29 @@ class BlockPool: assert len(request.block_hashes) >= num_full_blocks new_block_hashes = request.block_hashes[num_cached_blocks:] - new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events - else None) + new_hashes: Optional[list[ExternalBlockHash]] = ( + [] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None block_hash = new_block_hashes[i] # Update and added the full block to the cache. - block_hash_with_group_id = BlockHashWithGroupId( + block_hash_with_group_id = make_block_hash_with_group_id( block_hash, kv_cache_group_id) blk.block_hash = block_hash_with_group_id self.cached_block_hash_to_block[block_hash_with_group_id][ blk.block_id] = blk if new_hashes is not None: - new_hashes.append(block_hash.hash_value) + new_hashes.append(maybe_convert_block_hash(block_hash)) if self.enable_kv_cache_events: if num_cached_blocks == 0: - parent_block_hash = None + parent_block_hash: Optional[ExternalBlockHash] = None else: parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None - parent_block_hash = parent_block.block_hash.get_hash_value() + parent_block_hash = maybe_convert_block_hash( + get_block_hash(parent_block.block_hash)) self.kv_event_queue.append( BlockStored( @@ -220,7 +227,9 @@ class BlockPool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.get_hash_value()], + BlockRemoved(block_hashes=[ + maybe_convert_block_hash(get_block_hash(block_hash)) + ], medium=MEDIUM_GPU)) return True diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index bd2ec036834b2..eadea15a2e5e3 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -86,7 +86,7 @@ class EncoderCacheManager: Returns: True if the encoder output for this input is already cached """ - mm_hash = request.mm_hashes[input_id] + mm_hash = request.mm_features[input_id].identifier # Not cached at all if mm_hash not in self.cached: return False @@ -167,7 +167,7 @@ class EncoderCacheManager: This method assumes can_allocate() returned True for the same input. """ - mm_hash = request.mm_hashes[input_id] + mm_hash = request.mm_features[input_id].identifier request_id = request.request_id if mm_hash not in self.cached: self.cached[mm_hash] = set() @@ -193,8 +193,8 @@ class EncoderCacheManager: """ return { input_id - for input_id in range(len(request.mm_hashes)) - if request.mm_hashes[input_id] in self.cached + for input_id in range(len(request.mm_features)) + if request.mm_features[input_id].identifier in self.cached } def free_encoder_input(self, request: Request, input_id: int) -> None: @@ -208,7 +208,7 @@ class EncoderCacheManager: `can_allocate`). """ req_id = request.request_id - mm_hash = request.mm_hashes[input_id] + mm_hash = request.mm_features[input_id].identifier # The mm_hash not in cache or the req_id set is empty if not self.cached.get(mm_hash, None): return diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 9421341f990c8..86771060c4099 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -24,6 +24,7 @@ class KVCacheCoordinator(ABC): use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, + dcp_world_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len @@ -39,6 +40,7 @@ class KVCacheCoordinator(ABC): kv_cache_spec=kv_cache_group.kv_cache_spec, block_pool=self.block_pool, kv_cache_group_id=i, + dcp_world_size=dcp_world_size, ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) @@ -197,9 +199,14 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): """ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, False, - enable_kv_cache_events) + use_eagle: bool, enable_kv_cache_events: bool, + dcp_world_size: int): + super().__init__(kv_cache_config, + max_model_len, + use_eagle, + False, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) self.num_single_type_manager = len(self.single_type_managers) def get_num_common_prefix_blocks(self, request_id: str, @@ -225,12 +232,19 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events) + enable_kv_cache_events: bool, dcp_world_size: int): + super().__init__(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ 0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size + self.dcp_world_size = dcp_world_size + if dcp_world_size > 1: + self.block_size *= dcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group") @@ -246,6 +260,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): block_pool=self.block_pool, kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, + dcp_world_size=self.dcp_world_size, ) return hit_blocks, len(hit_blocks[0]) * self.block_size @@ -261,9 +276,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events) + enable_kv_cache_events: bool, dcp_world_size: int): + super().__init__(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) + assert dcp_world_size == 1, "DCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: @@ -394,17 +414,27 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): return hit_blocks, hit_length -def get_kv_cache_coordinator( - kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, - enable_caching: bool, - enable_kv_cache_events: bool) -> KVCacheCoordinator: +def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig, + max_model_len: int, use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int) -> KVCacheCoordinator: if not enable_caching: - return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, + return KVCacheCoordinatorNoPrefixCache(kv_cache_config, + max_model_len, use_eagle, - enable_kv_cache_events) + enable_kv_cache_events, + dcp_world_size=dcp_world_size) if len(kv_cache_config.kv_cache_groups) == 1: - return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, - use_eagle, enable_caching, - enable_kv_cache_events) - return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events) + return UnitaryKVCacheCoordinator(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) + return HybridKVCacheCoordinator(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 87a11fe58a048..401327f727a4a 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -24,8 +24,9 @@ class KVCacheBlocks: """ blocks: tuple[list[KVCacheBlock], ...] """ - blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. - We don't use block of tokens as the outer dimension because it assumes all + `blocks[i][j]` refers to the i-th kv_cache_group + and the j-th block of tokens.We don't use block of + tokens as the outer dimension because it assumes all kv_cache_groups have the same number of blocks, which is true for now but will be broken if we want to give different block_size to different kv_cache_groups in the future. @@ -91,6 +92,7 @@ class KVCacheManager: use_eagle: bool = False, log_stats: bool = False, enable_kv_cache_events: bool = False, + dcp_world_size: int = 1, ) -> None: self.max_model_len = max_model_len @@ -109,12 +111,20 @@ class KVCacheManager: self.block_size = kv_cache_config.kv_cache_groups[ 0].kv_cache_spec.block_size + if dcp_world_size > 1: + assert len(kv_cache_config.kv_cache_groups) == 1 + # Note(hc): need revisit. When both DCP and any future + # PCP are enabled, the block_size may need to be scaled + # by a factor of dcp_size × pcp_size? + self.block_size *= dcp_world_size + self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, use_eagle=self.use_eagle, enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, + dcp_world_size=dcp_world_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 590baa6208d07..9fab36aba91b3 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -5,12 +5,13 @@ import os from collections import defaultdict, deque from collections.abc import Iterable, Sequence -from dataclasses import astuple, dataclass -from typing import Any, Callable, NamedTuple, Optional +from dataclasses import dataclass +from typing import Any, Callable, NewType, Optional, Union +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit +from vllm.utils import GiB_bytes, cdiv, sha256_cbor from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, @@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request -logger = init_logger(__name__) +# BlockHash represents the hash of a single KV-cache block used for +# prefix caching. Treating it as a distinct type from ``bytes`` helps +# catch accidental misuse when passing around raw byte strings. +BlockHash = NewType("BlockHash", bytes) + +# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID. +# It is represented as raw bytes for compactness and efficiency. The helper +# functions below pack/unpack the ``BlockHash`` and group id into/from the key. +BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes) + +# ExternalBlockHash is used for reproducible prefix-cache block hashing. +# It's a union of ``bytes`` and ``int`` to keep backward compatibility +# after we default block hashing to use sha256 bytes. +ExternalBlockHash = Union[bytes, int] -class BlockHash(NamedTuple): - """Hash value of a block (int), the token IDs in the block, and extra keys. - We keep a tuple of token IDs and extra keys to reduce the likelihood of - hash collisions when the hash value is the same. By using SHA256 however, - hash collisions are practically impossible. +def make_block_hash_with_group_id(block_hash: BlockHash, + group_id: int) -> BlockHashWithGroupId: + """Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``. + + The group id is encoded using 4 bytes in big-endian order and appended to + the block hash bytes. This representation avoids creating tuples while + still allowing us to recover both components when needed. """ - # Hash value of the block in an integer. - hash_value: int - # Token IDs in the block. - token_ids: tuple[int, ...] - # Extra keys for the block. - extra_keys: Optional[Any] = None + return BlockHashWithGroupId(block_hash + + group_id.to_bytes(4, "big", signed=False)) -class BlockHashWithGroupId(NamedTuple): - # The hash value for the contents (e.g., token_ids) of a block without group - # ID. The value is the same for blocks representing the same tokens but for - # different groups. - block_hash: BlockHash - # The KV cache group ID. - group_id: int +def get_block_hash(key: BlockHashWithGroupId) -> BlockHash: + """Extract the ``BlockHash`` from a ``BlockHashWithGroupId``.""" + return BlockHash(key[:-4]) - def get_hash_value(self) -> int: - return self.block_hash.hash_value +def get_group_id(key: BlockHashWithGroupId) -> int: + """Extract the group id from a ``BlockHashWithGroupId``.""" + return int.from_bytes(key[-4:], "big", signed=False) + + +def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash: + if not envs.VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: + return hash_bytes + return int.from_bytes(hash_bytes, byteorder="big") & ((1 << 64) - 1) + + +logger = init_logger(__name__) # The hash seed for the first block of any prefix block sequence. # # We use a random value to avoid hash collisions or PYTHONHASHSEED environment -# variable if set such that processes can share the seed if needed. -# This aligns with the behavior of Python's hash() function, which also uses -# a random seed if PYTHONHASHSEED is not set. +# variable if set such that processes can share the seed if needed. This aligns +# with the behavior of Python's hash() function, which also uses a random seed +# if PYTHONHASHSEED is not set. # # The function `init_none_hash` initializes this variable globally. -NONE_HASH: int +NONE_HASH: BlockHash -def init_none_hash(hash_fn: Callable): +def init_none_hash(hash_fn: Callable[[Any], bytes]): global NONE_HASH hash_seed = os.getenv("PYTHONHASHSEED") - if hash_seed is None and hash_fn is sha256_cbor_64bit: + if hash_seed is None and hash_fn is sha256_cbor: logger.warning( "PYTHONHASHSEED is not set. This will lead to non-reproducible " - "block-hashes when using sha256_cbor_64bit as the hash function." + "block-hashes when using sha256_cbor as the hash function." "Consider setting PYTHONHASHSEED to a fixed value for " "reproducibility.") - NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big") - if hash_seed is None else hash_fn(hash_seed)) + if hash_seed is None: + NONE_HASH = BlockHash(os.urandom(32)) + else: + NONE_HASH = BlockHash(hash_fn(hash_seed)) class PrefixCachingMetrics: @@ -96,8 +116,8 @@ class PrefixCachingMetrics: This function is called with information gathered when new requests are being scheduled and are looking for computed blocks. - When there are more than `interval` requests, the oldest set of - requests are removed from the metrics. + When there are more than `max_recent_requests` requests, the oldest set + of requests are removed from the metrics. Args: stats: The prefix cache stats. @@ -142,8 +162,8 @@ class KVCacheBlock: block_id: int # Reference count. ref_cnt: int = 0 - # The hash of the block composed of (block hash, tuple of token IDs). - # It is only available when the block is full. + # The hash key (block hash + group id) of the block, only available + # when the block is full and cached. _block_hash: Optional[BlockHashWithGroupId] = None # Used to construct a doubly linked list for free blocks. @@ -177,7 +197,7 @@ class KVCacheBlock: if self.next_free_block else None) return (f"KVCacheBlock(block_id={self.block_id}, " f"ref_cnt={self.ref_cnt}, " - f"_block_hash={self._block_hash}, " + f"_block_hash={self._block_hash!r}, " f"prev_free_block={prev_block_id}, " f"next_free_block={next_block_id})") @@ -217,7 +237,7 @@ class FreeKVCacheBlockQueue: # Create a fake head and a tail block for the doubly linked list to # reduce branching in the code # - # The implementation garenteed that the fake head and tail + # The implementation guaranteed that the fake head and tail # are NEVER got popped, so we could safely assume each real blocks # in the queue has prev and next blocks. self.fake_free_list_head = KVCacheBlock(block_id=-1) @@ -350,7 +370,6 @@ class FreeKVCacheBlockQueue: """ if len(blocks) == 0: return - self.num_free_blocks += len(blocks) last_block = self.fake_free_list_tail.prev_free_block assert last_block is not None, ( @@ -365,6 +384,8 @@ class FreeKVCacheBlockQueue: last_block.next_free_block = self.fake_free_list_tail self.fake_free_list_tail.prev_free_block = last_block + self.num_free_blocks += len(blocks) + def get_all_free_blocks(self) -> list[KVCacheBlock]: """Get all free blocks in the free list. Mainly used for testing. @@ -398,9 +419,9 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return bool(request.mm_hashes) or (request.lora_request - is not None) or (request.cache_salt - is not None) + return bool(request.mm_features) or (request.lora_request + is not None) or (request.cache_salt + is not None) def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, @@ -422,32 +443,28 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, """ extra_keys: list[Any] = [] - mm_positions, mm_hashes = request.mm_positions, request.mm_hashes - if not mm_positions: + mm_features = request.mm_features + if not mm_features: return extra_keys, start_mm_idx - if mm_positions and len(mm_positions) != len(mm_hashes): - raise ValueError( - "The number of multi-modal positions and hashes must match. This " - "is likely because you did not enable MM hashing. " - "Please set `mm_processor_cache_gb > 0`.") - - # Note that we assume mm_positions is sorted by offset. + # Note that we assume mm_features are sorted by mm_position.offset. # We do not need to check all mm inputs if the start token index is out of # range. This usually happens in the late prefill phase and decoding phase. - if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx: + last_pos = mm_features[-1].mm_position + if last_pos.offset + last_pos.length < start_token_idx: return extra_keys, start_mm_idx # Support start_mm_idx == -1 to indicate the last mm input. if start_mm_idx < 0: - assert -start_mm_idx <= len(mm_positions) - start_mm_idx = len(mm_positions) + start_mm_idx + assert -start_mm_idx <= len(mm_features) + start_mm_idx = len(mm_features) + start_mm_idx curr_mm_idx = start_mm_idx - while mm_positions and curr_mm_idx < len(mm_positions): - assert mm_hashes[curr_mm_idx] is not None - offset = mm_positions[curr_mm_idx].offset - length = mm_positions[curr_mm_idx].length + while mm_features and curr_mm_idx < len(mm_features): + mm_feature = mm_features[curr_mm_idx] + assert mm_feature.identifier is not None + offset = mm_feature.mm_position.offset + length = mm_feature.mm_position.length if end_token_idx > offset: if start_token_idx > offset + length: # This block has passed the current mm input. @@ -455,7 +472,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, continue # The block contains the current mm input. - extra_keys.append(mm_hashes[curr_mm_idx]) + extra_keys.append(mm_feature.identifier) if end_token_idx >= offset + length: # If this block contains the end of the current mm input, @@ -517,15 +534,14 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable, - parent_block_hash: Optional[int], + hash_function: Callable[[Any], bytes], + parent_block_hash: Optional[BlockHash], curr_block_token_ids: Sequence[int], extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing hash values for the same block contents. - Args: hash_function: The hash function used to compute block hash. parent_block_hash: The hash of the parent block. None @@ -533,7 +549,6 @@ def hash_block_tokens( curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. extra_keys: Extra keys for the block. - Returns: The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. @@ -544,26 +559,16 @@ def hash_block_tokens( curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( hash_function( - (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), - curr_block_token_ids_tuple, extra_keys) + (parent_block_hash, curr_block_token_ids_tuple, extra_keys))) def get_request_block_hasher( block_size: int, - caching_hash_fn: Callable[[Any], - int]) -> Callable[[Request], list[BlockHash]]: + caching_hash_fn: Callable[[Any], bytes], +) -> Callable[[Request], list[BlockHash]]: """ Returns a function which computes the list of un-computed block hashes - of a request. - - Each request holds a list of its block hashes (request.block_hashes). - When a request is created, it calls the below function to compute - the hashes of all full blocks of the request's initial tokens. - The hashes are then stored in request.block_hashes. - Later, whenever new tokens are appended to the request, it calls - the below function again to compute any new full blocks of tokens. - The returned new hashes are appended to request.block_hashes. - """ + of a request.""" def request_block_hasher(request: Request) -> list[BlockHash]: start_token_idx = len(request.block_hashes) * block_size @@ -577,8 +582,8 @@ def get_request_block_hasher( # last mm input. curr_mm_idx = -1 - prev_block_hash_value = request.block_hashes[-1].hash_value \ - if request.block_hashes else None + prev_block_hash_value = (request.block_hashes[-1] + if request.block_hashes else None) new_block_hashes: list[BlockHash] = [] while True: end_token_idx = start_token_idx + block_size @@ -598,7 +603,7 @@ def get_request_block_hasher( new_block_hashes.append(block_hash) start_token_idx += block_size - prev_block_hash_value = block_hash.hash_value + prev_block_hash_value = block_hash return new_block_hashes @@ -749,6 +754,10 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: True if all layers have the same type, False otherwise. """ + if not kv_cache_spec: + # Encoder-only models do not have KV cache, kv_cache_type can be + # regarded as uniform. + return True try: kv_cache_spec_values = list(kv_cache_spec.values()) _ = kv_cache_spec_values[0].merge(kv_cache_spec_values) @@ -807,53 +816,21 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: return page_sizes.pop() -def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def _get_kv_cache_groups_uniform_type( + kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: """ Generates the KV cache configuration for a model with one type of KV cache. Divide the available memory equally among all layers. Args: - vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of each attention layer in the model - available_memory: Memory available for KV cache in bytes. + kv_cache_specs: The kv cache spec of each attention layer in the model Returns: - The generated KVCacheConfig + The generated KVCacheGroupSpecs """ - page_size = get_uniform_page_size(kv_cache_spec) - num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec), - available_memory, page_size) - - per_layer_size = page_size * num_blocks - # All layers have the same KV cache spec, so we create one kv cache group - # for all layers. - grouped_layer_names = [list(kv_cache_spec.keys())] - - # Each layer uses a separate Tensor to store its KV cache. - kv_cache_tensors = [ - KVCacheTensor(size=per_layer_size, shared_by=[layer_name]) - for layer_name in kv_cache_spec - ] - - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=kv_cache_tensors, - kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, - grouped_layer_names), - ) - - num_tokens = num_blocks * vllm_config.cache_config.block_size - num_tokens_str = f"{num_tokens:,}" - logger.info("GPU KV cache size: %s tokens", num_tokens_str) - max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" - max_concurrency = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, max_concurrency) - return kv_cache_config + return create_kv_cache_group_specs(kv_cache_specs, + [list(kv_cache_specs.keys())]) def is_kv_cache_page_size_uniform( @@ -878,11 +855,10 @@ def is_kv_cache_type_attention_free( return not kv_cache_spec -def _get_kv_cache_config_uniform_page_size( - vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def _get_kv_cache_groups_uniform_page_size( + kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for hybrid models with multiple + Generates the KV cache groups for hybrid models with multiple attention types but still with a uniform page size (physical memory per block per layer) for all layers. @@ -939,11 +915,9 @@ def _get_kv_cache_config_uniform_page_size( memory per block is the same for all groups. Args: - vllm_config: The global VllmConfig kv_cache_spec: The KVCacheSpec of each attention layer in the model - available_memory: Memory available for KV cache in bytes. Returns: - The generated KVCacheConfig + The generated KVCacheGroupSpecs """ # Group all layers by kv_cache_spec. # E.g., 2 full attention layers and 3 sliding window attention layers, @@ -956,7 +930,7 @@ def _get_kv_cache_config_uniform_page_size( # group identical. Add padding to the last group of each type if necessary. # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) # split to 3 groups with 2 layers each: - # (full.0, full.1), (sw.0, sw.1), (sw.2, padding). + # (full.0, full.1), (sw.0, sw.2), (sw.1, padding). # FIXME(Chen): At the moment of writing this code (2025-06-02), all # open-source hybrid model follows a n:1 pattern between different attention # types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and @@ -974,19 +948,60 @@ def _get_kv_cache_config_uniform_page_size( num_padding_layers, num_padding_layers / len(layers) * 100, ) - for i in range(0, len(layers), group_size): - grouped_layers.append(layers[i:i + group_size]) - kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec, - grouped_layers) + num_groups = cdiv(len(layers), group_size) + # In PP case, say if we have + # - stage 0: full.0, sw.0, sw.1 + # - stage 1: full.1, sw.2, sw.3 + # We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3) + # It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because + # the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group) + # and it will be padded to (full.0, padding), (sw.0, sw.1), + # (padding, padding) to ensure the number of layers in each group is + # the same and will cause memory waste. + # To avoid this, we assign layers[i::num_groups] to the i-th group + # instead of layers[i * group_size: (i + 1) * group_size] + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) + return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) + + +def get_kv_cache_config_from_groups(vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], + kv_cache_specs: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generate the KV cache configuration from the KV cache groups and spec + of each layer. + + Args: + vllm_config: The global VllmConfig + kv_cache_groups: The KV cache groups + kv_cache_specs: The KV cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes + Returns: + The generated KVCacheConfig + """ + if len(kv_cache_groups) == 0: + # Attention free models do not have KV cache. + # Return num_blocks=1 as BlockPool always needs a null_block. + return KVCacheConfig( + num_blocks=1, + kv_cache_tensors=[], + kv_cache_groups=kv_cache_groups, + ) # Determine how model runners should initialize the KV cache tensors. # We will have group_size memory pools, each is shared by one layer from # each group. As layers of different groups have different block table, # they will use different parts of the shared Tensor. - # The memory layout in the example will be: - # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 - # full.1, sw.1: share another Tensor with size=available_memory//2 - page_size = get_uniform_page_size(kv_cache_spec) + # The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2), + # (sw.1, padding) will be: (group_size = 2) + # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 + # full.1, sw.2: share another Tensor with size=available_memory//2 + group_size = max(len(group.layer_names) for group in kv_cache_groups) + + page_size = get_uniform_page_size(kv_cache_specs) + assert group_size > 0, "group_size must be greater than 0" num_blocks = get_num_blocks(vllm_config, group_size, available_memory, page_size) per_memory_pool_size = page_size * num_blocks @@ -994,8 +1009,8 @@ def _get_kv_cache_config_uniform_page_size( for i in range(group_size): shared_by = [] for j in range(len(kv_cache_groups)): - if i < len(grouped_layers[j]): - shared_by.append(grouped_layers[j][i]) + if i < len(kv_cache_groups[j].layer_names): + shared_by.append(kv_cache_groups[j].layer_names[i]) kv_cache_tensors.append( KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) @@ -1009,7 +1024,12 @@ def _get_kv_cache_config_uniform_page_size( [group.kv_cache_spec.block_size for group in kv_cache_groups]) # Print the KV cache size and maximum concurrency. - num_tokens = num_blocks // len(grouped_layers) * min_block_size + num_tokens = num_blocks // len(kv_cache_groups) * min_block_size + if vllm_config.parallel_config.decode_context_parallel_size > 1: + num_tokens *= vllm_config.parallel_config.decode_context_parallel_size + logger.info( + "Multiplying the GPU KV cache size by the dcp_world_size %d.", + vllm_config.parallel_config.decode_context_parallel_size) num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" @@ -1020,10 +1040,6 @@ def _get_kv_cache_config_uniform_page_size( return kv_cache_config -def _get_kv_cache_config_attention_free() -> KVCacheConfig: - return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[]) - - def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ This function tries to convert the KV cache specs to one type if the model @@ -1077,72 +1093,112 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): "convert the KV cache specs to one unified type.") -def get_kv_cache_config( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int, -) -> KVCacheConfig: +def get_kv_cache_groups( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for a model. + Split the layers in the model into groups with the same KV cache spec. Args: vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of each attention layer in the model - available_memory: Memory available for KV cache in bytes. Returns: - The generated KVCacheConfigs + The generated KVCacheGroups """ - check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: unify_hybrid_kv_cache_specs(kv_cache_spec) if is_kv_cache_type_attention_free(kv_cache_spec): - # This returns a kv_cache config with 0 kv_cache groups and 1 block - # to allow for the KVCache manager to handle attention free models. - return _get_kv_cache_config_attention_free() + # This returns an empty list to allow for the KVCacheManager to handle + # attention free models. + return [] elif is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. - return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, - available_memory) + return _get_kv_cache_groups_uniform_type(kv_cache_spec) elif is_kv_cache_page_size_uniform(kv_cache_spec): # Model contains multiple attention types, but KV cache of all layers # have the same physical memory per block per layer. Split the layers # into groups with the same number of layers, and thus same total page # size. - return _get_kv_cache_config_uniform_page_size(vllm_config, - kv_cache_spec, - available_memory) + return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) raise NotImplementedError -def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): +def get_kv_cache_configs(vllm_config: VllmConfig, + kv_cache_specs: list[dict[str, KVCacheSpec]], + available_memory: list[int]) -> list[KVCacheConfig]: """ - Make the KV cache configurations for each worker consistent, so that all - workers can be controlled by the same KVCacheManager. - This function verifies that the layer group of each worker are the same, - and changes the num_blocks of each worker to the smallest among all workers. + Generates the KV cache configurations for a model. + Since we use a shared centralized controller for all workers, we need the + `kv_cache_config` to be consistent across all workers to make sure + the KV cache allocation can be applied to all workers. However, different + workers may have different memory available, and different type of layers + (when pipeline parallel is enabled). To handle the difference between + workers, the current implementation is: + 1. Merge the KV cache specs of all workers to get the KVCacheSpecs for + the whole model. + 2. Generate the KV cache groups based on the layer ratio of the whole model. + 3. Generate the KV cache configs for each worker based on the KV cache + grouping strategy. (This is reasonable because the layer ratio of + different PP stages are similar.) + 4. Change the num_blocks of each worker to the smallest among all workers. Args: - kv_cache_configs: The KV cache configurations for each worker. Will be - in-place modified to make them consistent. + vllm_config: The global VllmConfig + kv_cache_specs: List of dict[layer_name, KVCacheSpec] for each worker. + available_memory: Memory available for KV cache in bytes for each + worker. + + Returns: + The generated KVCacheConfigs for each worker. """ - # Sort the kv cache groups by their KV cache spec. - # This can avoid the inconsistency caused by the order of groups. - for kv_cache_config in kv_cache_configs: - kv_cache_config.kv_cache_groups.sort(key=lambda x: (type( - x.kv_cache_spec).__name__, astuple(x.kv_cache_spec))) + # Check if the available memory is enough for each worker. + for kv_cache_spec_one_worker, available_memory_one_worker in zip( + kv_cache_specs, available_memory): + check_enough_kv_cache_memory(vllm_config, kv_cache_spec_one_worker, + available_memory_one_worker) - # Verify that the groups of each rank are the same. - for kv_cache_config in kv_cache_configs[1:]: - for group_rank_0, group_rank_i in zip( - kv_cache_configs[0].kv_cache_groups, - kv_cache_config.kv_cache_groups): - assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec + # Merge the KV cache specs of all workers. Different PP stages may have + # different layer names, and different TP ranks of the same PP stage should + # have the same KV cache spec. + merged_kv_cache_specs: dict[str, KVCacheSpec] = {} + for kv_cache_spec_one_worker in kv_cache_specs: + for layer_name, layer_spec in kv_cache_spec_one_worker.items(): + if layer_name not in merged_kv_cache_specs: + merged_kv_cache_specs[layer_name] = layer_spec + else: + assert merged_kv_cache_specs[layer_name] == layer_spec, ( + "The KV cache specs for the same layer are different " + "across workers. This is not supported yet.") + global_kv_cache_groups = get_kv_cache_groups(vllm_config, + merged_kv_cache_specs) + + kv_cache_configs: list[KVCacheConfig] = [] + for kv_cache_spec_one_worker, available_memory_one_worker in zip( + kv_cache_specs, available_memory): + kv_cache_groups_one_worker: list[KVCacheGroupSpec] = [] + for group in global_kv_cache_groups: + group_layer_names_one_worker = [ + layer_name for layer_name in group.layer_names + if layer_name in kv_cache_spec_one_worker + ] + kv_cache_groups_one_worker.append( + KVCacheGroupSpec(group_layer_names_one_worker, + group.kv_cache_spec)) + assert sum( + len(group.layer_names) for group in + kv_cache_groups_one_worker) == len(kv_cache_spec_one_worker), ( + "Some layers are not assigned to any group.") + kv_cache_configs.append( + get_kv_cache_config_from_groups(vllm_config, + kv_cache_groups_one_worker, + kv_cache_spec_one_worker, + available_memory_one_worker)) # Change the num_blocks of each rank to the smallest among all ranks. We # do not need to shrink the tensor size because it is valid to only use the diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b5cd6c5c8af51..3ec5b91bf2860 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -6,6 +6,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, Optional +from vllm._bc_linter import bc_linter_include + if TYPE_CHECKING: import numpy as np import numpy.typing as npt @@ -13,20 +15,19 @@ if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata) from vllm.lora.request import LoRARequest - from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange + from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.request import Request +@bc_linter_include @dataclass class NewRequestData: req_id: str prompt_token_ids: list[int] - mm_kwargs: list[MultiModalKwargsItem] - mm_hashes: list[str] - mm_positions: list[PlaceholderRange] + mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] block_ids: tuple[list[int], ...] @@ -42,9 +43,7 @@ class NewRequestData: return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, - mm_kwargs=request.mm_kwargs, - mm_hashes=request.mm_hashes, - mm_positions=request.mm_positions, + mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, block_ids=block_ids, @@ -56,9 +55,7 @@ class NewRequestData: return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids={self.prompt_token_ids}," - f"mm_kwargs={self.mm_kwargs}," - f"mm_hashes={self.mm_hashes}," - f"mm_positions={self.mm_positions}," + f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," @@ -70,9 +67,7 @@ class NewRequestData: return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids_len={len(self.prompt_token_ids)}," - f"mm_kwargs={self.mm_kwargs}," - f"mm_hashes={self.mm_hashes}," - f"mm_positions={self.mm_positions}," + f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," @@ -80,6 +75,7 @@ class NewRequestData: ")") +@bc_linter_include @dataclass class CachedRequestData: @@ -109,6 +105,7 @@ class CachedRequestData: ) +@bc_linter_include @dataclass class SchedulerOutput: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8322fa7335b69..85ca858ad7bd6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -100,6 +100,15 @@ class Scheduler(SchedulerInterface): self.block_size = self.cache_config.block_size + self.dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): The scheduler’s block_size must be multiplied + # by dcp_world_size, since block hashes are computed on the + # original full token sequence at a granularity of + # original_block_size × dcp_world_size. + if self.dcp_world_size > 1: + self.block_size *= self.dcp_world_size + # req_id -> Request self.requests: dict[str, Request] = {} # Scheduling policy @@ -135,8 +144,8 @@ class Scheduler(SchedulerInterface): ) # NOTE(woosuk): Here, "encoder" includes the vision encoder (and - # projector if needed). Currently, we assume that the encoder also - # has the Transformer architecture (e.g., ViT). + # projector if needed) for MM models as well as encoder-decoder + # transformers. self.max_num_encoder_input_tokens = encoder_compute_budget # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 @@ -161,6 +170,7 @@ class Scheduler(SchedulerInterface): use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, + dcp_world_size=self.dcp_world_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 @@ -377,6 +387,14 @@ class Scheduler(SchedulerInterface): self.connector.get_num_new_matched_tokens( request, num_new_local_computed_tokens)) + if num_external_computed_tokens is None: + # The request cannot be scheduled because + # the KVConnector couldn't determine + # the number of matched tokens. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + # Total computed tokens (local + external). num_computed_tokens = (num_new_local_computed_tokens + num_external_computed_tokens) @@ -447,9 +465,8 @@ class Scheduler(SchedulerInterface): in self.vllm_config.model_config.model.lower()), ( "Whisper is the only supported " "encoder-decoder model.") - num_encoder_tokens = MULTIMODAL_REGISTRY.\ - get_encdec_max_encoder_len( - self.vllm_config.model_config) + num_encoder_tokens =\ + self.scheduler_config.max_num_encoder_input_tokens else: num_encoder_tokens = 0 @@ -718,18 +735,18 @@ class Scheduler(SchedulerInterface): if num_new_tokens == 0 or not request.has_encoder_inputs: return [], num_new_tokens, encoder_compute_budget encoder_inputs_to_schedule: list[int] = [] - mm_positions = request.mm_positions - assert mm_positions is not None - assert len(mm_positions) > 0 + mm_features = request.mm_features + assert mm_features is not None + assert len(mm_features) > 0 # NOTE: since scheduler operates on the request level (possibly with # multiple encoder inputs per request), we need to create temporary # trackers for accounting at the encoder input level. mm_hashes_to_schedule = set() num_tokens_to_schedule = 0 - for i, pos_info in enumerate(mm_positions): - start_pos = pos_info.offset - num_encoder_tokens = pos_info.length + for i, mm_feature in enumerate(mm_features): + start_pos = mm_feature.mm_position.offset + num_encoder_tokens = mm_feature.mm_position.length # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, num_computed_tokens + num_new_tokens) and @@ -757,15 +774,19 @@ class Scheduler(SchedulerInterface): # in the decoder's KV cache. continue - # The same encoder input has already been scheduled in the current - # step. - if request.mm_hashes[i] in mm_hashes_to_schedule: - continue + if not self.is_encoder_decoder: + # We are not using the encoder cache for encoder-decoder models, + # yet. + if request.mm_features[i].identifier in mm_hashes_to_schedule: + # The same encoder input has already been scheduled in the + # current step. + continue - if self.encoder_cache_manager.check_and_update_cache(request, i): - # The encoder input is already computed and cached from a - # previous step. - continue + if self.encoder_cache_manager.check_and_update_cache( + request, i): + # The encoder input is already computed and cached from a + # previous step. + continue # If no encoder input chunking is allowed, we do not want to # partially schedule a multimodal item. If the scheduled range would @@ -798,7 +819,7 @@ class Scheduler(SchedulerInterface): num_tokens_to_schedule += num_encoder_tokens encoder_compute_budget -= num_encoder_tokens - mm_hashes_to_schedule.add(request.mm_hashes[i]) + mm_hashes_to_schedule.add(request.mm_features[i].identifier) encoder_inputs_to_schedule.append(i) return ( @@ -947,9 +968,9 @@ class Scheduler(SchedulerInterface): stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, )) - else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -1026,10 +1047,16 @@ class Scheduler(SchedulerInterface): # Here, we use list(set) to avoid modifying the set while iterating # over it. for input_id in list(cached_encoder_input_ids): - mm_positions = request.mm_positions[input_id] - start_pos = mm_positions.offset - num_tokens = mm_positions.length - if start_pos + num_tokens <= request.num_computed_tokens: + mm_feature = request.mm_features[input_id] + start_pos = mm_feature.mm_position.offset + num_tokens = mm_feature.mm_position.length + if self.is_encoder_decoder and request.num_computed_tokens > 0: + # With Whisper, as soon as we've generated a single token, + # we know we're done with the encoder input. Cross Attention + # KVs have been calculated and cached already. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + elif start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input( @@ -1178,6 +1205,8 @@ class Scheduler(SchedulerInterface): def shutdown(self) -> None: if self.kv_event_publisher: self.kv_event_publisher.shutdown() + if self.connector is not None: + self.connector.shutdown() ######################################################################## # KV Connector Related Methods diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 42d3e5c68b4c8..c431843de6baa 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -10,19 +10,19 @@ from vllm.v1.request import Request, RequestStatus def remove_all(lst: list, items_to_remove: set) -> list: """Remove all items from a list that are in the items_to_remove set. - + This method optimizes for the common case of removing a single item, falling back to list comprehension for multiple items. - + Args: lst: The list to remove items from items_to_remove: Set of items to remove - + Returns: Either the modified original list (for single item removal) or a new list (for multiple item removal). Callers should use the returned value. - + Note: For single item removal, this modifies the original list in-place and returns it. For multiple items, it creates and returns a new list. diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index f6affb3dab66f..d27239164b0db 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -25,6 +25,7 @@ class SingleTypeKVCacheManager(ABC): kv_cache_spec: KVCacheSpec, block_pool: BlockPool, kv_cache_group_id: int, + dcp_world_size: int = 1, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -33,8 +34,10 @@ class SingleTypeKVCacheManager(ABC): block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. """ - self.block_size = kv_cache_spec.block_size + self.dcp_world_size = dcp_world_size + if self.dcp_world_size > 1: + self.block_size *= dcp_world_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool @@ -196,6 +199,7 @@ class SingleTypeKVCacheManager(ABC): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ Get the longest cache hit prefix of the blocks that is not longer than @@ -253,6 +257,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) @@ -260,7 +265,10 @@ class FullAttentionManager(SingleTypeKVCacheManager): "and chunked local attention groups" computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) - max_num_blocks = max_length // kv_cache_spec.block_size + block_size = kv_cache_spec.block_size + if dcp_world_size > 1: + block_size *= dcp_world_size + max_num_blocks = max_length // block_size for block_hash in itertools.islice(block_hashes, max_num_blocks): # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are @@ -310,9 +318,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( "SlidingWindowManager can only be used for sliding window groups") + assert dcp_world_size == 1, "DCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -408,6 +418,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ For chunked local attention, we need to find the longest cache hit @@ -445,6 +456,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): "chunked local attention groups") assert use_eagle is False, ("Hybrid KV cache is not supported for " + "eagle + chunked local attention.") + assert dcp_world_size == 1, "DCP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = (max_length // @@ -525,10 +537,12 @@ class MambaManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, MambaSpec), ("MambaManager can only be used for mamba groups") + assert dcp_world_size == 1, "DCP not support mamba now." # Prefix caching is not supported for mamba now. Always return empty # list. computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( @@ -545,12 +559,48 @@ class MambaManager(SingleTypeKVCacheManager): num_running_requests: int) -> int: return 0 + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[KVCacheBlock]) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks + """ + + assert isinstance(self.kv_cache_spec, MambaSpec) + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += (self.kv_cache_spec.block_size * + self.kv_cache_spec.num_speculative_blocks) + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = (num_required_blocks - len(new_computed_blocks) - + len(self.req_to_blocks[request_id])) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum( + blk.ref_cnt == 0 and not blk.is_null + for blk in new_computed_blocks) + return num_new_blocks + num_evictable_computed_blocks + def allocate_new_blocks(self, request_id: str, num_tokens: int) -> list[KVCacheBlock]: - new_blocks = super().allocate_new_blocks(request_id, num_tokens) - assert len(self.req_to_blocks[request_id]) == 1, ( - "MambaManager should only allocate 1 block for each request.") - return new_blocks + # Allocate extra `num_speculative_blocks` blocks for + # speculative decoding (MTP/EAGLE) with linear attention. + assert isinstance(self.kv_cache_spec, MambaSpec) + if self.kv_cache_spec.num_speculative_blocks > 0: + num_tokens += (self.kv_cache_spec.block_size * + self.kv_cache_spec.num_speculative_blocks) + return super().allocate_new_blocks(request_id, num_tokens) class CrossAttentionManager(SingleTypeKVCacheManager): @@ -583,6 +633,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, CrossAttentionSpec), ( "CrossAttentionManager can only be used for cross-attention groups" diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 5d8959a3cd3fe..dec4abec519bd 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,6 +3,7 @@ import enum import time +from collections.abc import Mapping from typing import Any, Optional, Union import msgspec @@ -66,6 +67,8 @@ class EngineCoreRequest( current_wave: int = 0 priority: int = 0 + trace_headers: Optional[Mapping[str, str]] = None + class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" @@ -111,6 +114,7 @@ class EngineCoreOutput( events: Optional[list[EngineCoreEvent]] = None kv_transfer_params: Optional[dict[str, Any]] = None + trace_headers: Optional[Mapping[str, str]] = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 @@ -144,7 +148,7 @@ class EngineCoreOutputs( omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] - #NOTE(Nick): We could consider ways to make this more compact, + # NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout engine_index: int = 0 diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 2a9fa1fd9172c..73165c7e4c0ad 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -26,10 +26,11 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask +from vllm.tracing import init_tracer from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs) @@ -97,17 +98,21 @@ class AsyncLLM(EngineClient): self.model_config = vllm_config.model_config self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config self.log_requests = log_requests - self.log_stats = log_stats + + self.log_stats = log_stats or (stat_loggers is not None) + if not log_stats and stat_loggers is not None: + logger.info( + "AsyncLLM created with log_stats=False and non-empty custom " + "logger list; enabling logging without default stat loggers") if self.model_config.skip_tokenizer_init: self.tokenizer = None else: # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + model_config=vllm_config.model_config) # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( @@ -119,6 +124,11 @@ class AsyncLLM(EngineClient): # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor(self.tokenizer, log_stats=self.log_stats) + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + self.output_processor.tracer = tracer # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_async_mp_client( @@ -137,6 +147,8 @@ class AsyncLLM(EngineClient): vllm_config=vllm_config, engine_idxs=self.engine_core.engine_ranks_managed, custom_stat_loggers=stat_loggers, + enable_default_loggers=log_stats, + client_count=client_count, ) self.logger_manager.log_engine_initialized() @@ -163,9 +175,6 @@ class AsyncLLM(EngineClient): worker_name=worker_name, use_gzip=True)) else: - logger.info( - "Torch profiler disabled. AsyncLLM CPU traces will not be collected." # noqa: E501 - ) self.profiler = None @classmethod @@ -579,24 +588,18 @@ class AsyncLLM(EngineClient): async def get_model_config(self) -> ModelConfig: return self.model_config - async def get_decoding_config(self): - raise ValueError("Not Supported on V1 yet.") - async def get_input_preprocessor(self) -> InputPreprocessor: return self.processor.input_preprocessor - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: + async def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " "skip_tokenizer_init is True") - return self.tokenizer.get_lora_tokenizer(lora_request) + return self.tokenizer async def is_tracing_enabled(self) -> bool: - return False + return self.observability_config.otlp_traces_endpoint is not None async def do_log_stats( self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 922c06b44be88..a022e9c0d7058 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc import os import queue import signal @@ -22,16 +23,15 @@ from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import receiver_cache_from_config +from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, resolve_obj_by_qualname, set_process_title) -from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config, +from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_configs, get_request_block_hasher, - init_none_hash, - unify_kv_cache_configs) + init_none_hash) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler @@ -128,9 +128,12 @@ class EngineCore: log_stats=self.log_stats, ) self.use_spec_decode = vllm_config.speculative_config is not None + if self.scheduler.connector is not None: # type: ignore + self.model_executor.init_kv_output_aggregator( + self.scheduler.connector.get_finished_count()) # type: ignore self.mm_registry = mm_registry = MULTIMODAL_REGISTRY - self.mm_receiver_cache = receiver_cache_from_config( + self.mm_receiver_cache = engine_receiver_cache_from_config( vllm_config, mm_registry) # Setup batch queue for pipeline parallelism. @@ -158,6 +161,9 @@ class EngineCore: self.request_block_hasher = get_request_block_hasher( block_size, caching_hash_fn) + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() @@ -187,18 +193,9 @@ class EngineCore: available_gpu_memory = [0] * len(kv_cache_specs) assert len(kv_cache_specs) == len(available_gpu_memory) - # Get the kv cache tensor size - kv_cache_configs = [ - get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, - available_gpu_memory_one_worker) - for kv_cache_spec_one_worker, available_gpu_memory_one_worker in - zip(kv_cache_specs, available_gpu_memory) - ] - # Since we use a shared centralized controller, we need the - # `kv_cache_config` to be consistent across all workers to make sure - # all the memory operators can be applied to all workers. - unify_kv_cache_configs(kv_cache_configs) + kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, + available_gpu_memory) # All workers have the same kv_cache_config except layer names, so use # an arbitrary one to initialize the scheduler. @@ -223,7 +220,7 @@ class EngineCore: def add_request(self, request: Request, request_wave: int = 0): """Add request to the scheduler. - + `request_wave`: indicate which wave of requests this is expected to belong to in DP case """ @@ -330,7 +327,8 @@ class EngineCore: model_executed = False if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() - future = self.model_executor.execute_model(scheduler_output) + future = self.model_executor.execute_model(scheduler_output, + non_block=True) batch_queue.appendleft( (future, scheduler_output)) # type: ignore[arg-type] @@ -432,13 +430,13 @@ class EngineCore: def preprocess_add_request( self, request: EngineCoreRequest) -> tuple[Request, int]: """Preprocess the request. - + This function could be directly used in input processing thread to allow request initialization running in parallel with Model forward """ # Note on thread safety: no race condition. # `mm_receiver_cache` is reset at the end of LLMEngine init, - # and will only accessed in the input processing thread afterwards. + # and will only be accessed in the input processing thread afterwards. if self.mm_receiver_cache is not None and request.mm_features: request.mm_features = ( self.mm_receiver_cache.get_and_update_features( @@ -533,8 +531,10 @@ class EngineCoreProc(EngineCore): assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + gc.collect() + gc.freeze() @contextmanager def _perform_handshakes( @@ -691,7 +691,7 @@ class EngineCoreProc(EngineCore): parallel_config: ParallelConfig = kwargs[ "vllm_config"].parallel_config if parallel_config.data_parallel_size > 1 or dp_rank > 0: - set_process_title("DPEngineCore", str(dp_rank)) + set_process_title("EngineCore", f"DP{dp_rank}") decorate_logs() # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 65f7abc97110c..bb0f37c6e0264 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -245,8 +245,8 @@ class InprocClient(EngineCoreClient): self.engine_core = EngineCore(*args, **kwargs) def get_output(self) -> EngineCoreOutputs: - outputs, _ = self.engine_core.step() - return outputs.get(0) or EngineCoreOutputs() + outputs, _ = self.engine_core.step_fn() + return outputs and outputs.get(0) or EngineCoreOutputs() def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.engine_core.get_supported_tasks() @@ -347,8 +347,9 @@ class BackgroundResources: if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. - loop = self.output_socket._get_loop() - asyncio.get_running_loop() + loop = self.output_queue_task._loop \ + if self.output_queue_task else None + sockets = (self.output_socket, self.input_socket, self.first_req_send_socket, self.first_req_rcv_socket, self.stats_update_socket) @@ -359,11 +360,12 @@ class BackgroundResources: close_sockets(sockets) for task in tasks: if task is not None and not task.done(): - task.cancel() + with contextlib.suppress(Exception): + task.cancel() if in_loop(loop): close_sockets_and_tasks() - elif not loop.is_closed(): + elif loop and not loop.is_closed(): loop.call_soon_threadsafe(close_sockets_and_tasks) else: # Loop has been closed, try to clean up directly. diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 04ad51aae0a8c..cf4b06db843bd 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -121,12 +121,9 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): self.output_token_ids) <= self.min_tokens: stop_check_offset = len(self.output_text) - if stop_terminated: - if skipped_stop_token_id is not None: - # Cleanup after skipping detokenization. - self.token_ids.append(skipped_stop_token_id) - # Stop token triggered; skip stop string check. - return None + if skipped_stop_token_id is not None: + # Cleanup after skipping detokenization. + self.token_ids.append(skipped_stop_token_id) # 2) Evaluate stop strings. stop_string = None @@ -233,8 +230,13 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): def _protected_step(self, next_token_id: int) -> Optional[str]: try: token = self.stream.step(self.tokenizer, next_token_id) + except OverflowError: + # Handle rare observed overflow, still to be diagnosed. + # See https://github.com/vllm-project/vllm/issues/21951. + logger.exception("Encountered invalid token id: %d", next_token_id) + token = None except Exception as e: - if str(e) != INVALID_PREFIX_ERR_MSG: + if not str(e).startswith(INVALID_PREFIX_ERR_MSG): raise e # Recover from edge case where tokenizer can produce non-monotonic, # invalid UTF-8 output, which breaks the internal state of @@ -243,7 +245,8 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): logger.warning( "Encountered invalid prefix detokenization error" " for request %s, resetting decode stream.", self.request_id) - self.stream = DecodeStream(self.skip_special_tokens) + self.stream = DecodeStream( + skip_special_tokens=self.skip_special_tokens) token = self.stream.step(self.tokenizer, next_token_id) return token diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 7130f666ef19f..c93bfc35f0aeb 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -19,8 +19,9 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.transformers_utils.tokenizer_group import ( - TokenizerGroup, init_tokenizer_from_configs) +from vllm.tracing import init_tracer +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.utils import Device from vllm.v1.engine.core_client import EngineCoreClient @@ -65,6 +66,7 @@ class LLMEngine: "Set VLLM_USE_V1=0 and file and issue on Github.") self.vllm_config = vllm_config + self.observability_config = vllm_config.observability_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -87,9 +89,7 @@ class LLMEngine: else: # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + model_config=vllm_config.model_config) # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, @@ -99,6 +99,11 @@ class LLMEngine: # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor(self.tokenizer, log_stats=self.log_stats) + if self.observability_config.otlp_traces_endpoint is not None: + tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + self.output_processor.tracer = tracer # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) self.engine_core = EngineCoreClient.make_client( @@ -290,7 +295,7 @@ class LLMEngine: assert self.log_stats, "Stat logging disabled" return get_metrics_snapshot() - def get_tokenizer_group(self) -> TokenizerGroup: + def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: raise ValueError("Unable to get tokenizer because " "skip_tokenizer_init is True") diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2ee55b585da6c..5dad63988daa4 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -11,8 +11,9 @@ import torch from vllm.outputs import (CompletionOutput, PoolingOutput, PoolingRequestOutput, RequestOutput) from vllm.sampling_params import RequestOutputKind +from vllm.tracing import (SpanAttributes, SpanKind, Tracer, + extract_trace_context) from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor @@ -71,7 +72,6 @@ class RequestOutputCollector: @dataclass class OutputProcessorOutput: - request_outputs: list[Union[RequestOutput, PoolingRequestOutput]] reqs_to_abort: list[str] @@ -93,6 +93,9 @@ class RequestState: arrival_time: float, queue: Optional[RequestOutputCollector], log_stats: bool, + top_p: Optional[float] = None, + n: Optional[int] = None, + temperature: Optional[float] = None, ): self.request_id = request_id self.parent_req = parent_req @@ -105,6 +108,9 @@ class RequestState: self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.max_tokens_param = max_tokens_param + self.top_p = top_p + self.n = n + self.temperature = temperature self.is_prefilling = True self.queue = queue self.num_cached_tokens = 0 @@ -137,10 +143,16 @@ class RequestState: request=request, ) max_tokens_param = sampling_params.max_tokens + top_p = sampling_params.top_p + n = sampling_params.n + temperature = sampling_params.temperature else: logprobs_processor = None detokenizer = None max_tokens_param = None + top_p = None + n = None + temperature = None assert request.pooling_params is not None output_kind = request.pooling_params.output_kind @@ -156,6 +168,9 @@ class RequestState: logprobs_processor=logprobs_processor, detokenizer=detokenizer, max_tokens_param=max_tokens_param, + top_p=top_p, + n=n, + temperature=temperature, arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, @@ -274,16 +289,13 @@ class RequestState: class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" - def __init__( - self, - tokenizer: TokenizerGroup, - log_stats: bool, - ): + def __init__(self, tokenizer: AnyTokenizer, log_stats: bool): self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() + self.tracer: Optional[Tracer] = None def get_num_unfinished_requests(self): return len(self.request_states) @@ -334,10 +346,7 @@ class OutputProcessor: if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - tokenizer = None if not self.tokenizer else \ - self.tokenizer.get_lora_tokenizer(request.lora_request) - - req_state = RequestState.from_new_request(tokenizer=tokenizer, + req_state = RequestState.from_new_request(tokenizer=self.tokenizer, request=request, prompt=prompt, parent_req=parent_req, @@ -360,17 +369,17 @@ class OutputProcessor: 1) Compute stats for logging 2) Detokenize 3) Create and handle RequestOutput objects: - * If there is a queue (for usage with AsyncLLM), + * If there is a queue (for usage with AsyncLLM), put the RequestOutput objects into the queue for handling by the per-request generate() tasks. - * If there is no queue (for usage with LLMEngine), + * If there is no queue (for usage with LLMEngine), return a list of RequestOutput objects. NOTE FOR DEVELOPERS vLLM V1 minimizes the number of python loops over the full - batch to ensure system overheads are minimized. This is the + batch to ensure system overheads are minimized. This is the only function that should loop over EngineCoreOutputs. If you need to touch every element of the batch, do it from @@ -441,7 +450,9 @@ class OutputProcessor: # Track per-request stats self._update_stats_from_finished(req_state, finish_reason, iteration_stats) - + if self.tracer: + self.do_tracing(engine_core_output, req_state, + iteration_stats) self.lora_states.update_iteration_stats(iteration_stats) return OutputProcessorOutput( @@ -449,6 +460,63 @@ class OutputProcessor: reqs_to_abort=reqs_to_abort, ) + def do_tracing(self, engine_core_output: EngineCoreOutput, + req_state: RequestState, + iteration_stats: Optional[IterationStats]) -> None: + assert req_state.stats is not None + assert iteration_stats is not None + assert self.tracer is not None + + arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) + trace_context = extract_trace_context(engine_core_output.trace_headers) + with (self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds) as span): + metrics = req_state.stats + e2e_time = iteration_stats.iteration_timestamp - \ + metrics.arrival_time + queued_time = metrics.scheduled_ts - metrics.queued_ts + prefill_time = metrics.first_token_ts - metrics.scheduled_ts + decode_time = metrics.last_token_ts - metrics.first_token_ts + inference_time = metrics.last_token_ts - metrics.scheduled_ts + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, + metrics.first_token_latency) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, + queued_time) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, + len(req_state.prompt_token_ids)) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + metrics.num_generation_tokens) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, + prefill_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, + decode_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, + inference_time) + + # meta + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, + req_state.request_id) + if req_state.top_p: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, + req_state.top_p) + if req_state.max_tokens_param: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, + req_state.max_tokens_param) + if req_state.temperature: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, + req_state.temperature) + if req_state.n: + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, + req_state.n) + def _update_stats_from_output(self, req_state: RequestState, engine_core_output: EngineCoreOutput, engine_core_timestamp: Optional[float], diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 1aa117ded4ed8..71f539583a1be 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -9,15 +9,16 @@ from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import processor_cache_from_config -from vllm.multimodal.inputs import MultiModalFeatureSpec +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) @@ -28,13 +29,15 @@ from vllm.v1.structured_output.backend_outlines import ( from vllm.v1.structured_output.backend_xgrammar import ( validate_xgrammar_grammar) +logger = init_logger(__name__) + class Processor: def __init__( self, vllm_config: VllmConfig, - tokenizer: TokenizerGroup, + tokenizer: AnyTokenizer, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): @@ -42,7 +45,7 @@ class Processor: self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config - self.decoding_config = vllm_config.decoding_config + self.structured_outputs_config = vllm_config.structured_outputs_config self.tokenizer = tokenizer self.generation_config_fields = ( @@ -65,24 +68,31 @@ class Processor: ) -> None: max_logprobs = self.model_config.max_logprobs if max_logprobs == -1: - return + max_logprobs = self.model_config.get_vocab_size() + # Validate sample logprobs. - if params.logprobs and (params.logprobs == -1 - or params.logprobs > max_logprobs): - raise ValueError( - f"Requested sample logprobs of {params.logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.logprobs: + num_logprobs = params.logprobs + if num_logprobs == -1: + num_logprobs = self.model_config.get_vocab_size() + if num_logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {num_logprobs}, " + f"which is is greater than max allowed: {max_logprobs}") # Validate prompt logprobs. - if params.prompt_logprobs and params.prompt_logprobs > max_logprobs: - raise ValueError( - f"Requested prompt logprobs of {params.prompt_logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.prompt_logprobs: + num_prompt_logprobs = params.prompt_logprobs + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.model_config.get_vocab_size() + if num_prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {num_prompt_logprobs}, " + f"which is is greater than max allowed: {max_logprobs}") def _validate_sampling_params( self, params: SamplingParams, - lora_request: Optional[LoRARequest], ) -> None: self._validate_structured_output(params) self._validate_logit_bias(params) @@ -95,8 +105,7 @@ class Processor: # When skip_tokenizer_init=True, we can't validate token IDs # Skip validation and let the model handle invalid tokens return - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) - vocab_size = len(tokenizer) + vocab_size = len(self.tokenizer) if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): raise ValueError( "allowed_token_ids contains out-of-vocab token id!") @@ -136,7 +145,6 @@ class Processor: def _validate_params( self, params: Union[SamplingParams, PoolingParams], - lora_request: Optional[LoRARequest], ): """ Validate supported SamplingParam. @@ -147,14 +155,14 @@ class Processor: return self._validate_logprobs(params) - self._validate_sampling_params(params, lora_request) + self._validate_sampling_params(params) self._validate_supported_sampling_params(params) def _validate_multi_modal_uuids(self, prompt: PromptType) -> None: """ Validate that user-provided multi_modal_uuids align with multi_modal_data in the incoming request prompt(s). - Only checks lengths; `None` entries are allowed and will be + Only checks lengths; `None` entries are allowed and will be auto-hashed downstream. """ @@ -194,85 +202,96 @@ class Processor: _validate_single_prompt(prompt) # type: ignore[arg-type] def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: - if lora_request is not None and not self.lora_config: + if lora_request is None: + return + + # LoRA request passed in while LoRA is not enabled + if not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + if self.tokenizer is not None: + logger.warning_once( + "vLLM has deprecated support for supporting different " + "tokenizers for different LoRAs. By default, vLLM uses base " + "model's tokenizer. If you are using a LoRA " + "with its own tokenizer, consider specifying `--tokenizer " + "[lora_path]` to use the LoRA tokenizer.") + def _validate_structured_output(self, params: SamplingParams) -> None: - if not params.guided_decoding or not self.decoding_config: + if not params.structured_outputs or not self.structured_outputs_config: return - if self.model_config.skip_tokenizer_init and params.guided_decoding: + if self.model_config.skip_tokenizer_init and params.structured_outputs: raise ValueError( "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 ) - engine_level_backend = self.decoding_config.backend - if params.guided_decoding.backend: - # Request-level backend selection is not supported in V1. + backend = self.structured_outputs_config.backend + if _backend := params.structured_outputs._backend: + # Request-level backend selection is not supported. # The values may differ if `params` is reused and was set # to a specific backend based on `auto` behavior in a previous # request. We remember that it was set as a result of `auto` - # using the `_auto` option set on the backend in the params. - if (params.guided_decoding.backend != engine_level_backend - and not (engine_level_backend == "auto" - and params.guided_decoding.backend_was_auto)): + # using the `_backend_was_auto` field set in the params. + if (backend != _backend + and not (backend == "auto" + and params.structured_outputs._backend_was_auto)): raise ValueError( - "Request-level structured output backend selection is no " - "longer supported. The request specified " - f"'{params.guided_decoding.backend}', but vLLM was " - f"initialised with '{engine_level_backend}'. This error " - "can be resolved by removing backend selection from the " - "request.") + "Request-level structured output backend selection is not " + f"supported. The request specified '{_backend}', but vLLM " + f"was initialised with '{backend}'. This error can be " + "resolved by removing '_backend' from the request.") else: - params.guided_decoding.backend = engine_level_backend + params.structured_outputs._backend = backend # Request content validation - if (isinstance(params.guided_decoding.choice, list) - and not params.guided_decoding.choice): + if (isinstance(params.structured_outputs.choice, list) + and not params.structured_outputs.choice): # It is invalid for choice to be an empty list - raise ValueError(f"Choice '{params.guided_decoding.choice}' " - "cannot be an empty list") + raise ValueError( + f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501 + ) - if engine_level_backend.startswith("xgrammar"): + if backend.startswith("xgrammar"): # xgrammar with no fallback validate_xgrammar_grammar(params) - elif engine_level_backend.startswith("guidance"): + elif backend.startswith("guidance"): # TODO: ideally we would have the LLTokenizer here as Lark syntax # allows <|special_token|> and similar, see # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. validate_guidance_grammar(params, tokenizer=None) - elif engine_level_backend == "outlines": + elif backend == "outlines": # outlines backend validate_structured_output_request_outlines(params) - elif engine_level_backend == "lm-format-enforcer": + elif backend == "lm-format-enforcer": # lm format enforcer backend validate_structured_output_request_lm_format_enforcer(params) else: - # NOTE: engine_level_backend must be "auto" here, because we have + # NOTE: backend must be "auto" here, because we have # checked supported_backends above. - # "auto" is an opt-in to opinionated behavior where we try to - # choose a backend based on request contents. This is not the - # default as it is less predictable and subject to change - # between releases as feature support changes. + # In this mode, we set opinionated defaults based on what we think + # will satisfy the most use cases without having to worry about + # this setting. We include fallback behavior here, but not with any + # other setting where a specific backend was specified. try: validate_xgrammar_grammar(params) - params.guided_decoding.backend = "xgrammar" + params.structured_outputs._backend = "xgrammar" except ValueError: # The request either failed validation # or includes some jsonschema feature(s) that # are not supported in xgrammar. Fall back to guidance. validate_guidance_grammar(params, tokenizer=None) - params.guided_decoding.backend = "guidance" + params.structured_outputs._backend = "guidance" # Remember that this backend was set automatically - params.guided_decoding.backend_was_auto = True + params.structured_outputs._backend_was_auto = True - def _maybe_build_mm_hash_overrides( + def _maybe_build_mm_uuids( self, request_id: str, prompt: PromptType, - ) -> Optional[dict[str, list[str]]]: + ) -> Optional[MultiModalUUIDDict]: """Build per-item multimodal hash overrides when enabled. In this case, multimodal data items are identified by their request id, modality and index rather than their content. @@ -295,13 +314,13 @@ class Processor: if not mm_data: return None - overrides: dict[str, list[str]] = {} + mm_uuids: MultiModalUUIDDict = {} for modality, data in mm_data.items(): n = len(data) if isinstance(data, list) else 1 - overrides[modality] = [ + mm_uuids[modality] = [ f"{request_id}-{modality}-{i}" for i in range(n) ] - return overrides + return mm_uuids def process_inputs( self, @@ -317,11 +336,8 @@ class Processor: ) -> tuple[Optional[str], EngineCoreRequest]: # TODO(woosuk): Support pooling models. - # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) - self._validate_params(params, lora_request) - if trace_headers is not None: - raise ValueError("V1 does not support tracing yet.") + self._validate_params(params) data_parallel_size = self.vllm_config.parallel_config.data_parallel_size if data_parallel_rank is not None and not (0 <= data_parallel_rank < @@ -343,16 +359,15 @@ class Processor: if (self.model_config.multimodal_config and self.model_config.multimodal_config.mm_processor_cache_gb == 0 and not self.cache_config.enable_prefix_caching): - mm_hash_overrides = self._maybe_build_mm_hash_overrides( - request_id, prompt) + mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) else: # Otherwise, use user-provided uuids as multimodal hash overrides # if provided. self._validate_multi_modal_uuids(prompt) if isinstance(prompt, dict): - mm_hash_overrides = prompt.get("multi_modal_uuids") + mm_uuids = prompt.get("multi_modal_uuids") else: - mm_hash_overrides = None + mm_uuids = None # Process inputs, which includes: # 1. Tokenize text prompt, with LoRA request if one exists. @@ -361,8 +376,7 @@ class Processor: processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, - lora_request=lora_request, - mm_hash_overrides=mm_hash_overrides, + mm_uuids=mm_uuids, ) from vllm.platforms import current_platform current_platform.validate_request( @@ -371,16 +385,12 @@ class Processor: processed_inputs=processed_inputs, ) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + eos_token_id = self.input_preprocessor.get_eos_token_id() - self._validate_model_inputs(processed_inputs, lora_request) + self._validate_model_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - # TODO: Impl encoder-decoder - if encoder_inputs is not None: - raise NotImplementedError - sampling_params = None pooling_params = None if isinstance(params, SamplingParams): @@ -394,8 +404,7 @@ class Processor: sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) if self.tokenizer is not None: - sampling_params.update_from_tokenizer( - self.tokenizer.get_lora_tokenizer(lora_request)) + sampling_params.update_from_tokenizer(self.tokenizer) else: pooling_params = params.clone() @@ -433,26 +442,20 @@ class Processor: cache_salt=decoder_inputs.get("cache_salt"), priority=priority, data_parallel_rank=data_parallel_rank, + trace_headers=trace_headers, ) - def _validate_model_inputs(self, - inputs: ProcessorInputs, - lora_request: Optional[LoRARequest] = None): + def _validate_model_inputs(self, inputs: ProcessorInputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) if encoder_inputs is not None: - self._validate_model_input(encoder_inputs, - lora_request, - prompt_type="encoder") + self._validate_model_input(encoder_inputs, prompt_type="encoder") - self._validate_model_input(decoder_inputs, - lora_request, - prompt_type="decoder") + self._validate_model_input(decoder_inputs, prompt_type="decoder") def _validate_model_input( self, prompt_inputs: SingletonInputs, - lora_request: Optional[LoRARequest], *, prompt_type: Literal["encoder", "decoder"], ): @@ -468,7 +471,7 @@ class Processor: if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + tokenizer = self.tokenizer max_input_id = max(prompt_ids, default=0) # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while @@ -497,7 +500,7 @@ class Processor: assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper and Donut + return # Skip encoder length check for Whisper if model_config.is_multimodal_model: suggestion = ( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 56ef8477d267a..18ef25ceb6f5e 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -116,7 +116,7 @@ class CoreEngineProcManager: local_dp_ranks.append(local_index) self.processes.append( context.Process(target=target_fn, - name=f"EngineCore_{global_index}", + name=f"EngineCore_DP{global_index}", kwargs=common_kwargs | { "dp_rank": global_index, "local_dp_rank": local_index, @@ -315,7 +315,6 @@ class CoreEngineActorManager: import ray from ray._private.state import available_resources_per_node - from ray.util.state import list_nodes logger.info("Creating placement groups for data parallel") dp_master_ip = \ @@ -324,34 +323,33 @@ class CoreEngineActorManager: local_engine_count = \ vllm_config.parallel_config.data_parallel_size_local - nodes = sorted(list_nodes(filters=[("state", "=", "ALIVE")]), - key=lambda node: node.node_ip != dp_master_ip) - assert nodes[0].node_ip == dp_master_ip, ( - "The head node is missing or dead") - assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( - "There can only be one head node") - available_resources = available_resources_per_node() world_size = vllm_config.parallel_config.world_size placement_groups: list[PlacementGroup] = [] local_dp_ranks: list[int] = [] - - for node in nodes: - node_ip = node.node_ip - node_resources = available_resources[node.node_id] - if "GPU" not in node_resources: + dp_master_ip_key = f'node:{dp_master_ip}' + nodes = sorted(available_resources.values(), + key=lambda x: dp_master_ip_key not in x) + assert len(nodes) > 0, ( + "No nodes with resources found in Ray cluster.") + assert dp_master_ip_key in nodes[0], ( + "The DP master node (ip: %s) is missing or dead", dp_master_ip) + device_str = current_platform.ray_device_key + for node_resources in nodes: + if device_str not in node_resources: continue # For now, each DP rank can only be assigned to one node # TODO(rui): support allocating a single DP rank # to multiple nodes - available_engine_count = int(node_resources["GPU"]) // world_size - if node_ip == dp_master_ip: + available_engine_count = int( + node_resources[device_str]) // world_size + if dp_master_ip_key in node_resources: assert available_engine_count >= local_engine_count, ( "Not enough resources to allocate DP ranks " - f"on DP master node {node_ip}") + f"on DP master node {dp_master_ip}") for i in range(local_engine_count): bundles = [{ - "GPU": 1.0, + device_str: 1.0, "node:" + dp_master_ip: 0.001 }] * world_size + [{ "CPU": 1.0 @@ -367,7 +365,7 @@ class CoreEngineActorManager: for i in range(available_engine_count): if len(placement_groups) == num_pg_to_create: break - bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] + bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}] pg = ray.util.placement_group( name=f"dp_rank_{len(placement_groups)}", strategy="STRICT_PACK", @@ -419,17 +417,18 @@ class CoreEngineActorManager: local_dp_ranks = [] num_pg_created = 0 + device_str = current_platform.ray_device_key for node in nodes: if num_pg_created >= num_pg_to_create: break node_ip = node.node_ip node_id = node.node_id - available_gpus = int(available_resources[node_id]["GPU"]) + available_gpus = int(available_resources[node_id][device_str]) # Get total GPUs on this node from the node's resources # Ray stores node resources with node ID as key - total_gpus = int(total_resources[node_id]["GPU"]) + total_gpus = int(total_resources[node_id][device_str]) # Calculate used GPUs and used engines on this node used_gpus = max(0, total_gpus - available_gpus) @@ -448,13 +447,13 @@ class CoreEngineActorManager: # Create bundles with node constraint for master node if node_ip == dp_master_ip: bundles = [{ - "GPU": 1.0, + device_str: 1.0, "node:" + dp_master_ip: 0.001 }] * world_size + [{ "CPU": 1.0 }] else: - bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] + bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}] pg = ray.util.placement_group( name=f"dp_rank_{rank}", diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 68408a0b8a3d5..625017d52fff0 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -14,6 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa from vllm.executor.uniproc_executor import ( # noqa UniProcExecutor as UniProcExecutorV0) from vllm.utils import resolve_obj_by_qualname +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput @@ -86,12 +87,22 @@ class Executor(ExecutorBase): def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: return self.collective_rpc("get_kv_cache_spec") + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False) -> list[Any]: + raise NotImplementedError + def execute_model( self, - scheduler_output, + scheduler_output: SchedulerOutput, + non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: output = self.collective_rpc("execute_model", - args=(scheduler_output, )) + args=(scheduler_output, ), + non_block=non_block) return output[0] def execute_dummy_batch(self) -> None: diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 12e79ff165f4e..3aa373f12b609 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing -import os import pickle +import queue import signal import threading import time @@ -11,9 +11,10 @@ import weakref from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto -from functools import partial +from functools import cached_property, partial from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess +from multiprocessing.synchronize import Lock as LockType from threading import Thread from typing import Any, Callable, Optional, Union, cast @@ -25,15 +26,21 @@ from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_pp_group, get_tp_group) from vllm.executor.multiproc_worker_utils import ( set_multiprocessing_worker_envs) from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import (decorate_logs, get_distributed_init_method, get_loopback_ip, get_mp_context, get_open_port, set_process_title) +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.executor.utils import get_and_update_mm_cache +from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, + ModelRunnerOutput) from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -78,6 +85,8 @@ class MultiprocExecutor(Executor): scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers + context = get_mp_context() + shared_worker_lock = context.Lock() unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: @@ -89,6 +98,7 @@ class MultiprocExecutor(Executor): rank=rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, + shared_worker_lock=shared_worker_lock, )) # Workers must be created before wait_for_ready to avoid @@ -124,8 +134,6 @@ class MultiprocExecutor(Executor): self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None - self.kv_output_aggregator = KVOutputAggregator( - self.parallel_config.world_size) def start_worker_monitor(self): workers = self.workers @@ -164,9 +172,9 @@ class MultiprocExecutor(Executor): def execute_model( self, - scheduler_output, + scheduler_output: SchedulerOutput, + non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - non_block = self.max_concurrent_batches > 1 if not self.has_connector: # get output only from a single worker (output_rank) @@ -253,7 +261,8 @@ class MultiprocExecutor(Executor): if not non_block: result = result.result() elif not non_block: - result = get_response(w, dequeue_timeout) + result = get_response(w, dequeue_timeout, + self.shutdown_event) else: raise RuntimeError("non_block can only be used when" " max_concurrent_batches > 1") @@ -295,12 +304,8 @@ class MultiprocExecutor(Executor): """Properly shut down the executor and its workers""" if not getattr(self, 'shutting_down', False): self.shutting_down = True - self.shutdown_event.set() - - if self.io_thread_pool is not None: - self.io_thread_pool.shutdown(wait=False, cancel_futures=True) - self.io_thread_pool = None + # Make sure all the worker processes are terminated first. if workers := getattr(self, 'workers', None): for w in workers: # Close death_writer to signal child processes to exit @@ -310,13 +315,18 @@ class MultiprocExecutor(Executor): w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in workers]) + self.shutdown_event.set() + if self.io_thread_pool is not None: + self.io_thread_pool.shutdown(wait=False, cancel_futures=True) + del self.io_thread_pool + self.rpc_broadcast_mq = None def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return - @property + @cached_property def max_concurrent_batches(self) -> int: if self.scheduler_config.async_scheduling: return 2 @@ -375,6 +385,7 @@ class WorkerProc: rank: int, distributed_init_method: str, input_shm_handle: Handle, + shared_worker_lock: LockType, ): self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) @@ -394,17 +405,6 @@ class WorkerProc: wrapper.init_worker(all_kwargs) self.worker = wrapper - pp_size = vllm_config.parallel_config.pipeline_parallel_size - tp_size = vllm_config.parallel_config.tensor_parallel_size - pp_str = f"PP{rank // tp_size}" if pp_size > 1 else "" - tp_str = f"TP{rank % tp_size}" if tp_size > 1 else "" - suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}" - process_name = "VllmWorker" - if suffix: - set_process_title(suffix, append=True) - process_name = f"{process_name} {suffix}" - decorate_logs(process_name) - # Initialize MessageQueue for receiving SchedulerOutput self.rpc_broadcast_mq = MessageQueue.create_from_handle( input_shm_handle, self.worker.rank) @@ -412,17 +412,38 @@ class WorkerProc: # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) - # Initialize device and loads weights + scheduler_config = vllm_config.scheduler_config + self.use_async_scheduling = scheduler_config.async_scheduling + if self.use_async_scheduling: + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_thread = Thread( + target=self.async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy") + self.async_output_copy_thread.start() + + # Initialize multimodal receiver cache if needed + self.mm_receiver_cache = worker_receiver_cache_from_config( + vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock) + + # Initialize device self.worker.init_device() + + # Set process title and log prefix + self.setup_proc_title_and_log_prefix( + enable_ep=vllm_config.parallel_config.enable_expert_parallel) + + # Load model self.worker.load_model() @staticmethod def make_worker_process( - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - input_shm_handle, # Receive SchedulerOutput + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + shared_worker_lock: LockType, ) -> UnreadyWorkerProcHandle: context = get_mp_context() # (reader, writer) @@ -439,6 +460,7 @@ class WorkerProc: "input_shm_handle": input_shm_handle, "ready_pipe": (reader, writer), "death_pipe": death_reader, + "shared_worker_lock": shared_worker_lock, } # Run EngineCore busy loop in background process. proc = context.Process(target=WorkerProc.worker_main, @@ -493,6 +515,7 @@ class WorkerProc: return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): + self.worker.shutdown() self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() @@ -522,7 +545,7 @@ class WorkerProc: # tuple[Connection, Connection] reader, ready_writer = kwargs.pop("ready_pipe") death_pipe = kwargs.pop("death_pipe", None) - + shutdown_event = threading.Event() # Start death monitoring thread if death_pipe is provided if death_pipe is not None: @@ -534,7 +557,7 @@ class WorkerProc: # Parent process has exited, terminate this worker logger.info("Parent process exited, terminating worker") # Send signal to self to trigger clean shutdown - os.kill(os.getpid(), signal.SIGTERM) + shutdown_event.set() except Exception as e: logger.warning("Death monitoring error: %s", e) @@ -562,7 +585,7 @@ class WorkerProc: ready_writer.close() ready_writer = None - worker.worker_busy_loop() + worker.worker_busy_loop(cancel=shutdown_event) except Exception: # NOTE: if an Exception arises in busy_loop, we send @@ -572,6 +595,8 @@ class WorkerProc: if ready_writer is not None: logger.exception("WorkerProc failed to start.") + elif shutdown_event.is_set(): + logger.info("WorkerProc shutting down.") else: logger.exception("WorkerProc failed.") @@ -593,16 +618,51 @@ class WorkerProc: SUCCESS = auto() FAILURE = auto() - def worker_busy_loop(self): + def enqueue_output(self, output: Any): + """Prepares output from the worker and enqueues it to the + worker_response_mq. If the output is an Exception, it is + converted to a FAILURE response. + """ + if isinstance(output, AsyncModelRunnerOutput): + output = output.get_output() + + if isinstance(output, Exception): + result = (WorkerProc.ResponseStatus.FAILURE, str(output)) + else: + result = (WorkerProc.ResponseStatus.SUCCESS, output) + if (response_mq := self.worker_response_mq) is not None: + response_mq.enqueue(result) + + def handle_output(self, output: Any): + """Handles output from the worker. If async scheduling is enabled, + it is passed to the async_output_busy_loop thread. Otherwise, it is + enqueued directly to the worker_response_mq. + """ + if self.use_async_scheduling: + self.async_output_queue.put(output) + else: + self.enqueue_output(output) + + def async_output_busy_loop(self): + """Entrypoint for the thread which handles outputs asynchronously.""" + while True: + output = self.async_output_queue.get() + self.enqueue_output(output) + + def worker_busy_loop(self, cancel: Optional[threading.Event] = None): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() - + method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( + cancel=cancel) try: if isinstance(method, str): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) + # retrieve from shm cache if available + if self.mm_receiver_cache is not None \ + and func.__name__ == "execute_model": + get_and_update_mm_cache(self.mm_receiver_cache, args) output = func(*args, **kwargs) except Exception as e: # Notes have been introduced in python 3.11 @@ -612,10 +672,29 @@ class WorkerProc: # exception might not be serializable, so we convert it to # string, only for logging purpose. if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.FAILURE, str(e))) + self.handle_output(e) continue if output_rank is None or self.rank == output_rank: - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.SUCCESS, output)) + self.handle_output(output) + + @staticmethod + def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: + dp_size = get_dp_group().world_size + dp_rank = get_dp_group().rank_in_group + pp_size = get_pp_group().world_size + pp_rank = get_pp_group().rank_in_group + tp_size = get_tp_group().world_size + tp_rank = get_tp_group().rank_in_group + process_name = "Worker" + if dp_size > 1: + process_name += f"_DP{dp_rank}" + if pp_size > 1: + process_name += f"_PP{pp_rank}" + if tp_size > 1: + process_name += f"_TP{tp_rank}" + if enable_ep: + ep_rank = get_ep_group().rank_in_group + process_name += f"_EP{ep_rank}" + set_process_title(name=process_name) + decorate_logs(process_name) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 8394ae788ab01..aadb5fd1dddd5 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -51,8 +51,6 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): # KV connector setup self.has_connector = self.vllm_config.kv_transfer_config is not None - self.kv_output_aggregator = KVOutputAggregator( - self.parallel_config.world_size) @property def max_concurrent_batches(self) -> int: @@ -66,11 +64,13 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): def execute_model( self, scheduler_output: SchedulerOutput, + non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: """Execute the model on the Ray workers. Args: scheduler_output: The scheduler output to execute. + non_block: If True, the method will return a Future. Returns: The model runner output. @@ -84,7 +84,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): if not self.has_connector: # Get output only from a single worker (output_rank) # When PP is not used, we block here until the result is available. - if self.max_concurrent_batches == 1: + if not non_block: return refs[0].get() # When PP is used, we return a FutureWrapper immediately so that @@ -92,7 +92,7 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): return FutureWrapper(refs) # Get output from all workers when connector is present - if self.max_concurrent_batches == 1: + if not non_block: # Block and get results from all workers outputs = [ref.get() for ref in refs] return self.kv_output_aggregator.aggregate(outputs) @@ -106,4 +106,3 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): if reconfig_request.new_data_parallel_rank == \ ReconfigureRankType.SHUTDOWN_CURRENT_RANK: self.shutdown() - return \ No newline at end of file diff --git a/vllm/v1/executor/utils.py b/vllm/v1/executor/utils.py new file mode 100644 index 0000000000000..1855bc9963817 --- /dev/null +++ b/vllm/v1/executor/utils.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.multimodal.cache import ShmObjectStoreReceiverCache +from vllm.v1.core.sched.output import SchedulerOutput + + +def get_and_update_mm_cache( + receiver_cache: ShmObjectStoreReceiverCache, + args: tuple[SchedulerOutput], +) -> None: + """ + For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory + cache as needed. + + Args: + receiver_cache: The receiver cache to update. + args: According to the collective_rpc call of execute_model method in + executor, args is a tuple of only one SchedulerOutput element. + """ + scheduler_output = args[0] + for request_data in scheduler_output.scheduled_new_reqs: + request_data.mm_features = receiver_cache.get_and_update_features( + request_data.mm_features) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a3e4d393e4d20..0cf92a680a689 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,7 +11,6 @@ from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.utils import cdiv, get_dtype_size logger = init_logger(__name__) @@ -86,6 +85,12 @@ class FullAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len + dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): each dcp rank only need save + # (max_model_len//dcp_world_size) tokens locally. + if dcp_world_size > 1: + max_model_len = cdiv(max_model_len, dcp_world_size) return cdiv(max_model_len, self.block_size) * self.page_size_bytes @classmethod @@ -162,6 +167,8 @@ class SlidingWindowSpec(AttentionSpec): assert not self.use_mla, "MLA is not supported for sliding window" def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + assert vllm_config.parallel_config.decode_context_parallel_size == 1, \ + "DCP not support sliding window." max_model_len = vllm_config.model_config.max_model_len max_num_batched_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) @@ -186,6 +193,7 @@ class MambaSpec(KVCacheSpec): dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" + num_speculative_blocks: int = 0 @property def page_size_bytes(self) -> int: @@ -221,8 +229,8 @@ class CrossAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # For cross-attention, we need to cache encoder states # Get encoder length (e.g., 1500 for Whisper). - max_encoder_len = MULTIMODAL_REGISTRY.\ - get_encdec_max_encoder_len(vllm_config.model_config) + max_encoder_len = vllm_config.scheduler_config.\ + max_num_encoder_input_tokens return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3b0616952babf..b30036a6f8e80 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -169,15 +169,11 @@ class PrometheusStatLogger(StatLoggerBase): model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len - if (len(self.engine_indexes) > 1 - and vllm_config.speculative_config is not None): - raise NotImplementedError("Prometheus metrics with Spec Decoding " - "with >1 EngineCore per AsyncLLM is not " - "supported yet.") - spec_decode_labelvalues = [ - vllm_config.model_config.served_model_name, - str(self.engine_indexes[0]) - ] + spec_decode_labelvalues: dict[int, list[str]] = { + idx: [model_name, str(idx)] + for idx in engine_indexes + } + self.spec_decoding_prom = self._spec_decoding_cls( vllm_config.speculative_config, labelnames, spec_decode_labelvalues) @@ -206,40 +202,46 @@ class PrometheusStatLogger(StatLoggerBase): # # GPU cache # - # Deprecated in 0.9 - Renamed as vllm:kv_cache_usage_perc - # TODO: in 0.10, only enable if show_hidden_metrics=True - gauge_gpu_cache_usage = self._gauge_cls( - name="vllm:gpu_cache_usage_perc", - documentation=( - "GPU KV-cache usage. 1 means 100 percent usage." - "DEPRECATED: Use vllm:kv_cache_usage_perc instead."), - multiprocess_mode="mostrecent", - labelnames=labelnames) - self.gauge_gpu_cache_usage = make_per_engine(gauge_gpu_cache_usage, - engine_indexes, - model_name) + # Deprecated in 0.9.2 - Renamed as vllm:kv_cache_usage_perc + # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 + # TODO: remove in 0.12.0 + if self.show_hidden_metrics: + gauge_gpu_cache_usage = self._gauge_cls( + name="vllm:gpu_cache_usage_perc", + documentation=( + "GPU KV-cache usage. 1 means 100 percent usage." + "DEPRECATED: Use vllm:kv_cache_usage_perc instead."), + multiprocess_mode="mostrecent", + labelnames=labelnames) + self.gauge_gpu_cache_usage = make_per_engine( + gauge_gpu_cache_usage, engine_indexes, model_name) - # Deprecated in 0.9 - Renamed as vllm:prefix_cache_queries - # TODO: in 0.10, only enable if show_hidden_metrics=True - counter_gpu_prefix_cache_queries = self._counter_cls( - name="vllm:gpu_prefix_cache_queries", - documentation=( - "GPU prefix cache queries, in terms of number of queried" - "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead."), - labelnames=labelnames) - self.counter_gpu_prefix_cache_queries = make_per_engine( - counter_gpu_prefix_cache_queries, engine_indexes, model_name) + # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_queries + # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 + # TODO: remove in 0.12.0 + if self.show_hidden_metrics: + counter_gpu_prefix_cache_queries = self._counter_cls( + name="vllm:gpu_prefix_cache_queries", + documentation=( + "GPU prefix cache queries, in terms of number of queried" + "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead." + ), + labelnames=labelnames) + self.counter_gpu_prefix_cache_queries = make_per_engine( + counter_gpu_prefix_cache_queries, engine_indexes, model_name) - # Deprecated in 0.9 - Renamed as vllm:prefix_cache_hits - # TODO: in 0.10, only enable if show_hidden_metrics=True - counter_gpu_prefix_cache_hits = self._counter_cls( - name="vllm:gpu_prefix_cache_hits", - documentation=( - "GPU prefix cache hits, in terms of number of cached " - "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."), - labelnames=labelnames) - self.counter_gpu_prefix_cache_hits = make_per_engine( - counter_gpu_prefix_cache_hits, engine_indexes, model_name) + # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_hits + # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 + # TODO: remove in 0.12.0 + if self.show_hidden_metrics: + counter_gpu_prefix_cache_hits = self._counter_cls( + name="vllm:gpu_prefix_cache_hits", + documentation=( + "GPU prefix cache hits, in terms of number of cached " + "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."), + labelnames=labelnames) + self.counter_gpu_prefix_cache_hits = make_per_engine( + counter_gpu_prefix_cache_hits, engine_indexes, model_name) gauge_kv_cache_usage = self._gauge_cls( name="vllm:kv_cache_usage_perc", @@ -377,9 +379,13 @@ class PrometheusStatLogger(StatLoggerBase): self.histogram_time_to_first_token = make_per_engine( histogram_time_to_first_token, engine_indexes, model_name) + # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds + # TODO: in 0.12, only enable if show_hidden_metrics=True histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", - documentation="Histogram of time per output token in seconds.", + documentation=( + "Histogram of time per output token in seconds." + "DEPRECATED: Use vllm:inter_token_latency_seconds instead."), buckets=[ 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 @@ -388,6 +394,17 @@ class PrometheusStatLogger(StatLoggerBase): self.histogram_time_per_output_token = make_per_engine( histogram_time_per_output_token, engine_indexes, model_name) + histogram_inter_token_latency = self._histogram_cls( + name="vllm:inter_token_latency_seconds", + documentation="Histogram of inter-token latency in seconds.", + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ], + labelnames=labelnames) + self.histogram_inter_token_latency = make_per_engine( + histogram_inter_token_latency, engine_indexes, model_name) + request_latency_buckets = [ 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 @@ -498,15 +515,17 @@ class PrometheusStatLogger(StatLoggerBase): self.gauge_scheduler_waiting[engine_idx].set( scheduler_stats.num_waiting_reqs) - self.gauge_gpu_cache_usage[engine_idx].set( - scheduler_stats.kv_cache_usage) + if self.show_hidden_metrics: + self.gauge_gpu_cache_usage[engine_idx].set( + scheduler_stats.kv_cache_usage) self.gauge_kv_cache_usage[engine_idx].set( scheduler_stats.kv_cache_usage) - self.counter_gpu_prefix_cache_queries[engine_idx].inc( - scheduler_stats.prefix_cache_stats.queries) - self.counter_gpu_prefix_cache_hits[engine_idx].inc( - scheduler_stats.prefix_cache_stats.hits) + if self.show_hidden_metrics: + self.counter_gpu_prefix_cache_queries[engine_idx].inc( + scheduler_stats.prefix_cache_stats.queries) + self.counter_gpu_prefix_cache_hits[engine_idx].inc( + scheduler_stats.prefix_cache_stats.hits) self.counter_prefix_cache_queries[engine_idx].inc( scheduler_stats.prefix_cache_stats.queries) @@ -515,7 +534,7 @@ class PrometheusStatLogger(StatLoggerBase): if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats) + scheduler_stats.spec_decoding_stats, engine_idx) if iteration_stats is None: return @@ -537,8 +556,9 @@ class PrometheusStatLogger(StatLoggerBase): self.histogram_n_request[engine_idx].observe(n_param) for ttft in iteration_stats.time_to_first_tokens_iter: self.histogram_time_to_first_token[engine_idx].observe(ttft) - for tpot in iteration_stats.time_per_output_tokens_iter: - self.histogram_time_per_output_token[engine_idx].observe(tpot) + for itl in iteration_stats.inter_token_latencies_iter: + self.histogram_inter_token_latency[engine_idx].observe(itl) + self.histogram_time_per_output_token[engine_idx].observe(itl) for finished_request in iteration_stats.finished_requests: self.counter_request_success[ @@ -635,15 +655,21 @@ class StatLoggerManager: vllm_config: VllmConfig, engine_idxs: Optional[list[int]] = None, custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, + enable_default_loggers: bool = True, + client_count: int = 1, ): self.engine_idxs = engine_idxs if engine_idxs else [0] - factories: list[StatLoggerFactory] + factories: list[StatLoggerFactory] = [] if custom_stat_loggers is not None: - factories = custom_stat_loggers - else: - factories = [] - if logger.isEnabledFor(logging.INFO): + factories.extend(custom_stat_loggers) + + if enable_default_loggers and logger.isEnabledFor(logging.INFO): + if client_count > 1: + logger.warning( + "AsyncLLM created with api_server_count more than 1; " + "disabling stats logging to avoid incomplete stats.") + else: factories.append(LoggingStatLogger) # engine_idx: StatLogger diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 95094bda65cde..e6c344d193df2 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -68,6 +68,9 @@ class RequestStateStats: first_token_ts: float = 0.0 last_token_ts: float = 0.0 + # first token latency + first_token_latency: float = 0.0 + @dataclass class FinishedRequestStats: @@ -96,7 +99,7 @@ class IterationStats: self.max_num_generation_tokens_iter: list[int] = [] self.n_params_iter: list[int] = [] self.time_to_first_tokens_iter: list[float] = [] - self.time_per_output_tokens_iter: list[float] = [] + self.inter_token_latencies_iter: list[float] = [] self.waiting_lora_adapters: dict[str, int] = {} self.running_lora_adapters: dict[str, int] = {} @@ -116,6 +119,7 @@ class IterationStats: first_token_latency = self._time_since(req_stats.arrival_time) self.time_to_first_tokens_iter.append(first_token_latency) + req_stats.first_token_latency = first_token_latency req_stats.num_generation_tokens += num_new_generation_tokens @@ -128,8 +132,8 @@ class IterationStats: if is_prefilling: req_stats.first_token_ts = engine_core_timestamp else: - tpot = engine_core_timestamp - req_stats.last_token_ts - self.time_per_output_tokens_iter.append(tpot) + itl = engine_core_timestamp - req_stats.last_token_ts + self.inter_token_latencies_iter.append(itl) req_stats.last_token_ts = engine_core_timestamp diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f8d6b24702f3c..1b2da8addb19e 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import NamedTuple, Optional @@ -114,6 +115,20 @@ class ModelRunnerOutput: num_nans_in_logits: Optional[dict[str, int]] = None +# ModelRunnerOutput wrapper for async scheduling. +class AsyncModelRunnerOutput(ABC): + + @abstractmethod + def get_output(self) -> ModelRunnerOutput: + """Get the ModelRunnerOutput for this async output. + + This is a blocking call that waits until the results are ready, which + might involve copying device tensors to the host. + This method should only be called once per AsyncModelRunnerOutput. + """ + pass + + @dataclass class DraftTokenIds: diff --git a/vllm/v1/request.py b/vllm/v1/request.py index ad7477241ebbd..145af788d2372 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,6 +3,7 @@ import enum import time +from collections.abc import Mapping from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -35,6 +36,7 @@ class Request: structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, priority: int = 0, + trace_headers: Optional[Mapping[str, str]] = None, block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] = None, ) -> None: @@ -65,7 +67,7 @@ class Request: # Generative models. assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens - if sampling_params.guided_decoding is not None: + if sampling_params.structured_outputs is not None: self.status = RequestStatus.WAITING_FOR_FSM self.use_structured_output = True @@ -89,18 +91,14 @@ class Request: self.mm_features = mm_features or [] self.num_encoder_inputs = len(self.mm_features) self.has_encoder_inputs = self.num_encoder_inputs > 0 - # TODO(sfeng33): Remove these legacy fields after clearing out all - # references in scheduler and model runner - self.mm_positions = [f.mm_position for f in self.mm_features] - self.mm_kwargs = [f.data for f in self.mm_features] - self.mm_hashes = [f.identifier for f in self.mm_features] # Read-only views # Prevent directly appending to these lists since # they should also be updated simultaneously. self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) - + # trace_headers + self.trace_headers = trace_headers # State # The number of tokens with prefix cache hits. self.num_cached_tokens = -1 @@ -136,6 +134,7 @@ class Request: if request.sampling_params else None, cache_salt=request.cache_salt, priority=request.priority, + trace_headers=request.trace_headers, block_hasher=block_hasher, ) @@ -176,8 +175,8 @@ class Request: return RequestStatus.get_finished_reason(self.status) def get_num_encoder_tokens(self, input_id: int) -> int: - assert input_id < len(self.mm_positions) - num_tokens = self.mm_positions[input_id].length + assert input_id < len(self.mm_features) + num_tokens = self.mm_features[input_id].mm_position.length return num_tokens def record_event( diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 8220269162951..df944873bcaf3 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -1,16 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib +import inspect import itertools +from abc import abstractmethod from collections.abc import Sequence +from functools import partial from typing import TYPE_CHECKING, Optional, Union import torch from vllm.logger import init_logger +from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor +from vllm.sampling_params import SamplingParams from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, MinPLogitsProcessor, - MinTokensLogitsProcessor) + MinTokensLogitsProcessor, + process_dict_updates) from vllm.v1.sample.logits_processor.interface import (BatchUpdate, LogitsProcessor, MoveDirectionality) @@ -177,9 +183,112 @@ def build_logitsprocs( BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) +class AdapterLogitsProcessor(LogitsProcessor): + """Wrapper for per-request logits processors + + To wrap a specific per-request logits processor, + * Subclass `AdapterLogitsProcessor` + * Implement `self.is_argmax_invariant()` base-class method + * Implement `self.new_req_logits_processor(params)` + + `self.__init__(vllm_config, device, is_pin_memory)` does not need to be + overridden in general. However, to implement custom constructor behavior - + especially any logic which operates on or stores `vllm_config`, `device`, + or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)` + must be overridden and the override must call + `super().__init__(vllm_config, device, is_pin_memory)` + """ + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + """Subclass must invoke + `super().__init__(vllm_config, device, is_pin_memory)`. + + Subclass constructor may find it useful to utilize the `vllm_config`, + `device` and `is_pin_memory` argument. However regardless of whether + these arguments are used, the vLLM logits processor interface requires + all three arguments to be present. + """ + + # Map req index -> logits processor state + # + # State representation is a partial[Tensor] comprising a request-level + # logits processor with the output token ids argument and (if required) + # the prompt token ids argument pre-populated + # + # Note that the partial carries a *reference* to output token ids, and + # will thus always operate on the list as it is currently, not as it + # was when the partial was created. + self.req_info: dict[int, partial[torch.Tensor]] = {} + + @abstractmethod + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """Consume request info; return a per-request logits processor. + + Return None if logits processor does not need to be applied to request + + Args: + params: request sampling params + + Returns: + None if logits processor should not be applied to request; otherwise + returns a `RequestLogitsProcessor` instance + + """ + raise NotImplementedError + + def _new_state( + self, + params: SamplingParams, + prompt_ids: list[int], + output_ids: list[int], + ) -> Optional[partial[torch.Tensor]]: + """Return state representation for new request + + Returns None if logits processor is not applicable to request + + Args: + params: request sampling params + prompt_ids: request prompt token ids + output_ids: decoded tokens so far for this request + + Returns: + logits processor partial[Tensor] or None + + """ + if req_lp := self.new_req_logits_processor(params): + args = [prompt_ids, output_ids] if (len( + inspect.signature(req_lp).parameters) == 3) else [output_ids] + return partial(req_lp, *args) + return None + + def update_state(self, batch_update: Optional[BatchUpdate]): + process_dict_updates( + self.req_info, + batch_update, + self._new_state, + ) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self.req_info: + # Apply per-request logits processors to corresponding rows of + # logits tensor + for req_idx, req_lp in self.req_info.items(): + req_logits = logits[req_idx] + new_logits = req_lp(req_logits) + if new_logits is not req_logits: + # Modify logits tensor row in-place if necessary + logits[req_idx] = new_logits + return logits + + __all__ = [ "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", - "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP" + "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP", + "AdapterLogitsProcessor" ] diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index 683fc7c00dfb2..04027359909a6 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -21,6 +21,9 @@ class MoveDirectionality(Enum): SWAP = auto() +# Batch indices of any removed requests. +RemovedRequest = int + # (index, params, prompt_tok_ids, output_tok_ids) tuples for new # requests added to the batch. AddedRequest = tuple[int, SamplingParams, list[int], list[int]] @@ -29,9 +32,6 @@ AddedRequest = tuple[int, SamplingParams, list[int], list[int]] # one-way moves or two-way swaps of requests in batch MovedRequest = tuple[int, int, MoveDirectionality] -# Batch indices of any removed requests. -RemovedRequest = int - @dataclass(frozen=True) class BatchUpdate: diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py index 31cece58c7db5..0a1196559d3e3 100644 --- a/vllm/v1/sample/logits_processor/state.py +++ b/vllm/v1/sample/logits_processor/state.py @@ -36,18 +36,18 @@ class BatchUpdateBuilder: _removed: list[RemovedRequest] _is_removed_sorted: bool - moved: list[MovedRequest] added: list[AddedRequest] + moved: list[MovedRequest] def __init__( self, removed: Optional[list[RemovedRequest]] = None, - moved: Optional[list[MovedRequest]] = None, added: Optional[list[AddedRequest]] = None, + moved: Optional[list[MovedRequest]] = None, ) -> None: self._removed = removed or [] - self.moved = moved or [] self.added = added or [] + self.moved = moved or [] self._is_removed_sorted = False # Used to track changes in the pooling case @@ -107,8 +107,8 @@ class BatchUpdateBuilder: """Returns True if there were any changes to the batch.""" self._is_removed_sorted = False self._removed.clear() - self.moved.clear() self.added.clear() + self.moved.clear() batch_changed = self.batch_changed self.batch_changed = False return batch_changed diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 7bd4a5a380ac0..cc5653b10ec1d 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -73,10 +73,8 @@ class TopKTopPSampler(nn.Module): self.forward = self.forward_native else: self.forward = self.forward_native - if current_platform.is_tpu(): - self.apply_top_k_top_p = apply_top_k_top_p_tpu - else: - self.apply_top_k_top_p = apply_top_k_top_p + + self.apply_top_k_top_p = apply_top_k_top_p def forward_native( self, @@ -125,53 +123,6 @@ class TopKTopPSampler(nn.Module): return flashinfer_sample(logits.contiguous(), k, p, generators), None -def apply_top_k_top_p_tpu( - logits: torch.Tensor, - k: torch.Tensor, - p: torch.Tensor, -) -> torch.Tensor: - """ - Apply top-k and top-p optimized for TPU. - - This algorithm avoids using torch.scatter which is extremely slow on TPU. - This is achieved by finding a "cut-off" element in the original logit, and - after thresholding the logit using this cut-off, the remaining elements - shall constitute the top-p set. - - Note: in the case of tie (i.e. multipple cut-off elements present in the - logit), all tie elements are included in the top-p set. In other words, - this function does not break ties. Instead, these tie tokens have equal - chance of being chosen during final sampling, so we can consider the tie - being broken then. - """ - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) - - if k is not None: - top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) - top_k_count = top_k_count.unsqueeze(dim=1) - top_k_cutoff = probs_sort.gather(-1, top_k_count) - - # Make sure the no top-k rows are no-op. - no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) - top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) - - elements_to_discard = probs < top_k_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - if p is not None: - cumprob = torch.cumsum(probs_sort, dim=-1) - top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = False # at least one - - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) - top_p_cutoff = probs_sort.gather(-1, top_p_count) - elements_to_discard = probs < top_p_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - return logits - - def apply_top_k_top_p( logits: torch.Tensor, k: Optional[torch.Tensor], diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index e84136e3a6d07..17b83a4ba074c 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampler layer implementing TPU supported operations.""" +from typing import Optional + import torch import torch.nn as nn from vllm.v1.outputs import LogprobsTensors, SamplerOutput -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata _SAMPLING_EPS = 1e-5 @@ -17,7 +18,6 @@ class Sampler(nn.Module): def __init__(self): # TODO(houseroad): Add support for logprobs_mode. super().__init__() - self.topk_topp_sampler = TopKTopPSampler() def forward( self, @@ -65,13 +65,17 @@ class Sampler(nn.Module): logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. - random_sampled, _ = self.topk_topp_sampler( + logits = apply_top_k_top_p( logits, - sampling_metadata.generators, sampling_metadata.top_k, sampling_metadata.top_p, ) + # Random sample. + probs = logits.softmax(dim=-1, dtype=torch.float32) + random_sampled = self.random_sample(probs, + sampling_metadata.generators) + sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, random_sampled) return sampled @@ -144,3 +148,66 @@ class Sampler(nn.Module): # Apply mask using boolean indexing (xla friendly) logits.masked_fill_(~valid_token_mask, -float("inf")) return logits + + def random_sample( + self, + probs: torch.Tensor, + generators: dict[int, torch.Generator], + ) -> torch.Tensor: + q = torch.empty_like(probs) + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + q.exponential_() + if generators: + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1) + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """ + Apply top-k and top-p optimized for TPU. + + This algorithm avoids using torch.scatter which is extremely slow on TPU. + This is achieved by finding a "cut-off" element in the original logit, and + after thresholding the logit using this cut-off, the remaining elements + shall constitute the top-p set. + + Note: in the case of tie (i.e. multipple cut-off elements present in the + logit), all tie elements are included in the top-p set. In other words, + this function does not break ties. Instead, these tie tokens have equal + chance of being chosen during final sampling, so we can consider the tie + being broken then. + """ + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + + if k is not None: + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + if p is not None: + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index bf25c91d8390c..2a178ddf48777 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -27,6 +27,10 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -93,20 +97,26 @@ class EagleProposer: dtype=self.dtype, device=device) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. max_batch_size = vllm_config.scheduler_config.max_num_seqs - self.arange = torch.arange( - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - max_batch_size + 1, - device=device, - dtype=torch.int32, - ) + max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) + self.arange = torch.arange(max_num_slots_for_arange, + device=device, + dtype=torch.int32) self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True) + # Determine allowed attention backends once during initialization. self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] if current_platform.is_rocm(): @@ -155,13 +165,16 @@ class EagleProposer: target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, + last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -179,9 +192,11 @@ class EagleProposer: assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ - .build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=0) + ubatch_id = dbo_current_ubatch_id() + attn_metadata_builder = \ + self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + attn_metadata = attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. @@ -218,13 +233,19 @@ class EagleProposer: hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method in ("deepseek_mtp", "ernie_mtp"): + if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + draft_token_ids = logits.argmax(dim=-1) + return draft_token_ids.view(-1, 1) + positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] @@ -242,15 +263,12 @@ class EagleProposer: draft_token_ids = logits.argmax(dim=-1) - # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - # TODO: Currently, MTP module released by deepseek only has - # one layer. Adapt this code to support multiple layers once - # there's a multi-layer MTP module. - assert isinstance(attn_metadata, self.allowed_attn_types) + if not isinstance(attn_metadata, self.allowed_attn_types): + raise ValueError( + f"Unsupported attention metadata type for speculative " + "decoding with num_speculative_tokens > 1: " + f"{type(attn_metadata)}. Supported types are: " + f"{self.allowed_attn_types}") # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -260,10 +278,13 @@ class EagleProposer: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size - attn_metadata.num_actual_tokens = batch_size - attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): + + common_attn_metadata.num_actual_tokens = batch_size + common_attn_metadata.max_query_len = 1 + common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[:batch_size + 1]).clone() + for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -283,27 +304,38 @@ class EagleProposer: positions) # Increment the sequence lengths. - attn_metadata.max_seq_len += 1 - attn_metadata.seq_lens += 1 - # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) + common_attn_metadata.seq_lens += 1 + common_attn_metadata.seq_lens_cpu += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, + 1) + + common_attn_metadata.num_computed_tokens_cpu = \ + common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. block_numbers = clamped_positions // self.block_size - block_ids = attn_metadata.block_table.gather( + block_ids = common_attn_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) + common_attn_metadata.slot_mapping.masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID) + + # Rebuild attention metadata + attn_metadata_builder = \ + self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + attn_metadata = attn_metadata_builder\ + .build_for_drafting(common_attn_metadata=common_attn_metadata, + draft_index=token_index + 1) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -322,12 +354,18 @@ class EagleProposer: with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( + ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) + if self.method in ("deepseek_mtp", "ernie_mtp", + "qwen3_next_mtp"): + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) @@ -338,6 +376,158 @@ class EagleProposer: draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def prepare_next_token_ids_cpu( + self, sampled_token_ids: list[list[int]], + requests: dict[str, + CachedRequestState], gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int]) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = (req_state.num_computed_tokens + + num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.input_ids.device) + return next_token_ids + + def prepare_next_token_ids_padded(self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int) -> \ + tuple[torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array([ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item()) + for i in range(num_reqs) + ]) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = \ + discard_request_indices[:num_discarded_requests] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1) + + # Generate a mask for all valid tokens within those requests + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, + dtype=torch.bool) + else: + valid_mask = ( + (valid_sampled_token_ids_gpu != -1) & + (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices_safe.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, selected_tokens, + self.backup_next_token_ids.gpu[:batch_size]) + + return next_token_ids, valid_sampled_tokens_count + + def prepare_inputs_padded(self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor) -> \ + tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `token_indices_to_sample`. + No blocking CPU operations should be introduced in this function. + """ + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1] + ]) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu)) + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + + total_num_tokens = query_start_loc_cpu[-1].item() + token_indices = self.arange[:total_num_tokens] + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + seq_lens=common_attn_metadata.seq_lens, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=common_attn_metadata.seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, + ) + + token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ + - num_rejected_tokens_gpu + + return spec_common_attn_metadata, token_indices, token_indices_to_sample + def propose_tree( self, batch_size: int, @@ -349,8 +539,9 @@ class EagleProposer: hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: + ubatch_id = dbo_current_ubatch_id() tree_attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builder + self.runner.attn_groups[0][0].metadata_builders[ubatch_id] assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) @@ -510,11 +701,11 @@ class EagleProposer: def prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, - # [batch_size] - num_rejected_tokens: torch.Tensor + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ - This function is used to prepare the inputs for the spec decode. + This function is used to prepare the inputs for speculative decoding. It updates to the common_attn_metadata to account for the rejected tokens (and newly sampled tokens). It also returns the token indices of the tokens that should be fed to the speculator. @@ -535,6 +726,13 @@ class EagleProposer: # q1, q1 + 1, ..., q1 + q2 - n2 - 1, # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index b4bc3058c570a..282e6f65e7abe 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time from dataclasses import dataclass, field from typing import Optional @@ -58,6 +59,7 @@ class SpecDecodingLogging: self.num_draft_tokens: list[int] = [] self.num_accepted_tokens: list[int] = [] self.accepted_tokens_per_pos_lists: list[list[int]] = [] + self.last_log_time = time.monotonic() def observe(self, spec_decoding_stats: SpecDecodingStats): self.num_drafts.append(spec_decoding_stats.num_drafts) @@ -73,6 +75,13 @@ class SpecDecodingLogging: num_drafts = np.sum(self.num_drafts) num_draft_tokens = np.sum(self.num_draft_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens) + draft_throughput = 0 + accepted_throughput = 0 + + elapsed_time = time.monotonic() - self.last_log_time + if elapsed_time > 0: + draft_throughput = num_draft_tokens / elapsed_time + accepted_throughput = num_accepted_tokens / elapsed_time draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * 100 if num_draft_tokens > 0 else float("nan")) @@ -86,16 +95,20 @@ class SpecDecodingLogging: log_fn( "SpecDecoding metrics: " - "Draft acceptance rate: %.1f%%, " "Mean acceptance length: %.2f, " + "Accepted throughput: %.2f tokens/s, " + "Drafted throughput: %.2f tokens/s, " "Accepted: %d tokens, " "Drafted: %d tokens, " - "Per-position acceptance rate: %s", - draft_acceptance_rate, + "Per-position acceptance rate: %s, " + "Avg Draft acceptance rate: %.1f%%", mean_acceptance_length, + accepted_throughput, + draft_throughput, num_accepted_tokens, num_draft_tokens, rates_str, + draft_acceptance_rate, ) self.reset() @@ -127,27 +140,32 @@ class SpecDecodingProm: self, speculative_config: Optional[SpeculativeConfig], labelnames: list[str], - labelvalues: list[str], + per_engine_labelvalues: dict[int, list[str]], ): self.spec_decoding_enabled = speculative_config is not None if not self.spec_decoding_enabled: return - self.counter_spec_decode_num_drafts = \ - self._counter_cls( - name="vllm:spec_decode_num_drafts", - documentation="Number of spec decoding drafts.", - labelnames=labelnames).labels(*labelvalues) - self.counter_spec_decode_num_draft_tokens = \ - self._counter_cls( - name="vllm:spec_decode_num_draft_tokens", - documentation="Number of draft tokens.", - labelnames=labelnames,).labels(*labelvalues) - self.counter_spec_decode_num_accepted_tokens = \ - self._counter_cls( - name="vllm:spec_decode_num_accepted_tokens", - documentation="Number of accepted tokens.", - labelnames=labelnames).labels(*labelvalues) + counter_drafts = self._counter_cls( + name="vllm:spec_decode_num_drafts", + documentation="Number of spec decoding drafts.", + labelnames=labelnames) + self.counter_spec_decode_num_drafts = make_per_engine( + counter_drafts, per_engine_labelvalues) + + counter_draft_tokens = self._counter_cls( + name="vllm:spec_decode_num_draft_tokens", + documentation="Number of draft tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_draft_tokens = make_per_engine( + counter_draft_tokens, per_engine_labelvalues) + + counter_accepted_tokens = self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens", + documentation="Number of accepted tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_accepted_tokens = make_per_engine( + counter_accepted_tokens, per_engine_labelvalues) assert speculative_config is not None num_spec_tokens = (speculative_config.num_speculative_tokens @@ -158,21 +176,36 @@ class SpecDecodingProm: documentation="Accepted tokens per draft position.", labelnames=pos_labelnames, ) - self.counter_spec_decode_num_accepted_tokens_per_pos: list[ - prometheus_client.Counter] = [] - for pos in range(num_spec_tokens): - pos_labelvalues = labelvalues + [str(pos)] - self.counter_spec_decode_num_accepted_tokens_per_pos.append( - base_counter.labels(*pos_labelvalues)) + self.counter_spec_decode_num_accepted_tokens_per_pos: dict[ + int, list[prometheus_client.Counter]] = { + idx: [ + base_counter.labels(*lv, str(pos)) + for pos in range(num_spec_tokens) + ] + for idx, lv in per_engine_labelvalues.items() + } - def observe(self, spec_decoding_stats: SpecDecodingStats): + def observe(self, + spec_decoding_stats: SpecDecodingStats, + engine_idx: int = 0): if not self.spec_decoding_enabled: return - self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts) - self.counter_spec_decode_num_draft_tokens.inc( + self.counter_spec_decode_num_drafts[engine_idx].inc( + spec_decoding_stats.num_drafts) + self.counter_spec_decode_num_draft_tokens[engine_idx].inc( spec_decoding_stats.num_draft_tokens) - self.counter_spec_decode_num_accepted_tokens.inc( + self.counter_spec_decode_num_accepted_tokens[engine_idx].inc( spec_decoding_stats.num_accepted_tokens) for pos, counter in enumerate( - self.counter_spec_decode_num_accepted_tokens_per_pos): + self. + counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) + + +def make_per_engine(counter: prometheus_client.Counter, + per_engine_labelvalues: dict[int, list[str]]): + """Create a counter for each label value.""" + return { + idx: counter.labels(*labelvalues) + for idx, labelvalues in per_engine_labelvalues.items() + } diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 57854cc112041..13c33d3edf141 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -60,15 +60,12 @@ class StructuredOutputManager: max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) self.tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config, - scheduler_config=self.vllm_config.scheduler_config, - lora_config=self.vllm_config.lora_config, - ).get_lora_tokenizer(None) - reasoning_backend = \ - self.vllm_config.decoding_config.reasoning_backend - if reasoning_backend: + model_config=self.vllm_config.model_config) + reasoning_parser = \ + self.vllm_config.structured_outputs_config.reasoning_parser + if reasoning_parser: reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_backend) + reasoning_parser) self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: @@ -77,15 +74,16 @@ class StructuredOutputManager: if TYPE_CHECKING: assert request.sampling_params is not None and \ - request.sampling_params.guided_decoding is not None + request.sampling_params.structured_outputs is not None # Initialize the backend the first time it is needed. # # NOTE: We only support a single backend. We do NOT support different # backends on a per-request basis in V1 (for now, anyway...). + # _backend is set in Processor._validate_structured_output if self.backend is None: assert request.sampling_params is not None - backend = request.sampling_params.guided_decoding.backend + backend = request.sampling_params.structured_outputs._backend vocab_size = self.vllm_config.model_config.get_vocab_size() if backend == "xgrammar": self.backend = XgrammarBackend( diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 02e7fc33f517d..e06ab6377de3a 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -60,9 +60,9 @@ class GuidanceBackend(StructuredOutputBackend): def __post_init__(self): self.disable_any_whitespace = \ - self.vllm_config.decoding_config.disable_any_whitespace + self.vllm_config.structured_outputs_config.disable_any_whitespace self.disable_additional_properties = \ - self.vllm_config.decoding_config.disable_additional_properties + self.vllm_config.structured_outputs_config.disable_additional_properties self.ll_tokenizer = llguidance_hf.from_tokenizer( self.tokenizer, self.vocab_size) diff --git a/vllm/v1/structured_output/backend_lm_format_enforcer.py b/vllm/v1/structured_output/backend_lm_format_enforcer.py index 2279a1c8c8a00..465b2428f8938 100644 --- a/vllm/v1/structured_output/backend_lm_format_enforcer.py +++ b/vllm/v1/structured_output/backend_lm_format_enforcer.py @@ -138,30 +138,30 @@ class LMFormatEnforcerBackend(StructuredOutputBackend): def validate_structured_output_request_lm_format_enforcer( params: SamplingParams): - if params.guided_decoding is None: + if params.structured_outputs is None: return - gd_params = params.guided_decoding + so_params = params.structured_outputs - if gd_params.regex: + if so_params.regex: return - elif gd_params.json: - if isinstance(gd_params.json, str): + elif so_params.json: + if isinstance(so_params.json, str): try: # make sure schema is valid json - json.loads(gd_params.json) + json.loads(so_params.json) except json.JSONDecodeError as e: raise ValueError("Invalid JSON grammar specification.") from e else: try: - json.dumps(gd_params.json) + json.dumps(so_params.json) except Exception as e: raise ValueError( - f"Error serializing guided decoding jsonschema: {e}" + f"Error serializing structured outputs jsonschema: {e}" ) from e return - elif gd_params.choice: + elif so_params.choice: return - elif gd_params.grammar: - raise ValueError("LM Format Enforcer guided decoding backend " + elif so_params.grammar: + raise ValueError("LM Format Enforcer structured outputs backend " "does not support grammar specifications") diff --git a/vllm/v1/structured_output/backend_outlines.py b/vllm/v1/structured_output/backend_outlines.py index 572e4984480fa..e5e638a6ad764 100644 --- a/vllm/v1/structured_output/backend_outlines.py +++ b/vllm/v1/structured_output/backend_outlines.py @@ -158,36 +158,36 @@ class OutlinesGrammar(StructuredOutputGrammar): def validate_structured_output_request_outlines(params: SamplingParams): - if params.guided_decoding is None: + if params.structured_outputs is None: return - gd_params = params.guided_decoding + so_params = params.structured_outputs - if gd_params.regex: - validate_regex_is_buildable(gd_params.regex) - elif gd_params.json: - if isinstance(gd_params.json, str): + if so_params.regex: + validate_regex_is_buildable(so_params.regex) + elif so_params.json: + if isinstance(so_params.json, str): try: # make sure schema is valid json - json.loads(gd_params.json) - schema = gd_params.json + json.loads(so_params.json) + schema = so_params.json except json.JSONDecodeError as e: raise ValueError("Invalid JSON grammar specification.") from e else: try: - schema = json.dumps(gd_params.json) + schema = json.dumps(so_params.json) except Exception as e: raise ValueError( - f"Error serializing guided decoding jsonschema: {e}" + f"Error serializing structured outputs jsonschema: {e}" ) from e pattern = json_schema.build_regex_from_schema(schema) validate_regex_is_buildable(pattern) - elif gd_params.choice: - choices = [regex_escape(str(choice)) for choice in gd_params.choice] + elif so_params.choice: + choices = [regex_escape(str(choice)) for choice in so_params.choice] regex = "(" + "|".join(choices) + ")" validate_regex_is_buildable(regex) - elif gd_params.grammar: - raise ValueError("Outlines guided decoding backend " + elif so_params.grammar: + raise ValueError("Outlines structured outputs backend " "does not support grammar specifications") @@ -306,7 +306,7 @@ def validate_regex_is_buildable(pattern: str) -> None: _check_unsupported(parsed) except ValueError as e: raise ValueError( - f"Regex uses unsupported feature for guided decoding: {e}. " + f"Regex uses unsupported feature for structured outputs: {e}. " "Only basic matching constructs are supported—lookarounds, " "backreferences, and unicode boundaries are not.") from e @@ -315,6 +315,6 @@ def validate_regex_is_buildable(pattern: str) -> None: "Regex does not have a anchored universal start state" "This means that the Regex uses anchors (^) or look-arounds " "in a way which requires context before any token is matched." - "Guided decoding needs regexes that can match without needing " + "structured outputs needs regexes that can match without needing " "that context. Try rewriting the pattern without using these " f"constructs. Pattern:\n{pattern}") diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 5e00f63804162..55b4792fe010d 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -34,7 +34,7 @@ class XgrammarBackend(StructuredOutputBackend): def __post_init__(self): self.disable_any_whitespace = \ - self.vllm_config.decoding_config.disable_any_whitespace + self.vllm_config.structured_outputs_config.disable_any_whitespace if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. @@ -248,37 +248,37 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: Raises ValueError if the request is not supported. """ - if sampling_params.guided_decoding is None: + if sampling_params.structured_outputs is None: return - gd_params = sampling_params.guided_decoding + so_params = sampling_params.structured_outputs - if gd_params.regex: + if so_params.regex: try: - xgr.Grammar.from_regex(gd_params.regex) + xgr.Grammar.from_regex(so_params.regex) except Exception as err: raise ValueError("Failed to transform regex into a grammar: " f"{err}") from err - if gd_params.choice: - choice_grammar = choice_as_grammar(gd_params.choice) + if so_params.choice: + choice_grammar = choice_as_grammar(so_params.choice) try: xgr.Grammar.from_ebnf(choice_grammar) except Exception as err: raise ValueError("Failed to transform choices into a grammar: " "{err}") from err - gd_params.choice = None - gd_params.grammar = choice_grammar + so_params.choice = None + so_params.grammar = choice_grammar return - if gd_params.json: - if isinstance(gd_params.json, str): + if so_params.json: + if isinstance(so_params.json, str): try: - schema = json.loads(gd_params.json) + schema = json.loads(so_params.json) except json.JSONDecodeError as e: raise ValueError("Invalid JSON grammar specification.") from e else: - schema = gd_params.json + schema = so_params.json try: xgr.Grammar.from_json_schema(schema) @@ -291,11 +291,11 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: "supported by xgrammar.") return - if gd_params.grammar: - if grammar_is_likely_lark(gd_params.grammar): + if so_params.grammar: + if grammar_is_likely_lark(so_params.grammar): # xgrammar supports EBNF grammars only try: - gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) + so_params.grammar = convert_lark_to_ebnf(so_params.grammar) except ValueError as e: raise ValueError( "Failed to convert the grammar from Lark to EBNF. ") from e @@ -303,14 +303,14 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: # Test parsing EBNF grammar, possibly already converted from Lark try: # parse the grammar, but we aren't compiling it. - xgr.Grammar.from_ebnf(gd_params.grammar) + xgr.Grammar.from_ebnf(so_params.grammar) except Exception as e: raise ValueError("Invalid grammar specification.") from e return - if gd_params.structural_tag: + if so_params.structural_tag: try: - s_tag = json.loads(gd_params.structural_tag) + s_tag = json.loads(so_params.structural_tag) tags = [ xgr.StructuralTagItem( begin=s["begin"], diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index fc365f12573fc..99974ef46ecd5 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -60,7 +60,7 @@ class StructuredOutputRequest: def get_structured_output_key( sampling_params: SamplingParams) -> StructuredOutputKey: - params = sampling_params.guided_decoding + params = sampling_params.structured_outputs assert params is not None, "params can't be None." if params.json is not None: if not isinstance(params.json, str): diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 95319831d5121..127c8876525b5 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -8,7 +8,9 @@ import importlib.metadata import os from typing import TYPE_CHECKING +import numpy as np import regex as re +import torch from cachetools import LRUCache from diskcache import Cache @@ -20,9 +22,13 @@ if TYPE_CHECKING: import outlines_core as oc import transformers.file_utils as file_utils import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2 + import xgrammar as xgr from vllm.transformers_utils.tokenizer import AnyTokenizer + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch else: + xgr = LazyLoader("xgr", globals(), "xgrammar") oc = LazyLoader("oc", globals(), "outlines_core") file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils") tokenization_gpt2 = LazyLoader( @@ -36,6 +42,80 @@ logger = init_logger(__name__) CACHE = None +def apply_grammar_bitmask( + scheduler_output: SchedulerOutput, + input_batch: InputBatch, + logits: torch.Tensor, + device: torch.device, +) -> None: + """ + Apply grammar bitmask to output logits of the model with xgrammar function. + + Args: + scheduler_output (SchedulerOutput): The result of engine scheduling. + input_batch (InputBatch): The input of model runner. + logits (torch.Tensor): The output logits of model forward. + device (torch.device): The device that model runner running on. + """ + grammar_bitmask = scheduler_output.grammar_bitmask + if grammar_bitmask is None: + return + + # We receive the structured output bitmask from the scheduler, + # compacted to contain bitmasks only for structured output requests. + # The order of the requests in the bitmask is not guaranteed to be the + # same as the order of the requests in the gpu runner's batch. We need + # to sort the bitmask to match the order of the requests used here. + + # Get the batch indices of the structured output requests. + # Keep track of the number of speculative tokens scheduled for every + # request in the batch, as the logit indices are offset by this amount. + struct_out_req_batch_indices: dict[str, int] = {} + cumulative_offset = 0 + seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1]) + for req_id, batch_index in seq: + logit_index = batch_index + cumulative_offset + cumulative_offset += len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + if req_id in scheduler_output.structured_output_request_ids: + struct_out_req_batch_indices[req_id] = logit_index + + out_indices = [] + + # Reorder the bitmask to match the order of the requests in the batch. + sorted_bitmask = np.full(shape=(logits.shape[0], grammar_bitmask.shape[1]), + fill_value=-1, + dtype=grammar_bitmask.dtype) + cumulative_index = 0 + seq = sorted(scheduler_output.structured_output_request_ids.items(), + key=lambda x: x[1]) + for req_id, _ in seq: + logit_index = struct_out_req_batch_indices[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + for i in range(1 + num_spec_tokens): + sorted_bitmask[logit_index + i] = \ + grammar_bitmask[cumulative_index + i] + out_indices.append(logit_index + i) + cumulative_index += 1 + num_spec_tokens + grammar_bitmask = sorted_bitmask + + # If the length of out indices and the logits have the same shape + # we don't need to pass indices to the kernel, + # since the bitmask is already aligned with the logits. + skip_out_indices = len(out_indices) == logits.shape[0] + + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() + + xgr.apply_token_bitmask_inplace( + logits, + grammar_bitmask.to(device, non_blocking=True), + indices=out_indices if not skip_out_indices else None, + ) + + class OutlinesVocabulary: """ Wrapper class for `outlines_core.Vocabulary`, @@ -65,9 +145,9 @@ def get_outlines_cache_path() -> str: elif xdg_cache_home: return os.path.join(xdg_cache_home, ".cache", "outlines") # If homedir is "/", we may be inside a container, and thus writing to - # root would be problematic, so we fallback to using a tempfile. + # root would be problematic, so we fall back to using a tempfile. # Also validate the path exists, since os.path.expanduser does - # not garuntee existence. + # not guarantee existence. elif os.path.isdir(home_dir) and home_dir != "/": # Default Unix fallback: ~/.cache/outlines return os.path.join(home_dir, ".cache", "outlines") diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 8f9face6fbf2e..fd84b4a111f58 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,17 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse +import contextlib import multiprocessing import time import weakref from collections.abc import Sequence +from contextlib import AbstractContextManager from multiprocessing import connection from multiprocessing.process import BaseProcess from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union, overload) import torch +from torch.autograd.profiler import record_function +import vllm.envs as envs from vllm.logger import init_logger from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) @@ -19,6 +23,8 @@ from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, kill_process_tree) if TYPE_CHECKING: + import numpy as np + from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.utils import (CoreEngineActorManager, CoreEngineProcManager) @@ -97,20 +103,31 @@ class ConstantList(Generic[T], Sequence): class CpuGpuBuffer: + """Buffer to easily copy tensors between CPU and GPU.""" def __init__( self, - *args, + *size: Union[int, torch.SymInt], dtype: torch.dtype, device: torch.device, pin_memory: bool, - ): - self.cpu = torch.zeros(*args, + with_numpy: bool = True, + ) -> None: + self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory) - self.np = self.cpu.numpy() self.gpu = self.cpu.to(device) + self.np: np.ndarray + # To keep type hints simple (avoiding generics and subclasses), we + # only conditionally create the numpy array attribute. This can cause + # AttributeError if `self.np` is accessed when `with_numpy=False`. + if with_numpy: + if dtype == torch.bfloat16: + raise ValueError( + "Bfloat16 torch tensors cannot be directly cast to a " + "numpy array, so call CpuGpuBuffer with with_numpy=False") + self.np = self.cpu.numpy() def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor: if n is None: @@ -142,7 +159,7 @@ def get_engine_client_zmq_addr(local_only: bool, class APIServerProcessManager: """Manages a group of API server processes. - + Handles creation, monitoring, and termination of API server worker processes. Also monitors extra processes to check if they are healthy. """ @@ -159,7 +176,7 @@ class APIServerProcessManager: stats_update_address: Optional[str] = None, ): """Initialize and start API server worker processes. - + Args: target_server_fn: Function to call for each API server process listen_address: Address to listen for client connections @@ -168,7 +185,7 @@ class APIServerProcessManager: num_servers: Number of API server processes to start input_addresses: Input addresses for each API server output_addresses: Output addresses for each API server - stats_update_address: Optional stats update address + stats_update_address: Optional stats update address """ self.listen_address = listen_address self.sock = sock @@ -212,7 +229,7 @@ def wait_for_completion_or_failure( "CoreEngineActorManager"]] = None, coordinator: Optional["DPCoordinator"] = None) -> None: """Wait for all processes to complete or detect if any fail. - + Raises an exception if any process exits with a non-zero status. Args: @@ -338,7 +355,8 @@ def report_usage_stats( vllm_config.cache_config.block_size, "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization, - + "kv_cache_memory_bytes": + vllm_config.cache_config.kv_cache_memory_bytes, # Quantization "quantization": vllm_config.model_config.quantization, @@ -355,3 +373,10 @@ def report_usage_stats( "disable_custom_all_reduce": vllm_config.parallel_config.disable_custom_all_reduce, }) + + +def record_function_or_nullcontext(name: str) -> AbstractContextManager: + if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING: + return record_function(name) + else: + return contextlib.nullcontext() diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 6ab5ce2748a4a..82b6d1b514d50 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Union import numpy as np import torch +from vllm.distributed import get_dcp_group from vllm.logger import init_logger from vllm.utils import cdiv +from vllm.v1.utils import CpuGpuBuffer logger = init_logger(__name__) @@ -28,28 +31,20 @@ class BlockTable: self.pin_memory = pin_memory self.device = device - self.block_table = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32, - ) - self.block_table_cpu = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_np = self.block_table_cpu.numpy() + self.block_table = self._make_buffer(max_num_reqs, + max_num_blocks_per_req, + dtype=torch.int32) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) - self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() - self.slot_mapping = torch.zeros(self.max_num_batched_tokens, - dtype=torch.int64, - device=self.device) + self.slot_mapping = self._make_buffer(self.max_num_batched_tokens, + dtype=torch.int64) + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 def append_row( self, @@ -61,7 +56,7 @@ class BlockTable: num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks - self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.block_table.np[row_idx, start:start + num_blocks] = block_ids def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 @@ -69,17 +64,14 @@ class BlockTable: def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] - self.block_table_np[tgt, :num_blocks] = self.block_table_np[ - src, :num_blocks] + block_table_np = self.block_table.np + block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks] self.num_blocks_per_row[tgt] = num_blocks def swap_row(self, src: int, tgt: int) -> None: - num_blocks_src = self.num_blocks_per_row[src] - num_blocks_tgt = self.num_blocks_per_row[tgt] - self.num_blocks_per_row[src] = num_blocks_tgt - self.num_blocks_per_row[tgt] = num_blocks_src - - self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] + src_tgt, tgt_src = [src, tgt], [tgt, src] + self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src] + self.block_table.np[src_tgt] = self.block_table.np[tgt_src] def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None: @@ -89,50 +81,94 @@ class BlockTable: # NOTE(woosuk): We can't simply use `token_indices // block_size` # here because M (max_model_len) is not necessarily divisible by # block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // self.block_size) - block_numbers = self.block_table_np.ravel()[block_table_indices] - block_offsets = positions % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:req_indices.shape[0]]) + if self.dcp_world_size > 1: + # Note(hc): The DCP implement store kvcache with an interleave + # style, the kvcache for the token whose token_idx is i is + # always stored on the GPU whose dcp_rank equals i % cp_world_size: + + # Use a "virtual block" which equals to world_size * block_size + # for block_table_indices calculation. + virtual_block_size = self.block_size * self.dcp_world_size + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions // virtual_block_size) + block_numbers = self.block_table.np.ravel()[block_table_indices] + # Use virtual_block_size for mask calculation, which marks local + # tokens. + virtual_block_offsets = positions % virtual_block_size + mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + # Calculate local block_offsets + block_offsets = virtual_block_offsets // self.dcp_world_size + # Calculate slot_mapping + slot_mapping = block_numbers * self.block_size + block_offsets + # Write final slots, use -1 for not-local + self.slot_mapping.np[:req_indices.shape[0]] = np.where( + mask, slot_mapping, -1) + else: + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions // self.block_size) + block_numbers = self.block_table.np.ravel()[block_table_indices] + block_offsets = positions % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping.np[:req_indices.shape[0]]) def commit_block_table(self, num_reqs: int) -> None: - self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], - non_blocking=True) + self.block_table.copy_to_gpu(num_reqs) def commit_slot_mapping(self, num_tokens: int) -> None: - self.slot_mapping[:num_tokens].copy_( - self.slot_mapping_cpu[:num_tokens], non_blocking=True) + self.slot_mapping.copy_to_gpu(num_tokens) def clear(self) -> None: - self.block_table.fill_(0) - self.block_table_cpu.fill_(0) + self.block_table.gpu.fill_(0) + self.block_table.cpu.fill_(0) - def get_device_tensor(self) -> torch.Tensor: + def get_device_tensor(self, num_reqs: int) -> torch.Tensor: """Returns the device tensor of the block table.""" - return self.block_table + return self.block_table.gpu[:num_reqs] def get_cpu_tensor(self) -> torch.Tensor: """Returns the CPU tensor of the block table.""" - return self.block_table_cpu + return self.block_table.cpu def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" - return self.block_table_np + return self.block_table.np + + def _make_buffer(self, *size: Union[int, torch.SymInt], + dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer(*size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory) class MultiGroupBlockTable: """The BlockTables for each KV cache group.""" - def __init__(self, max_num_reqs: int, max_model_len: int, - max_num_batched_tokens: int, pin_memory: bool, - device: torch.device, block_sizes: list[int]) -> None: + def __init__(self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + num_speculative_tokens: int = 0) -> None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + try: + dcp_world_size = get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + dcp_world_size = 1 + self.block_tables = [ - BlockTable(block_size, max_num_reqs, cdiv(max_model_len, - block_size), - max_num_batched_tokens, pin_memory, device) - for block_size in block_sizes + BlockTable( + block_size, max_num_reqs, + max(cdiv(max_model_len, block_size * dcp_world_size), + 1 + num_speculative_tokens), max_num_batched_tokens, + pin_memory, device) for block_size in block_sizes ] def append_row(self, block_ids: tuple[list[int], ...], diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index feb49978d7518..cd0f0af43e7e7 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -55,11 +55,28 @@ class CPUModelRunner(GPUModelRunner): raise ValueError("Multiple KVCacheGroups is not" "currently supported with CPU model runner.") - assert type(self.attn_groups[0] - [0].metadata_builder) is TorchSDPAMetadataBuilderV1 + # Guard against encoder-only / pooling models where `attn_groups` + # may be empty or lack the expected metadata_builder. + # Without this check, accessing `attn_groups[0][0]` would trigger + # an AssertionError on CPU backend. + if not hasattr(self, "attn_groups") or not self.attn_groups: + return + if not self.attn_groups[0]: + return - self.attn_groups[0][0].metadata_builder.reorder_batch( - self.input_batch, scheduler_output) + mb = getattr(self.attn_groups[0][0], "metadata_builders", None) + if isinstance(mb, list): + if not isinstance(mb[0], TorchSDPAMetadataBuilderV1): + return + mb[0].reorder_batch(self.input_batch, scheduler_output) + return + elif not isinstance(mb, TorchSDPAMetadataBuilderV1): + # Encoder-only / rerank models do not benefit from reordering, + # so we safely skip here. + return + + # Safe path for decoder/attention-heavy models + mb.reorder_batch(self.input_batch, scheduler_output) def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors @@ -72,7 +89,7 @@ class CPUModelRunner(GPUModelRunner): assert isinstance(device_tensor, torch.Tensor) setattr(obj, device_attr_name, cpu_tensor) - for k, v in vars(self).items(): + for v in vars(self).values(): if isinstance(v, CpuGpuBuffer): v.gpu = v.cpu @@ -81,9 +98,9 @@ class CPUModelRunner(GPUModelRunner): replace_tensor(self.input_batch, k, k[:-11]) for block_table in self.input_batch.block_table.block_tables: - for k, v in vars(block_table).items(): - if k.endswith("_cpu") and isinstance(v, torch.Tensor): - replace_tensor(block_table, k, k[:-4]) + for v in vars(block_table).values(): + if isinstance(v, CpuGpuBuffer): + v.gpu = v.cpu def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) @@ -128,12 +145,20 @@ def _torch_cuda_wrapper(): self.record = lambda: None self.synchronize = lambda: None + class _StreamPlaceholder: + + def __init__(self, *args, **kwargs) -> None: + pass + cuda_event = torch.cuda.Event + cuda_stream = torch.cuda.Stream try: torch.cuda.Event = _EventPlaceholder + torch.cuda.Stream = _StreamPlaceholder yield finally: torch.cuda.Event = cuda_event + torch.cuda.Stream = cuda_stream @contextmanager diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index b87c4fe09bb90..daee91ec404fe 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -19,6 +19,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.cpu_model_runner import CPUModelRunner from vllm.v1.worker.gpu_worker import (Worker, init_worker_distributed_environment) +from vllm.v1.worker.utils import is_residual_scattered_for_sp logger = init_logger(__name__) @@ -107,18 +108,29 @@ class CPUWorker(Worker): scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: intermediate_tensors = None + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = self.model_runner._get_num_input_tokens( + num_scheduled_tokens) + all_gather_tensors = { + "residual": + not is_residual_scattered_for_sp(self.vllm_config, + num_input_tokens) + } if not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors)) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) if not get_pp_group().is_last_rank: assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) + get_pp_group().send_tensor_dict( + output.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors) return None assert isinstance(output, ModelRunnerOutput) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ef5a7e39a5b16..6717622efb801 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,8 +10,7 @@ import torch from typing_extensions import deprecated from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import (MultiModalKwargsItem, - MultiModalKwargsItems, PlaceholderRange) +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values @@ -31,9 +30,7 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] - mm_kwargs: list[MultiModalKwargsItem] - mm_positions: list[PlaceholderRange] - mm_hashes: list[str] + mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -60,13 +57,17 @@ class CachedRequestState: "removed in v0.13. Please use `mm_kwargs` instead.") def mm_inputs(self) -> list[MultiModalKwargsItems]: return [ - MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs + MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features + if f.data is not None ] def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - return self.output_token_ids[idx - self.num_prompt_tokens] + elif idx - self.num_prompt_tokens < len(self.output_token_ids): + return self.output_token_ids[idx - self.num_prompt_tokens] + else: + return -1 class InputBatch: @@ -83,6 +84,7 @@ class InputBatch: logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, is_pooling_model: bool = False, + num_speculative_tokens: int = 0, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -127,6 +129,7 @@ class InputBatch: pin_memory=pin_memory, device=device, block_sizes=block_sizes, + num_speculative_tokens=num_speculative_tokens, ) # Sampling-related. @@ -202,6 +205,14 @@ class InputBatch: self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() + # Speculative decoding + self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ), + dtype=torch.int64, + device="cpu", + pin_memory=pin_memory) + self.num_accepted_tokens_cpu = \ + self.num_accepted_tokens_cpu_tensor.numpy() + # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), dtype=np.int32) @@ -250,6 +261,11 @@ class InputBatch: self.pooling_params: dict[str, PoolingParams] = {} + # Cached reference to the GPU tensor of previously sampled tokens + self.prev_sampled_token_ids: Optional[torch.Tensor] = None + self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None + self.prev_req_id_to_index: Optional[dict[str, int]] = None + @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -355,8 +371,9 @@ class InputBatch: if sampling_params.logprobs == -1 else sampling_params.logprobs) if sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[ - req_id] = sampling_params.prompt_logprobs + self.num_prompt_logprobs[req_id] = ( + self.vocab_size if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs) if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) @@ -388,6 +405,9 @@ class InputBatch: else: raise NotImplementedError("Unrecognized request type") + # Speculative decoding: by default 1 token is generated. + self.num_accepted_tokens_cpu[req_index] = 1 + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -509,6 +529,8 @@ class InputBatch: self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\ + self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1] swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) @@ -584,7 +606,7 @@ class InputBatch: if self.is_pooling_model: last_req_index -= 1 - # Samping state not used by pooling models. + # Sampling state not used by pooling models. continue # Autoregressive models require detailed tracking of condense @@ -603,6 +625,8 @@ class InputBatch: empty_index] = self.presence_penalties_cpu[last_req_index] self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] + self.num_accepted_tokens_cpu[ + empty_index] = self.num_accepted_tokens_cpu[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 96dafd6add679..4873b586724ec 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -15,6 +15,7 @@ import torch import torch.distributed import torch.nn as nn from tqdm import tqdm +from typing_extensions import TypeAlias import vllm.envs as envs from vllm.attention import Attention, AttentionType @@ -28,6 +29,7 @@ from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) @@ -52,22 +54,28 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, cdiv, check_use_alibi, - get_dtype_size, is_pin_memory_available, round_up, - supports_dynamo) + GiB_bytes, check_use_alibi, get_dtype_size, + is_pin_memory_available, round_up, supports_dynamo) +from vllm.v1.attention.backends.flash_attn import AttentionMetadata +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, - reorder_batch_to_split_decodes_and_prefills) + reorder_batch_to_split_decodes_and_prefills, split_attn_metadata) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +# yapf conflicts with isort for this block +# yapf: disable from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, + CrossAttentionSpec, EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - LogprobsTensors, ModelRunnerOutput) +# yapf: enable +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, LogprobsLists, LogprobsTensors, + ModelRunnerOutput, SamplerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -77,11 +85,16 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.structured_output.utils import apply_grammar_bitmask +from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin, KVConnectorOutput) + KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split +from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices +from vllm.v1.worker.utils import is_residual_scattered_for_sp from .utils import (AttentionGroup, MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, @@ -89,19 +102,63 @@ from .utils import (AttentionGroup, MultiModalBudget, scatter_mm_placeholders) if TYPE_CHECKING: - import xgrammar as xgr - import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") - xgr_torch_compile = LazyLoader( - "xgr_torch_compile", globals(), - "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") logger = init_logger(__name__) +AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] +# list when ubatching is enabled +PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], + AttnMetadataDict] + + +# Wrapper for ModelRunnerOutput to support overlapped execution. +class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): + + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampled_token_ids: torch.Tensor, + invalid_req_indices: list[int], + async_output_copy_stream: torch.cuda.Stream, + ): + self._model_runner_output = model_runner_output + self._invalid_req_indices = invalid_req_indices + + # Event on the copy stream so we can synchronize the non-blocking copy. + self._async_copy_ready_event = torch.cuda.Event() + + # Keep a reference to the device tensor to avoid it being + # deallocated until we finish copying it to the host. + self._sampled_token_ids = sampled_token_ids + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self._sampled_token_ids_cpu = self._sampled_token_ids.to( + 'cpu', non_blocking=True) + self._async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + + This function blocks until the copy is finished. + """ + self._async_copy_ready_event.synchronize() + + # Release the device tensor once the copy has completed + del self._sampled_token_ids + + valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + + output = self._model_runner_output + output.sampled_token_ids = valid_sampled_token_ids + return output + class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @@ -143,9 +200,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model_config.is_multimodal_raw_input_only_model) self.max_model_len = model_config.max_model_len + self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + self.broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend + == "external_launcher" and len(get_pp_group().ranks) > 0) + # Model-related. self.num_query_heads = model_config.get_num_attention_heads( parallel_config) @@ -162,6 +228,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( model_config) + if self.model_config.is_encoder_decoder: + # Maximum length of the encoder input, only for encoder-decoder + # models. + self.max_encoder_len = scheduler_config.\ + max_num_encoder_input_tokens + else: + self.max_encoder_len = 0 + # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -207,6 +281,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Request states. self.requests: dict[str, CachedRequestState] = {} + self.comm_stream = torch.cuda.Stream() # Input Batch # NOTE(Chen): Ideally, we should initialize the input batch inside @@ -219,7 +294,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # the block_sizes in the kv cache config. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + # We need to use the encoder length for encoder-decoer + # because of KV cache for cross-attention. + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -233,6 +310,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): is_pooling_model=self.is_pooling_model, ) + self.use_async_scheduling = self.scheduler_config.async_scheduling + self.async_output_copy_stream = torch.cuda.Stream() if \ + self.use_async_scheduling else None + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -253,10 +334,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, dtype=torch.int32) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + # Because inputs_embeds may be bfloat16 and we don't need a numpy + # version of this tensor, avoid a RuntimeError by not creating a + # numpy buffer. + self.inputs_embeds = self._make_buffer(self.max_num_tokens, + self.hidden_size, + dtype=self.dtype, + numpy=False) + self.discard_request_indices = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) + self.num_discarded_requests = 0 + + self.num_draft_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -273,6 +365,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mrope_positions = self._make_buffer( (3, self.max_num_tokens + 1), dtype=torch.int64) + # CUDA event to synchronize use of reused CPU tensors between steps + # when async scheduling is enabled. + self.prepare_inputs_event: Optional[torch.cuda.Event] = None + if self.use_async_scheduling: + self.prepare_inputs_event = torch.cuda.Event() + # Start in a completed state. + self.prepare_inputs_event.record(torch.cuda.default_stream()) + # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -301,11 +401,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) - self.mm_budget = (MultiModalBudget( + self.mm_budget = MultiModalBudget( self.model_config, self.scheduler_config, self.mm_registry, - ) if self.supports_mm_inputs else None) + ) if self.supports_mm_inputs else None self.reorder_batch_threshold: Optional[int] = None @@ -324,11 +424,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device="cpu", pin_memory=self.pin_memory) - def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer(*args, + def _make_buffer(self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + numpy: bool = True) -> CpuGpuBuffer: + return CpuGpuBuffer(*size, dtype=dtype, device=self.device, - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + with_numpy=numpy) def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() @@ -380,6 +484,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return if self.reorder_batch_threshold is not None: + # NOTE(lucas): currently no backend supports the custom masking + # required for DCP with q_len > 1, so we assert here. Remove this + # assert once the custom mask is support is added to FA3. + if self.dcp_world_size > 1: + assert self.reorder_batch_threshold == 1, \ + "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, @@ -463,9 +573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_kwargs=new_req_data.mm_kwargs, - mm_positions=new_req_data.mm_positions, - mm_hashes=new_req_data.mm_hashes, + mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -575,13 +683,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _update_states_after_model_execute( + self, output_token_ids: torch.Tensor) -> None: + """Update the cached states after model execution. + + This is used for MTP/EAGLE for hybrid models, as in linear attention, + only the last token's state is kept. In MTP/EAGLE, for draft tokens + the state are kept util we decide how many tokens are accepted for + each sequence, and a shifting is done during the next iteration + based on the number of accepted tokens. + """ + if not self.model_config.is_hybrid or not self.speculative_config: + return + + # Find the number of accepted tokens for each sequence. + num_accepted_tokens = (torch.cat( + [ + output_token_ids, + torch.full((output_token_ids.size(0), 1), + -1, + device=output_token_ids.device), + ], + dim=1) == -1).int().argmax(-1).cpu().numpy() + for i, num_tokens in enumerate(num_accepted_tokens): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + def _init_mrope_positions(self, req_state: CachedRequestState): image_grid_thw = [] video_grid_thw = [] second_per_grid_ts = [] audio_feature_lengths = [] use_audio_in_video = False - for mm_item in req_state.mm_kwargs: + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue mm_input = mm_item.get_data() if (t := mm_input.get("image_grid_thw")) is not None: image_grid_thw.append(t.tolist()) @@ -614,7 +750,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_kwargs = list[MultiModalKwargsItem]() for req in scheduler_output.scheduled_new_reqs: - mm_kwargs.extend(req.mm_kwargs) + for feature in req.mm_features: + if feature.data is not None: + mm_kwargs.append(feature.data) # Input all modalities at once mm_kwargs_combined: BatchedTensorInputs = {} @@ -657,11 +795,97 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return cu_num_tokens, arange - def _prepare_inputs( + def _prepare_input_ids(self, total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is None: + # Normal scheduling case + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the GPU from prev_sampled_token_ids. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + assert prev_req_id_to_index is not None + flattened_indices = [] + prev_common_req_indices = [] + indices_match = True + max_flattened_index = -1 + for req_id, cur_index in self.input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + flattened_index = cu_num_tokens[cur_index].item() - 1 + flattened_indices.append(flattened_index) + indices_match &= (prev_index == flattened_index) + max_flattened_index = max(max_flattened_index, flattened_index) + num_commmon_tokens = len(flattened_indices) + if num_commmon_tokens < total_num_scheduled_tokens: + # If not all requests are decodes from the last iteration, + # We need to copy the input_ids_cpu to the GPU first. + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + if num_commmon_tokens == 0: + # No requests in common with the previous iteration + # So input_ids_cpu will have all the input ids. + return + if indices_match and max_flattened_index == (num_commmon_tokens - 1): + # Common-case optimization: the batch is unchanged + # and no reordering happened. + # The indices are both the same permutation of 0..N-1 so + # we can copy directly using a single slice. + self.input_ids.gpu[:num_commmon_tokens].copy_( + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, + 0], + non_blocking=True) + return + # Upload the index tensors asynchronously + # so the scatter can be non-blocking. + input_ids_index_tensor = torch.tensor(flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to( + self.device, + non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, non_blocking=True) + self.input_ids.gpu.scatter_( + dim=0, + index=input_ids_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0]) + + def _get_encoder_seq_lens( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata], int]: + kv_cache_spec: KVCacheSpec, + num_reqs: int, + ) -> Optional[np.ndarray]: + if not isinstance(kv_cache_spec, CrossAttentionSpec): + return None + + # Build encoder_seq_lens array mapping request indices to + # encoder lengths for inputs scheduled in this batch + encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32) + for req_id in scheduler_output.scheduled_encoder_inputs: + req_index = self.input_batch.req_id_to_index[req_id] + encoder_seq_lens[req_index] = self.max_encoder_len + + return encoder_seq_lens + + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> tuple[PerLayerAttnMetadata, torch.Tensor, + Optional[SpecDecodeMetadata], np.ndarray, + Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], + Optional[torch.Tensor]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -733,6 +957,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.query_start_loc.copy_to_gpu() query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + num_tokens_padded = num_tokens_unpadded + self.get_local_padding( + num_tokens_unpadded) + ubatch_slices, num_tokens_after_padding = \ + ubatch_split(max_num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded, + self.vllm_config) + self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) @@ -742,8 +975,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): seq_lens = self.seq_lens.gpu[:num_reqs] max_seq_len = self.seq_lens.np[:num_reqs].max().item() + num_tokens = [ + self.requests[r].num_tokens for r in self.input_batch.req_ids + ] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning + discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + discard_request_indices = np.nonzero(discard_requests_mask)[0] + self.num_discarded_requests = len(discard_request_indices) + self.discard_request_indices.np[:self.num_discarded_requests] = ( + discard_request_indices) + + self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) + # Copy the tensors to the GPU. - self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( @@ -762,6 +1011,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 + num_draft_tokens = None spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -776,13 +1026,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices + self.num_draft_tokens.np[:num_reqs] = num_draft_tokens + self.num_draft_tokens.np[num_reqs:].fill(0) + self.num_draft_tokens.copy_to_gpu() logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: logits_indices_padded = self._prepare_kv_sharing_fast_prefill( logits_indices) - attn_metadata: dict[str, Any] = {} + attn_metadata: PerLayerAttnMetadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] # Used in the below loop. query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] @@ -790,11 +1045,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) spec_decode_common_attn_metadata = None + if use_spec_decode: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + encoder_seq_lens = self._get_encoder_seq_lens( + scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs) if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): @@ -813,13 +1075,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_common_prefix_blocks = 0 else: blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] - slot_mapping = blk_table.slot_mapping[: - total_num_scheduled_tokens] + blk_table_tensor = blk_table.get_device_tensor(num_reqs) + slot_mapping = blk_table.slot_mapping.gpu[: + total_num_scheduled_tokens] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_( + -1) num_common_prefix_blocks = ( scheduler_output. num_common_prefix_blocks[kv_cache_group_id]) @@ -839,6 +1102,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits_indices_padded=logits_indices_padded, num_logits_indices=logits_indices.size(0), causal=True, + encoder_seq_lens=encoder_seq_lens, ) if self.speculative_config and \ @@ -848,7 +1112,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for attn_group in self.attn_groups[kv_cache_group_id]: # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 - builder = attn_group.metadata_builder + builder = attn_group.get_metadata_builder() if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, @@ -857,13 +1121,36 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): builder, ) - attn_metadata_i = (builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - )) + extra_attn_metadata_args = {} + if use_spec_decode and isinstance(builder, + GDNAttentionMetadataBuilder): + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens. + gpu[:num_reqs], + num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + ) - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list): + assert common_attn_metadata.max_query_len == 1 + attn_metadata_i = (attn_group.get_metadata_builder( + ubatch_id=ubid).build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata)) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + assert isinstance(attn_metadata, dict) + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i # Hot-Swap lora model if self.lora_config: @@ -871,7 +1158,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + max_num_scheduled_tokens, ubatch_slices, + num_tokens_after_padding) def _compute_cascade_attn_prefix_len( self, @@ -1111,10 +1399,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) return logits_indices_padded - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + def _batch_mm_kwargs_from_scheduler( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: + """Batch multimodal kwargs from scheduled encoder inputs. + + Args: + scheduler_output: The scheduler output containing scheduled encoder + inputs. + + Returns: + A tuple of (mm_kwargs, req_ids_pos) where: + - mm_kwargs: List of multimodal kwargs items to be batched + - mm_hashes_pos: List of (mm_hash, position_info) tuples + """ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: - return + return [], [] # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() # list of tuple (mm_hash, position_info) @@ -1123,10 +1425,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_hash = req_state.mm_hashes[mm_input_id] - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - mm_hashes_pos.append( - (mm_hash, req_state.mm_positions[mm_input_id])) + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) + + return mm_kwargs, mm_hashes_pos + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + # Batch the multi-modal inputs using the helper method. + mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + scheduler_output) + + if not mm_kwargs: + return # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1178,9 +1490,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens - mm_positions = req_state.mm_positions - mm_hashes = req_state.mm_hashes - for i, pos_info in enumerate(mm_positions): + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1203,7 +1514,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) assert start_idx < end_idx - mm_hash = mm_hashes[i] + mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ f"Encoder cache miss for {mm_hash}." @@ -1218,9 +1529,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_embeds.append(mm_embeds_item) return mm_embeds + def _extract_encoder_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, torch.Tensor]: + """Extract encoder inputs for encoder-decoder models. + + This method extracts multimodal input features from scheduled encoder + inputs and formats them for the encoder-decoder model forward pass. + """ + # Batch the multi-modal inputs using the helper method. + mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) + + if not mm_kwargs: + return {} + + # Group MM kwargs by modality and extract features + encoder_features = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + # Add the grouped features to encoder_features dict + # This allows the model to receive them as kwargs (e.g., + # input_features=...) + encoder_features.update(mm_kwargs_group) + + return encoder_features + def get_model(self) -> nn.Module: # get raw model out of the cudagraph wrapper. - if isinstance(self.model, CUDAGraphWrapper): + if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): return self.model.unwrap() return self.model @@ -1274,74 +1614,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return tuple(tasks) - def apply_grammar_bitmask( - self, - scheduler_output: "SchedulerOutput", - logits: torch.Tensor, - ): - grammar_bitmask = scheduler_output.grammar_bitmask - if grammar_bitmask is None: - return - - # We receive the structured output bitmask from the scheduler, - # compacted to contain bitmasks only for structured output requests. - # The order of the requests in the bitmask is not guaranteed to be the - # same as the order of the requests in the gpu runner's batch. We need - # to sort the bitmask to match the order of the requests used here. - - # Get the batch indices of the structured output requests. - # Keep track of the number of speculative tokens scheduled for every - # request in the batch, as the logit indices are offset by this amount. - struct_out_req_batch_indices: dict[str, int] = {} - cumulative_offset = 0 - seq = sorted(self.input_batch.req_id_to_index.items(), - key=lambda x: x[1]) - for req_id, batch_index in seq: - logit_index = batch_index + cumulative_offset - cumulative_offset += len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - if req_id in scheduler_output.structured_output_request_ids: - struct_out_req_batch_indices[req_id] = logit_index - - out_indices = [] - - # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.full(shape=(logits.shape[0], - grammar_bitmask.shape[1]), - fill_value=-1, - dtype=grammar_bitmask.dtype) - cumulative_index = 0 - seq = sorted(scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) - for req_id, _ in seq: - logit_index = struct_out_req_batch_indices[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] - out_indices.append(logit_index + i) - cumulative_index += 1 + num_spec_tokens - grammar_bitmask = sorted_bitmask - - # If the length of out indices and the logits have the same shape - # we don't need to pass indices to the kernel, - # since the bitmask is already aligned with the logits. - skip_out_indices = len(out_indices) == logits.shape[0] - - # Serialization of np.ndarray is much more efficient than a tensor, - # so we receive it in that format. - grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() - - # Force use of the torch.compile implementation from xgrammar to work - # around issues with the Triton kernel in concurrent structured output - # scenarios. See PR #19565 and issues #19493, #18376 for details. - xgr_torch_compile.apply_token_bitmask_inplace_torch_compile( - logits, - grammar_bitmask.to(self.device, non_blocking=True), - indices=out_indices if not skip_out_indices else None, - ) - def sync_and_slice_intermediate_tensors( self, num_tokens: int, intermediate_tensors: IntermediateTensors, sync_self: bool) -> IntermediateTensors: @@ -1349,21 +1621,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size - enabled_sp = self.compilation_config.pass_config. \ - enable_sequence_parallelism - if enabled_sp: - # When sequence parallelism is enabled, we always pad num_tokens - # to be a multiple of tensor_parallel_size (tp) earlier - assert num_tokens % tp == 0 - is_residual_scattered = tp > 1 and enabled_sp \ - and num_tokens % tp == 0 + is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens) # When sequence parallelism is enabled, the "residual" tensor is sharded # across tensor parallel ranks, so each rank only needs its own slice. if sync_self: assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): - is_scattered = k == "residual" and is_residual_scattered + is_scattered = k == "residual" and is_rs copy_len = num_tokens // tp if is_scattered else \ num_tokens self.intermediate_tensors[k][:copy_len].copy_( @@ -1371,8 +1636,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return IntermediateTensors({ k: - v[:num_tokens // tp] - if k == "residual" and is_residual_scattered else v[:num_tokens] + v[:num_tokens // + tp] if k == "residual" and is_rs else v[:num_tokens] for k, v in self.intermediate_tensors.items() }) @@ -1397,6 +1662,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + """ + Determines the total number of tokens that each rank will run. + All ranks will be padded out so that they run with the same number + of tokens + + Returns: tuple[ + num_pad_tokens: The number of tokens that will be added to the batch + num_tokens_after_padding: A tensor containing the total number of + tokens for each DP rank including padding. + ] + """ dp_size = self.vllm_config.parallel_config.data_parallel_size dp_rank = self.vllm_config.parallel_config.data_parallel_rank @@ -1420,12 +1696,44 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int32) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + def get_local_padding(self, num_tokens_unpadded: int) -> int: + + num_tokens_padded = num_tokens_unpadded + + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_tokens_padded = self.vllm_config.pad_for_cudagraph( + num_tokens_unpadded) + else: + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + num_tokens_padded = round_up(num_tokens_unpadded, tp_size) + + num_pad_tokens = num_tokens_padded - num_tokens_unpadded + return num_pad_tokens + + # This is where the second ubatch is adjusted to account for the padding. + # Should be called after attention metadata creation. This just pads + # the second ubatch slice out to the total number of tokens + # (num_tokens + padding) + def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, + num_total_tokens: int): + padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, + num_total_tokens) + ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, + padded_second_ubatch_slice) + def _pool( self, hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, - kv_connector_output: Optional[KVConnectorOutput], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1456,68 +1764,56 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, - kv_connector_output=kv_connector_output, ) - @torch.inference_mode() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, IntermediateTensors]: - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) - - if self.cache_config.kv_sharing_fast_prefill: - assert not self.input_batch.num_prompt_logprobs, ( - "--kv-sharing-fast-prefill produces incorrect logprobs for " - "prompt tokens, tokens, please disable it when the requests " - "need prompt logprobs") - - # Prepare the decoder inputs. - (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = self._prepare_inputs(scheduler_output) - - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH + and hasattr(self, "cudagraph_batch_sizes") + and self.cudagraph_batch_sizes and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use CUDA graphs. # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_scheduled_tokens) - else: - # Eager mode. - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: - num_input_tokens = round_up(num_scheduled_tokens, tp_size) - else: - num_input_tokens = num_scheduled_tokens + return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) - num_input_tokens += num_pad + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if (self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1): + return round_up(num_scheduled_tokens, tp_size) + return num_scheduled_tokens + + def _preprocess( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ubatch_slices: Optional[UBatchSlices] = None, + num_tokens_after_padding: Optional[torch.Tensor] = None, + ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, + Optional[IntermediateTensors], dict[str, Any]]: + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if ubatch_slices: + assert num_tokens_after_padding is not None + num_input_tokens = int(num_tokens_after_padding[0].item() * 2) + self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) + elif ubatch_slices is None: + num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) + num_pad, num_tokens_after_padding = self.get_dp_padding( + num_input_tokens) + num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if self.supports_mm_inputs: + if (self.supports_mm_inputs and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) - else: - mm_embeds = [] - if self.supports_mm_inputs and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -1527,11 +1823,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_( + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_( inputs_embeds_scheduled) input_ids = None - inputs_embeds = self.inputs_embeds[:num_input_tokens] + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] model_kwargs = { **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), @@ -1555,75 +1851,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - num_scheduled_tokens == self.input_batch.num_reqs * max_query_len) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor) + if (self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs): + encoder_inputs = self._extract_encoder_inputs(scheduler_output) + model_kwargs.update(encoder_inputs) - # Run the model. - # Use persistent buffers for CUDA graphs. - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ), self.maybe_get_kv_connector_output( - scheduler_output) as kv_connector_output: - - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None - - # Broadcast PP output for external_launcher (torchrun) - # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches - # https://github.com/vllm-project/vllm/issues/18019 - broadcast_pp_output = \ - self.parallel_config.distributed_executor_backend \ - == "external_launcher" and len(get_pp_group().ranks) > 0 - if not get_pp_group().is_last_rank: - # For mid-pipeline stages, return the hidden states. - assert isinstance(hidden_states, IntermediateTensors) - if not broadcast_pp_output: - hidden_states.kv_connector_output = kv_connector_output - return hidden_states - get_pp_group().send_tensor_dict(hidden_states.tensors, - all_gather_group=get_tp_group()) - logits = None - else: - if self.is_pooling_model: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, kv_connector_output) - - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - if broadcast_pp_output: - model_output_broadcast_data = { - "logits": logits.contiguous(), - } if logits is not None else {} - model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( - model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) - assert model_output_broadcast_data is not None - logits = model_output_broadcast_data["logits"] - - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) + return ( + num_scheduled_tokens, + num_input_tokens, + num_tokens_after_padding, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) + def _sample( + self, logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata] + ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: @@ -1656,28 +1903,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_metadata, ) sampler_output.sampled_token_ids = output_token_ids + self._update_states_after_model_execute(output_token_ids) + return sampler_output + + def _bookkeeping_sync( + self, scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, num_scheduled_tokens: int + ) -> tuple[ + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], + ]: num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - discard_sampled_tokens_req_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token for partial prefills. - # Rewind the generator state as if the token was not sampled. - # This relies on cuda-specific torch-internal impl details - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) + discard_sampled_tokens_req_indices = \ + self.discard_request_indices.np[:self.num_discarded_requests] + for i in discard_sampled_tokens_req_indices: + gen = self.input_batch.generators.get(int(i)) + if gen is not None: + gen.set_offset(gen.get_offset() - 4) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. @@ -1691,21 +1949,42 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): scheduler_output.num_scheduled_tokens, ) - # Get the valid generated tokens. + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = self._to_list(sampled_token_ids) + invalid_req_indices = [] + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)].clear() else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids = [] + invalid_req_indices = discard_sampled_tokens_req_indices.tolist() + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the GPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. @@ -1713,7 +1992,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. req_ids = self.input_batch.req_ids - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: continue @@ -1728,28 +2012,210 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx + req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if self.speculative_config: - assert spec_decode_common_attn_metadata is not None - self._draft_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, + return ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + with record_function_or_nullcontext("Preprocess"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect logprobs for " + "prompt tokens, tokens, please disable it when the requests" + " need prompt logprobs") + + if self.prepare_inputs_event is not None: + # Ensure prior step has finished with reused CPU tensors. + self.prepare_inputs_event.synchronize() + try: + # Prepare the decoder inputs. + (attn_metadata, logits_indices, spec_decode_metadata, + num_scheduled_tokens_np, spec_decode_common_attn_metadata, + max_query_len, ubatch_slices, num_tokens_after_padding + ) = self._prepare_inputs(scheduler_output) + + finally: + if self.prepare_inputs_event is not None: + self.prepare_inputs_event.record() + + ( + num_scheduled_tokens, + num_input_tokens, + num_tokens_across_dp, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) = self._preprocess(scheduler_output, intermediate_tensors, + ubatch_slices, num_tokens_after_padding) + + if ubatch_slices is not None: + num_input_tokens = num_input_tokens // 2 + + uniform_decode = (max_query_len + == self.uniform_decode_query_len) and ( + num_scheduled_tokens + == self.input_batch.num_reqs * max_query_len) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=uniform_decode) + cudagraph_runtime_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch(batch_descriptor) + + # Run the model. + # Use persistent buffers for CUDA graphs. + with (set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ), record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as + kv_connector_output): + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, ) - self.eplb_step() + with record_function_or_nullcontext("Postprocess"): + if self.use_aux_hidden_state_outputs: + # True when EAGLE 3 is used. + hidden_states, aux_hidden_states = model_output + else: + # Common case. + hidden_states = model_output + aux_hidden_states = None - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + if not self.broadcast_pp_output: + # Common case. + if not get_pp_group().is_last_rank: + # Return the intermediate tensors. + assert isinstance(hidden_states, IntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + + if self.is_pooling_model: + # Return the pooling output. + output = self._pool(hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np) + output.kv_connector_output = kv_connector_output + return output + + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + else: + # Rare case. + assert not self.is_pooling_model + + if not get_pp_group().is_last_rank: + all_gather_tensors = { + "residual": + not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens) + } + get_pp_group().send_tensor_dict( + hidden_states.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors) + logits = None + else: + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, + None) + + model_output_broadcast_data = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() + + model_output_broadcast_data = get_pp_group( + ).broadcast_tensor_dict(model_output_broadcast_data, + src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + apply_grammar_bitmask(scheduler_output, self.input_batch, + logits, self.device) + + with record_function_or_nullcontext("Sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + def propose_draft_token_ids(sampled_token_ids): + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("Draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + + use_padded_batch_for_eagle = self.speculative_config and \ + self.speculative_config.use_eagle() and \ + not self.speculative_config.disable_padded_drafter_batch + if use_padded_batch_for_eagle: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampler_output.sampled_token_ids) + + with record_function_or_nullcontext("Bookkeep"): + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync(scheduler_output, sampler_output, + logits, hidden_states, + num_scheduled_tokens) + + if self.speculative_config and not use_padded_batch_for_eagle: + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) + + with record_function_or_nullcontext("EPLB"): + self.eplb_step() + + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, @@ -1758,6 +2224,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_nans_in_logits=num_nans_in_logits, ) + if not self.use_async_scheduling: + return output + + return AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None @@ -1772,7 +2248,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: list[list[int]], + sampled_token_ids: Union[torch.Tensor, list[list[int]]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, @@ -1782,11 +2258,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "medusa": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) + if sample_hidden_states.shape[0] == len(sampled_token_ids): # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states @@ -1807,27 +2286,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - # TODO(woosuk): Refactor the loop. - req_ids = self.input_batch.req_ids - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) + + if self.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), \ + "sampled_token_ids should be a python list when" \ + "padded-batch is disabled." + next_token_ids = self.drafter.prepare_next_token_ids_cpu( + sampled_token_ids, self.requests, self.input_batch, + scheduler_output.num_scheduled_tokens) + else: + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), \ + "sampled_token_ids should be a torch.Tensor when" \ + "padded-batch is enabled." + next_token_ids, valid_sampled_tokens_count = \ + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests + ) if spec_decode_metadata is None: + token_indices_to_sample = None # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. @@ -1839,17 +2328,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, num_rejected_tokens_cpu) + if self.speculative_config.disable_padded_drafter_batch: + token_indices_to_sample = None + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens) + else: + common_attn_metadata, token_indices, \ + token_indices_to_sample =\ + self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. @@ -1869,6 +2361,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds, @@ -2012,10 +2505,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \ + and not self.parallel_config.enable_dbo: self.model = CUDAGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + elif self.parallel_config.enable_dbo: + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.FULL, self.device) + else: + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.NONE, self.device) def reload_weights(self) -> None: assert getattr(self, "model", None) is not None, \ @@ -2213,8 +2714,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, uniform_decode: bool = False, + allow_microbatching: bool = False, skip_eplb: bool = False, is_profile: bool = False, + create_mixed_batch: bool = False, remove_lora: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -2233,18 +2736,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): uniform_decode: If True, the batch is a uniform decode batch. skip_eplb: If True, skip EPLB state update. is_profile: If True, this is a profile run. + create_mixed_batch: If True, create a mixed batch with both decode + (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ + ubatch_enabled = self.parallel_config.enable_dbo + num_tokens_across_dp = None + num_pad = 0 + should_ubatch = False + if ubatch_enabled: + should_ubatch = num_tokens >= \ + self.parallel_config.dbo_decode_token_threshold and \ + allow_microbatching + + (should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch( + num_tokens, num_tokens, should_ubatch, self.vllm_config) + + # Currently the dummy run should only be ubatching during + # cuda graph capture, meaning all DP ranks should already + # have the same batch size + if num_tokens_across_dp is not None: + assert int(num_tokens_across_dp[0]) == num_tokens // 2 + assert cudagraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + if not should_ubatch: + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad # If cudagraph_mode.decode_mode() == FULL and - # cudagraph_mode.seperate_routine(). This means that we are using + # cudagraph_mode.separate_routine(). This means that we are using # different graphs and/or modes for mixed prefill-decode batches vs. # uniform decode batches. A uniform decode batch means that all # requests have identical query length, except a potential virtual @@ -2258,19 +2781,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_decode else \ num_tokens + if allow_microbatching: + assert self.uniform_decode_query_len == 1 + assert uniform_decode is True + assert max_query_len == 1 # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - if uniform_decode: - num_reqs = cdiv(num_tokens, max_query_len) + if create_mixed_batch: + assert not uniform_decode + # Create mixed batch: + # first half decode tokens, second half one prefill + num_decode_tokens = num_tokens // 2 + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + + # Create decode requests (1 token each) followed by prefill request + num_scheduled_tokens_list = [1] * num_decode_tokens + [ + num_prefill_tokens + ] + # Note: Overriding max_query_len to be the prefill tokens + max_query_len = num_prefill_tokens + elif uniform_decode: + num_reqs = num_tokens // max_query_len assert num_reqs <= max_num_reqs, \ "Do not capture num_reqs > max_num_reqs for uniform batch" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: - num_scheduled_tokens_list[-1] = num_tokens % max_query_len + num_scheduled_tokens_list[-1] += num_tokens % max_query_len else: num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs @@ -2282,15 +2823,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - attn_metadata: Optional[dict[str, Any]] = None + ubatch_slices = None + # We currently only microbatch if the number of tokens is + # over a certain threshold. + if should_ubatch: + # We only support decode-only cudagraphs + assert num_reqs == num_tokens + assert num_tokens % 2 == 0 + ubatch_slices = [ + UBatchSlice(slice(0, num_reqs // 2), slice(0, + num_tokens // 2)), + UBatchSlice(slice(num_reqs // 2, num_reqs), + slice(num_tokens // 2, num_tokens)) + ] + + attn_metadata: Optional[PerLayerAttnMetadata] = None # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: attn_metadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] - # Make sure max_model_len is used at the graph capture time. - self.seq_lens.np[:num_reqs] = self.max_model_len + if create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + else: + # Make sure max_model_len is used at the graph capture time. + seq_lens = self.max_model_len + self.seq_lens.np[:num_reqs] = seq_lens self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() @@ -2308,31 +2872,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_actual_tokens=num_tokens, max_query_len=max_query_len, max_seq_len=self.max_model_len, - block_table_tensor=self.input_batch.block_table[ - kv_cache_group_id].get_device_tensor()[:num_reqs], - slot_mapping=self.input_batch. - block_table[kv_cache_group_id].slot_mapping[:num_tokens], + block_table_tensor=self.input_batch. + block_table[kv_cache_group_id].get_device_tensor(num_reqs), + slot_mapping=self.input_batch.block_table[ + kv_cache_group_id].slot_mapping.gpu[:num_tokens], causal=True) - for attn_group in self.attn_groups[kv_cache_group_id]: - attn_metadata_i = attn_group.metadata_builder\ - .build_for_cudagraph_capture(common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, common_attn_metadata) + for ubid, common_attn_metadata in enumerate( + common_attn_metadata_list): + assert common_attn_metadata.max_query_len == 1 + attn_metadata_i = (attn_group\ + .get_metadata_builder(ubatch_id=ubid)\ + .build_for_cudagraph_capture(common_attn_metadata)) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][ + layer_name] = attn_metadata_i + else: + assert type(attn_metadata) is dict + attn_metadata_i = attn_group.get_metadata_builder()\ + .build_for_cudagraph_capture(common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, remove_lora): - if self.supports_mm_inputs: + model_kwargs = self._init_model_kwargs(num_tokens) + if (self.supports_mm_inputs + and not self.model_config.is_encoder_decoder): input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] model_kwargs = { - **self._init_model_kwargs(num_tokens), + **model_kwargs, **self._dummy_mm_kwargs(num_reqs), } else: input_ids = self.input_ids.gpu[:num_tokens] inputs_embeds = None - model_kwargs = self._init_model_kwargs(num_tokens) if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_tokens] @@ -2364,13 +2943,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): f"Cudagraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + if ubatch_slices is not None: + num_tokens = num_tokens // 2 with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor): + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices): outputs = self.model( input_ids=input_ids, positions=positions, @@ -2502,6 +3084,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) + dummy_pooling_params.verify(task=task, model_config=self.model_config) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) @@ -2554,7 +3137,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_budget = self.mm_budget assert mm_budget is not None - # TODO: handle encoder-decoder models once we support them. if (encoder_budget := mm_budget.get_encoder_budget()) > 0: # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when @@ -2607,12 +3189,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.encoder_cache.clear() gc.collect() - def capture_model(self) -> None: + def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " "ensure `cudagraph_mode` was not manually set to `NONE`") - return + return 0 else: self.initialize_cudagraph_capture() @@ -2635,6 +3217,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): finally: if should_freeze: gc.unfreeze() + gc.collect() # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes @@ -2642,6 +3225,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): set_cudagraph_capturing_enabled(True) with freeze_gc(), graph_capture(device=self.device): cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() @@ -2651,8 +3235,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode=cudagraph_runtime_mode, uniform_decode=False) - # Capture full cudagraph for uniform decode batches if we have - # dont already have full mixed prefill-decode cudagraphs + # Capture full cudagraph for uniform decode batches if we + # don't already have full mixed prefill-decode cudagraphs. if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ cudagraph_mode.separate_routine(): max_num_tokens = self.scheduler_config.max_num_seqs * \ @@ -2671,7 +3255,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. # Note: We don't put it into graph_capture context manager because - # we may doing lazy capturing in future that still allows capturing + # we may do lazy capturing in future that still allows capturing # after here. set_cudagraph_capturing_enabled(False) @@ -2682,6 +3266,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # This usually takes 5~20 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + return cuda_graph_size def _capture_cudagraphs(self, compilation_cases: list[int], cudagraph_runtime_mode: CUDAGraphMode, @@ -2698,6 +3283,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", cudagraph_runtime_mode.name)) + enable_dbo = self.parallel_config.enable_dbo + # DBO Only supports running Full cudagraphs with uniform + # decode lengths + if enable_dbo and uniform_decode: + for num_tokens in compilation_cases: + # If the number of tokens is greater than the microbatching + # threshold, don't generate a microbatched cudagraph + if (num_tokens + < self.parallel_config.dbo_decode_token_threshold): + continue + + # Warmup + for _ in range( + self.compilation_config.cudagraph_num_of_warmups): + force_attention = ( + cudagraph_runtime_mode == CUDAGraphMode.FULL) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=True, + allow_microbatching=True, + skip_eplb=True) + + # Graph Capture + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True, + allow_microbatching=True, + skip_eplb=True) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: for _ in range(self.compilation_config.cudagraph_num_of_warmups): @@ -2764,14 +3378,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] for attn_backend, layer_names in attn_backends_map.items(): - attn_metadata_builder_i = attn_backend.get_builder_cls()( + attn_metadata_builders = [] + attn_metadata_builders.append(attn_backend.get_builder_cls()( kv_cache_spec, layer_names, self.vllm_config, self.device, - ) + )) + if self.parallel_config.enable_dbo: + attn_metadata_builders.append( + attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + )) attn_group = AttentionGroup(attn_backend, - attn_metadata_builder_i, + attn_metadata_builders, layer_names) attn_groups.append(attn_group) return attn_groups @@ -2783,7 +3406,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.attn_groups.append( create_attn_groups(attn_backends, kv_cache_spec)) - # Calculate reorder batch threshold (if neeeded) + # Calculate reorder batch threshold (if needed) self.calculate_reorder_batch_threshold() def initialize_cudagraph_capture(self) -> None: @@ -2791,11 +3414,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): min_cg_builder_name = None for attn_group in self._attn_group_iterator(): - builder = attn_group.metadata_builder + builder = attn_group.get_metadata_builder() if builder.cudagraph_support.value < min_cg_support.value: min_cg_support = builder.cudagraph_support min_cg_builder_name = builder.__class__.__name__ - # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported @@ -2861,7 +3483,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): is compatible (e.g., decode threshold is the same) """ for group in self._attn_group_iterator(): - attn_metadata_builder_i = group.metadata_builder + attn_metadata_builder_i = group.get_metadata_builder() # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) @@ -2900,7 +3522,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): "for more details.") self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -2909,6 +3531,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), ) def _allocate_kv_cache_tensors( @@ -3146,6 +3771,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + if self.device.type == 'xpu': + get_kv_transfer_group().set_host_xfer_buffer_ops( + copy_kv_blocks) + + if self.dcp_world_size > 1: + layer_names = self.attn_groups[0][0].layer_names + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + for layer in layers.values(): + assert layer.impl.need_to_return_lse_for_decode, ( + "DCP requires attention impls to return" + " the softmax lse for decode, but the impl " + f"{layer.impl.__class__.__name__} " + "does not return the softmax lse for decode.") def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ @@ -3158,7 +3798,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: - attn_spec = EncoderOnlyAttentionSpec( + attn_spec: AttentionSpec = EncoderOnlyAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, @@ -3200,7 +3840,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue - # TODO: Support other attention modules, e.g., cross-attention # TODO(lucas): move the attention specs into the model layers like # the attention backends if attn_module.attn_type == AttentionType.DECODER: @@ -3228,19 +3867,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=use_mla) + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + kv_cache_spec[layer_name] = CrossAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: - if self.vllm_config.speculative_config is not None: + if (self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"]): raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") if self.vllm_config.cache_config.enable_prefix_caching: @@ -3259,7 +3905,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtypes=mamba_module.get_state_dtype(), block_size=max_model_len, page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type) + mamba_type=mamba_module.mamba_type, + num_speculative_blocks=( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0), + ) return kv_cache_spec diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py new file mode 100644 index 0000000000000..5012ad0483c84 --- /dev/null +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +import threading +from typing import Any, Callable, Optional + +import torch + +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import (create_forward_context, get_forward_context, + override_forward_context) +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class UbatchMetadata: + context: UBatchContext + input_ids: torch.Tensor + positions: torch.Tensor + inputs_embeds: Optional[torch.Tensor] + intermediate_tensors: Optional[IntermediateTensors] + num_tokens: int + + +@dataclasses.dataclass +class CUDAGraphMetaData: + cudagraph: torch.cuda.CUDAGraph + ubatch_metadata: UbatchMetadata + outputs: Optional[Any] = None + + +class UBatchWrapper: + + def __init__(self, runnable: Callable, vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, device: torch.cuda.device): + self.runnable = runnable + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.comm_stream = torch.cuda.Stream(device=device) + # Two ubatch threads plus the main thread + self.ready_barrier = threading.Barrier(3) + + self.cudagraphs: dict[int, CUDAGraphMetaData] = {} + + self.cudagraph_wrapper = None + self.graph_pool = None + if runtime_mode is not CUDAGraphMode.NONE: + self.cudagraph_wrapper = CUDAGraphWrapper( + runnable, vllm_config, runtime_mode=runtime_mode) + self.graph_pool = current_platform.get_global_graph_pool() + + def __getattr__(self, key: str): + # allow accessing the attributes of the runnable. + if hasattr(self.runnable, key): + return getattr(self.runnable, key) + raise AttributeError(f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self.runnable}") + + def unwrap(self) -> Callable: + # in case we need to access the original runnable. + return self.runnable + + def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor: + """ + Capture a cudagraph for a microbatched run. + + The logic here is somewhat complicated because we need to make sure that + each of the ubatch threads initialize the cuda context before we start + the graph capture. + + The flow is as follows: + 1. The main thread starts up each ubatch thread. Each thread will + initialize its cuda context (torch.cuda.current_blas_handle()) + before going to sleep upon entering the ubatch_context. + + 2. The main thread starts the graph capture and wakes up the first + ubatch thread. + + 3. Each ubatch thread runs the model to completion and returns the + completed output tensors back to the main thread. + + 4. The main thread stores the captured cudagraph along with its metadata + and returns + """ + + @torch.inference_mode() + def _capture_ubatch_thread(results, ubatch_metadata): + ubatch_context = ubatch_metadata.context + with torch.cuda.stream(ubatch_context.compute_stream): + _ = torch.cuda.current_blas_handle() + with torch.cuda.stream(ubatch_context.comm_stream): + _ = torch.cuda.current_blas_handle() + with ubatch_context: + model_output = model( + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + intermediate_tensors=ubatch_metadata.intermediate_tensors, + inputs_embeds=ubatch_metadata.inputs_embeds, + ) + + results.append((ubatch_metadata.context.id, model_output)) + + results: list[tuple[int, torch.Tensor]] = [] + compute_stream = ubatch_metadata[0].context.compute_stream + num_tokens = ubatch_metadata[0].num_tokens + \ + ubatch_metadata[1].num_tokens + + # Ubatches will manually manage the forward context, so we override + # it to None here so we can have it restored correctly later + with override_forward_context(None): + ubatch_threads = [] + for metadata in ubatch_metadata: + thread = threading.Thread(target=_capture_ubatch_thread, + args=( + results, + metadata, + )) + ubatch_threads.append(thread) + thread.start() + self.ready_barrier.wait() # Wait for both threads to be ready + + # Capture the cudagraph + cudagraph_metadata = \ + CUDAGraphMetaData( + cudagraph=torch.cuda.CUDAGraph(), + ubatch_metadata=ubatch_metadata, + ) + with torch.cuda.graph(cudagraph_metadata.cudagraph, + stream=compute_stream, + pool=self.graph_pool): + ubatch_metadata[0].context.cpu_wait_event.set() + for thread in ubatch_threads: + thread.join() + sorted_results = [value for position, value in sorted(results)] + result = torch.cat(sorted_results, dim=0) + cudagraph_metadata.outputs = result + self.cudagraphs[num_tokens] = cudagraph_metadata + return cudagraph_metadata.outputs + + def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor: + + @torch.inference_mode() + def _ubatch_thread(results, model, ubatch_metadata): + with ubatch_metadata.context: + model_output = model( + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + intermediate_tensors=ubatch_metadata.intermediate_tensors, + inputs_embeds=ubatch_metadata.inputs_embeds, + ) + results.append((ubatch_metadata.context.id, model_output)) + + results: list[tuple[int, torch.Tensor]] = [] + + # Ubatch threads will manually manage the forward context, so we + # override it to None here so we can have it restored correctly + # after both threads have finished + with override_forward_context(None): + ubatch_threads = [] + for metadata in ubatch_metadata: + thread = threading.Thread(target=_ubatch_thread, + args=( + results, + model, + metadata, + )) + ubatch_threads.append(thread) + thread.start() + self.ready_barrier.wait() # Wait for both threads to be ready + ubatch_metadata[0].context.cpu_wait_event.set() + for thread in ubatch_threads: + thread.join() + sorted_results = [value for position, value in sorted(results)] + result = torch.cat(sorted_results, dim=0) + return result + + def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, + positions, inputs_embeds, intermediate_tensors, + compute_stream, dp_metadata, batch_descriptor, + cudagraph_runtime_mode) -> list[UbatchMetadata]: + + # Create one forward context per ubatch + forward_contexts = [] + for i, ubatch_slice in enumerate(ubatch_slices): + forward_contexts.append( + create_forward_context( + attn_metadata[i] if attn_metadata is not None else None, + self.vllm_config, + dp_metadata=dp_metadata, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=cudagraph_runtime_mode)) + + ubatch_ctxs = make_ubatch_contexts( + num_micro_batches=len(ubatch_slices), + comm_stream=self.comm_stream, + compute_stream=compute_stream, + forward_contexts=forward_contexts, + ready_barrier=self.ready_barrier) + + ubatch_metadata: list[UbatchMetadata] = [] + for i, ubatch_slice in enumerate(ubatch_slices): + sliced_input_ids, sliced_positions, sliced_inputs_embeds, \ + sliced_intermediate_tensors = \ + self._slice_model_inputs( + ubatch_slice.token_slice, input_ids, positions, + inputs_embeds, intermediate_tensors) + ubatch_metadata.append( + UbatchMetadata( + context=ubatch_ctxs[i], + input_ids=sliced_input_ids, + positions=sliced_positions, + inputs_embeds=sliced_inputs_embeds, + intermediate_tensors=sliced_intermediate_tensors, + num_tokens=ubatch_slice.token_slice.stop - + ubatch_slice.token_slice.start)) + + return ubatch_metadata + + def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions, + inputs_embeds, intermediate_tensors): + sliced_input_ids = input_ids[tokens_slice] + # if we are using mrope. Mrope adds an additional dimension to the + # positions tensor + if positions.ndim == 2: + sliced_positions = positions[:, tokens_slice] + else: + sliced_positions = positions[tokens_slice] + sliced_inputs_embeds = inputs_embeds[ + tokens_slice] if inputs_embeds else None + sliced_intermediate_tensors = intermediate_tensors[ + tokens_slice] if intermediate_tensors else None + + return (sliced_input_ids, sliced_positions, sliced_inputs_embeds, + sliced_intermediate_tensors) + + def __call__(self, *args, **kwargs): + forward_context = get_forward_context() + batch_descriptor = forward_context.batch_descriptor + ubatch_slices = forward_context.ubatch_slices + cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode + + # If there's no ubatching, just run the runnable object + if ubatch_slices is None: + if cudagraph_runtime_mode in (CUDAGraphMode.NONE, + CUDAGraphMode.PIECEWISE): + return self.runnable(*args, **kwargs) + else: + assert self.cudagraph_wrapper is not None + return self.cudagraph_wrapper(*args, **kwargs) + + attn_metadata = forward_context.attn_metadata + num_tokens = (ubatch_slices[0].token_slice.stop - + ubatch_slices[0].token_slice.start) * 2 + input_ids = kwargs['input_ids'] + positions = kwargs['positions'] + intermediate_tensors = kwargs['intermediate_tensors'] + inputs_embeds = kwargs['inputs_embeds'] + compute_stream = torch.cuda.current_stream() + + dp_metadata = forward_context.dp_metadata + + # We shouldn't be here unless we are running with multiple DP ranks + assert dp_metadata is not None + + if num_tokens not in self.cudagraphs \ + and cudagraph_runtime_mode is CUDAGraphMode.FULL: + ubatch_metadata = self._make_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + compute_stream=compute_stream, + dp_metadata=dp_metadata, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=CUDAGraphMode.NONE) + + return self._capture_ubatches(ubatch_metadata, self.model) + elif num_tokens in self.cudagraphs: + cudagraph_metadata = self.cudagraphs[num_tokens] + cudagraph_metadata.cudagraph.replay() + return cudagraph_metadata.outputs + else: + ubatch_metadata = self._make_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + compute_stream=compute_stream, + dp_metadata=dp_metadata, + batch_descriptor=batch_descriptor, + cudagraph_runtime_mode=CUDAGraphMode.NONE) + return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f49f5bdd9703b..6855526583f04 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -5,7 +5,7 @@ import copy import gc import os from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed @@ -28,10 +28,11 @@ from vllm.tasks import SupportedTask from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, - ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -231,18 +232,40 @@ class Worker(WorkerBase): You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ + GiB = lambda b: b / GiB_bytes + if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: + # still need a profile run which compiles the model for + # max_num_batched_tokens + self.model_runner.profile_run() + + msg = ( + f"Initial free memory {GiB(self.init_snapshot.free_memory)} " + f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for " + "KV Cache as specified by kv_cache_memory_bytes config and " + "skipped memory profiling. This does does not respect the " + "gpu_memory_utilization config. Only use kv_cache_memory_bytes " + "config when you want manual control of KV cache memory " + "size. If OOM'ed, check the difference of initial free " + "memory between the current run and the previous run " + "where kv_cache_memory_bytes is suggested and update it " + "correspondingly.") + logger.info(msg) + return kv_cache_memory_bytes + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - GiB = lambda b: b / GiB_bytes # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( self.init_snapshot, - weights_memory=int( - self.model_runner.model_memory_usage)) as profile_result: + weights_memory=int(self.model_runner.model_memory_usage), + ) as profile_result: self.model_runner.profile_run() + self.non_torch_memory = profile_result.non_torch_increase + self.peak_activation_memory = profile_result.torch_peak_increase + free_gpu_memory = profile_result.after_profile.free_memory # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. @@ -254,7 +277,7 @@ class Worker(WorkerBase): "release GPU memory while vLLM is profiling during initialization. " "To fix this, ensure consistent GPU memory allocation or " "isolate vLLM in its own container.") - available_kv_cache_memory = self.requested_memory \ + self.available_kv_cache_memory_bytes = self.requested_memory \ - profile_result.non_kv_cache_memory unrequested_memory = self.init_snapshot.free_memory \ @@ -274,10 +297,10 @@ class Worker(WorkerBase): ) logger.debug(profile_result) logger.info("Available KV cache memory: %.2f GiB", - GiB(available_kv_cache_memory)) + GiB(self.available_kv_cache_memory_bytes)) gc.collect() - return int(available_kv_cache_memory) + return int(self.available_kv_cache_memory_bytes) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() @@ -317,8 +340,56 @@ class Worker(WorkerBase): # cuda graph capture. kernel_warmup(self) + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: - self.model_runner.capture_model() + cuda_graph_memory_bytes = self.model_runner.capture_model() + + if (self.cache_config.kv_cache_memory_bytes is None + and hasattr(self, "peak_activation_memory")): + # Suggests optimal kv cache memory size if we rely on + # memory_profiling to guess the kv cache memory size which + # provides peak_activation_memory and a few other memory + # consumption. `memory_profiling` does not consider + # CUDAGraph memory size and may not utilize all gpu memory. + # Users may want fine-grained control to specify kv cache + # memory size. + GiB = lambda b: round(b / GiB_bytes, 2) + + # empirically observed that the memory profiling may + # slightly underestimate the memory consumption. + # So leave a small buffer (=150MiB) to avoid OOM. + redundancy_buffer_memory = 150 * (1 << 20) + non_kv_cache_memory = (self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes) + kv_cache_memory_bytes_to_gpu_limit = ( + self.init_snapshot.free_memory - non_kv_cache_memory - + redundancy_buffer_memory) + kv_cache_memory_bytes_to_requested_limit = ( + int(self.requested_memory) - non_kv_cache_memory - + redundancy_buffer_memory) + + msg = ( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. " + f"Desired GPU memory utilization is " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). " + f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " + f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " + f"for peak activation, {GiB(self.non_torch_memory)} GiB " + f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " + f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " + f"config with `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_requested_limit}` to fit into " + f"requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` to fully " + f"utilize gpu memory. Current kv cache memory in use is " + f"{int(self.available_kv_cache_memory_bytes)} bytes.") + + logger.info(msg) # Warm up sampler and preallocate memory buffer for logits and other # sampling related tensors of max possible shape to avoid memory @@ -355,17 +426,26 @@ class Worker(WorkerBase): def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: + ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = self.model_runner._get_num_input_tokens( + num_scheduled_tokens) + all_gather_tensors = { + "residual": + not is_residual_scattered_for_sp(self.vllm_config, + num_input_tokens) + } if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors)) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - if isinstance(output, ModelRunnerOutput): + if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): return output assert isinstance(output, IntermediateTensors) @@ -374,7 +454,8 @@ class Worker(WorkerBase): "external_launcher") and not get_pp_group().is_last_rank get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors) kv_connector_output = output.kv_connector_output if not kv_connector_output: @@ -400,8 +481,10 @@ class Worker(WorkerBase): self.profiler.start() else: self.profiler.stop() - print(self.profiler.key_averages().table( - sort_by="self_cuda_time_total")) + # only print profiler results on rank 0 + if self.local_rank == 0: + print(self.profiler.key_averages().table( + sort_by="self_cuda_time_total")) def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1) @@ -498,7 +581,8 @@ class Worker(WorkerBase): parallel_config = self.vllm_config.parallel_config moe_modules = [ module for module in self.model_runner.model.modules() - if module.__class__.__name__ == "FusedMoE" + if (module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE") ] num_local_experts = moe_modules[0].moe_config.num_local_experts assert all(module.moe_config.num_local_experts == num_local_experts @@ -598,6 +682,9 @@ class Worker(WorkerBase): self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) + def shutdown(self) -> None: + self.model_runner.ensure_kv_transfer_shutdown() + def init_worker_distributed_environment( vllm_config: VllmConfig, @@ -613,7 +700,9 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, backend) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.decode_context_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index a03ebe35d8e0a..3eb9f26e9f5b6 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -9,7 +9,8 @@ from typing import Generator # noqa: UP035 from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, +from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown, + get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.forward_context import get_forward_context, set_forward_context @@ -42,6 +43,12 @@ class KVConnectorModelRunnerMixin: # Do this here to save a collective_rpc. kv_connector.start_load_kv(get_forward_context()) + @staticmethod + def ensure_kv_transfer_shutdown() -> None: + # has_kv_transfer_group can be None during interpreter shutdown. + if has_kv_transfer_group and has_kv_transfer_group(): + ensure_kv_transfer_shutdown() + @staticmethod def maybe_wait_for_kv_save() -> None: if has_kv_transfer_group(): @@ -82,7 +89,7 @@ class KVConnectorModelRunnerMixin: scheduler_output) if has_kv_transfer_group() else nullcontext() # This context manager must be used within an active forward context. - # It encapsulates the entire KV conector lifecycle within execute_model + # It encapsulates the entire KV connector lifecycle within execute_model @staticmethod @contextmanager def _get_kv_connector_output( diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 4b5f27d27541b..01d5f0525c4e2 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -11,7 +11,8 @@ import numpy as np import torch import torch.nn as nn -from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.config import ModelConfig, SchedulerConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -62,8 +63,7 @@ class LoRAModelRunnerMixin: def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...], token_lora_mapping: tuple[int, ...], lora_requests: set[LoRARequest]) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() # Set is_prefill to True, so we always use the SGMV kernels on # non-cuda platforms. @@ -74,6 +74,11 @@ class LoRAModelRunnerMixin: is_prefill=True) self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + def _ensure_lora_enabled(self) -> None: + if not hasattr(self, "lora_manager"): + raise RuntimeError( + "LoRA is not enabled. Use --enable-lora to enable LoRA.") + def set_active_loras(self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray) -> None: @@ -171,21 +176,17 @@ class LoRAModelRunnerMixin: self.lora_manager.remove_all_adapters() def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.remove_adapter(lora_id) def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.pin_adapter(lora_id) def list_loras(self) -> set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") + self._ensure_lora_enabled() return self.lora_manager.list_adapters() diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 81c798685cb3a..dfa54d0ad83b6 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -392,7 +392,7 @@ class InputBatch: # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] - # instead, we need to temporiarily copy the data for one of the indices + # instead, we need to temporarily copy the data for one of the indices # TODO(lucas): optimize this by only copying valid indices tmp = self.token_ids_cpu[i1, ...].copy() self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 985d5ba58c49c..43f12912707f1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, cast from unittest.mock import patch import numpy as np @@ -23,6 +23,7 @@ from vllm.config import (ParallelConfig, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA @@ -386,9 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_kwargs=new_req_data.mm_kwargs, - mm_positions=new_req_data.mm_positions, - mm_hashes=new_req_data.mm_hashes, + mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=None, generator=None, @@ -821,10 +820,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_hash = req_state.mm_hashes[mm_input_id] - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - mm_hashes_pos.append( - (mm_hash, req_state.mm_positions[mm_input_id])) + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -882,13 +881,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens - mm_positions = req_state.mm_positions - mm_hashes = req_state.mm_hashes # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid # the intrinsic dynamism that `gather_mm_placeholders` introduces. - for i, pos_info in enumerate(mm_positions): + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -903,8 +901,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # The encoder output is already processed and stored # in the decoder's KV cache. continue - - mm_hash = mm_hashes[i] + mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None,\ f"Encoder cache miss for {mm_hash}." @@ -1768,28 +1765,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.grammar_bitmask_cpu.zero_() self.require_structured_out_cpu.zero_() - # We receive the structured output bitmask from the scheduler, but the - # indices of the requests in the batch may not match the indices of - # the bitmask since the scheduler doesn't know how the tpu runner is - # ordering the requests in the batch. We need to match the order of - # bitmask with the order of requests - struct_out_indices: list[int] = [] - mask_indices: list[int] = [] - for req_id in self.input_batch.req_ids: - mask_index = scheduler_output.structured_output_request_ids.get( - req_id) - if mask_index is None: + sorted_struct_requests = sorted( + scheduler_output.structured_output_request_ids.items(), + key=lambda item: item[1]) + cumulative_mask_idx = 0 + for req_id, _ in sorted_struct_requests: + if req_id not in self.input_batch.req_id_to_index: continue batch_index = self.input_batch.req_id_to_index[req_id] - struct_out_indices.append(batch_index) - mask_indices.append(mask_index) - self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( - grammar_bitmask[mask_indices]) - # It's not guaranteed that all requests in this batch require - # structured output, so create a bool tensor to represent - # the requests that need structured output. - struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) - self.require_structured_out_cpu[struct_out_indices] = True + self.grammar_bitmask_cpu[batch_index] = torch.from_numpy( + grammar_bitmask[cumulative_mask_idx]) + # It's not guaranteed that all requests in this batch require + # structured output, so create a bool tensor to represent + # the requests that need structured output. + self.require_structured_out_cpu[batch_index] = True + cumulative_mask_idx += 1 + return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ self.structured_decode_arange.to(logits.device) @@ -1887,75 +1878,6 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] -def _make_src_and_dst_indices( - src_block_ids: list[int], - dst_block_ids: list[int], - src_device: Union[torch.device, str], - dst_device: Union[torch.device, str], -) -> tuple[torch.Tensor, torch.Tensor]: - src_indices = torch.tensor(src_block_ids, - device=src_device, - dtype=torch.int64) - dst_indices = torch.tensor(dst_block_ids, - device=dst_device, - dtype=torch.int64) - return src_indices, dst_indices - - -@torch.compile(backend="openxla") -def _insert_blocks_to_tpu( - cpu_cache: torch.Tensor, - tpu_cache: torch.Tensor, - cpu_block_indices: torch.Tensor, - tpu_block_indices: torch.Tensor, -) -> None: - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to( - tpu_cache.device) - - -@torch.compile(backend="openxla") -def _swap_out_tpu_blocks( - tpu_cache: torch.Tensor, - cpu_cache: torch.Tensor, - tpu_block_indices: torch.Tensor, - cpu_block_indices: torch.Tensor, -) -> None: - """ tpu blocks to cpu blocks""" - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu() - - -def copy_kv_blocks( - src_kv_caches: dict[str, torch.Tensor], - dst_kv_caches: dict[str, torch.Tensor], - src_block_ids: list[int], - dst_block_ids: list[int], - direction: Literal["h2d", "d2h"], -) -> None: - """Copy kv blocks between different buffers.""" - if not src_kv_caches or not dst_kv_caches or \ - not src_block_ids or not dst_block_ids or \ - len(src_block_ids) != len(dst_block_ids): - return - - src_device = next(iter(src_kv_caches.values())).device - dst_device = next(iter(dst_kv_caches.values())).device - - src_indices, dst_indices = _make_src_and_dst_indices( - src_block_ids=src_block_ids, - dst_block_ids=dst_block_ids, - src_device=src_device, - dst_device=dst_device) - - _copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \ - _swap_out_tpu_blocks - for layer_name in src_kv_caches: - src_tensor = src_kv_caches[layer_name] - dst_tensor = dst_kv_caches[layer_name] - _copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) - - def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, page_size: int) -> int: """Calculates the padded number of KV cache update slices to avoid diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 9adf8a14213f3..fc72b954df9cf 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -250,7 +250,7 @@ class TPUWorker: scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) - # every worker's output is needed when kv_transfer_group is setup + # every worker's output is needed when kv_transfer_group is set up return output if self.is_driver_worker or has_kv_transfer_group( ) else None @@ -330,6 +330,9 @@ class TPUWorker: ensure_kv_transfer_initialized(vllm_config) + def shutdown(self) -> None: + self.model_runner.ensure_kv_transfer_shutdown() + if USE_TPU_COMMONS: from tpu_commons.worker import TPUWorker as TPUCommonsWorker diff --git a/vllm/v1/worker/ubatch_splitting.py b/vllm/v1/worker/ubatch_splitting.py new file mode 100644 index 0000000000000..650f0ec5138db --- /dev/null +++ b/vllm/v1/worker/ubatch_splitting.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.config import VllmConfig +from vllm.forward_context import DPMetadata +from vllm.logger import init_logger +from vllm.utils import round_up +from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices, + is_second_ubatch_empty) + +logger = init_logger(__name__) + + +def should_ubatch_with_num_tokens( + should_ubatch: bool, + orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int, + vllm_config: VllmConfig, +) -> tuple[bool, Optional[torch.Tensor]]: + dp_size = vllm_config.parallel_config.data_parallel_size + dp_rank = vllm_config.parallel_config.data_parallel_rank + return DPMetadata.should_ubatch_across_dp(should_ubatch, + orig_num_tokens_per_ubatch, + padded_num_tokens_per_ubatch, + dp_size, dp_rank) + + +def get_dp_padding_ubatch( + num_tokens_unpadded: int, num_tokens_padded: int, + should_attempt_ubatching: bool, + vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]: + """ + 1. Decides if each DP rank is going to microbatch. Either all ranks + run with microbatching or none of them do. If this function decides + not to run with microbatching. It will "abort" meaning that no padding + information will be returned to the caller. It will return (False, None) + + 2. Determines the total number of tokens that each rank will run. + All ranks will be padded out so that the run with the same number + of tokens + + Returns: tuple[ + should_ubatch: Are all DP ranks going to microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. Will be + None if should_ubatch if False + ] + + """ + assert num_tokens_padded >= num_tokens_unpadded + dp_size = vllm_config.parallel_config.data_parallel_size + if dp_size == 1: + # Early exit. + return False, None + + # If this DP rank doesn't want to attempt microbatching + if not should_attempt_ubatching: + (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( + False, 0, 0, vllm_config) + assert should_ubatch is False + assert num_tokens_across_dp is None + return should_ubatch, num_tokens_across_dp + + # Round up to the next multiple of two for even divisibility + num_tokens_padded = round_up(num_tokens_padded, 2) + num_tokens_per_ubatch = num_tokens_padded // 2 + should_ubatch = True + + # Sanity Check that the existing padding isn't giving us an empty second + # ubatch. Abort if so + if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded): + logger.debug( + "Empty second µbatch detected: unpadded tokens: %s, padded " + "tokens: %s", num_tokens_unpadded, num_tokens_padded) + should_ubatch = False + + # Note that we compute the number of padded tokens per ubatch + (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( + should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch, + vllm_config) + if not should_ubatch: + assert num_tokens_across_dp is None + return should_ubatch, num_tokens_across_dp + + assert num_tokens_across_dp is not None + + max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item()) + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) + return should_ubatch, num_tokens_after_padding + + +def ubatch_split( + max_num_scheduled_tokens: int, + num_tokens_unpadded: int, + num_tokens_padded: int, + vllm_config: VllmConfig, +) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]: + """ + Coordinates amongst all DP ranks to determine if and how the full batch + should be split into microbatches. + + Returns: tuple[ + ubatch_slices: if this is set then all DP ranks have agreed to + microbatch + num_tokens_after_padding: A tensor containing the total number of + tokens per-microbatch for each DP rank including padding. Will be + None if ubatch_slices is None + ] + + """ + parallel_config = vllm_config.parallel_config + # Don't bother with the should_ubatch handshaking unless microbatching + # is enabled + if not parallel_config.enable_dbo: + return (None, None) + + # Check preconditions for microbatching + should_attempt_ubatching = \ + parallel_config.enable_dbo and \ + num_tokens_unpadded >= \ + parallel_config.dbo_decode_token_threshold \ + and max_num_scheduled_tokens == 1 + + # Don't microbatch unless every other DP worker is also microbatching + num_tokens_after_padding = None + (should_ubatch, num_tokens_after_padding) = get_dp_padding_ubatch( + num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching, + vllm_config) + if not should_ubatch: + return (None, None) + + # This doesn't actually pad the ubatch slices. It just initializes the + # split point to the padded value so that padding can be applied + # to the second ubatch in pad_out_ubatch_slice after attention + # metadata creation + assert num_tokens_after_padding is not None + total_num_tokens_per_ubatch = int(num_tokens_after_padding[0].item()) + padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch) + padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, + num_tokens_unpadded) + + # Note there's an assumption here that there's 1 token per request + ubatch_slices = [ + UBatchSlice(padded_first_ubatch_slice, padded_first_ubatch_slice), + UBatchSlice(padded_second_ubatch_slice, padded_second_ubatch_slice) + ] + + return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py new file mode 100644 index 0000000000000..6716d171cc701 --- /dev/null +++ b/vllm/v1/worker/ubatch_utils.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +from typing_extensions import TypeAlias + + +@dataclass +class UBatchSlice: + request_slice: slice + token_slice: slice + + +UBatchSlices: TypeAlias = list[UBatchSlice] + + +def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int) -> bool: + return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py new file mode 100644 index 0000000000000..9aeaa9909dc81 --- /dev/null +++ b/vllm/v1/worker/ubatching.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +from typing import Optional + +import torch + +from vllm import forward_context +from vllm.forward_context import ForwardContext +from vllm.utils import current_stream + +_THREAD_ID_TO_CONTEXT: dict = {} +_CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None] + + +class UBatchContext: + """ + Context manager for micro-batching synchronization using threading events. + """ + + def __init__(self, + id: int, + comm_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + forward_context: ForwardContext, + ready_barrier: threading.Barrier, + cpu_wait_event: threading.Event, + cpu_signal_event: threading.Event, + gpu_comm_done_event: torch.cuda.Event, + gpu_compute_done_event: torch.cuda.Event, + schedule: str = "default"): + self.id = id + self.comm_stream = comm_stream + self.compute_stream = compute_stream + self.forward_context = forward_context + self.ready_barrier = ready_barrier + self.cpu_wait_event = cpu_wait_event + self.cpu_signal_event = cpu_signal_event + self.current_stream = compute_stream + self.gpu_comm_done_event = gpu_comm_done_event + self.gpu_compute_done_event = gpu_compute_done_event + self.schedule = schedule + self.recv_hook = None + + def __enter__(self): + global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT + _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id + _CURRENT_CONTEXTS[self.id] = self + self.ready_barrier.wait() + + self.cpu_wait_event.wait() + self.cpu_wait_event.clear() + self._restore_context() + # Assume we start on the compute stream + assert current_stream() == self.compute_stream + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT + _CURRENT_CONTEXTS[self.id] = None + del _THREAD_ID_TO_CONTEXT[threading.get_ident()] + self.maybe_run_recv_hook() + self.cpu_signal_event.set() + self.cpu_wait_event.clear() + self.current_stream = self.compute_stream + torch.cuda.set_stream(self.current_stream) + return False + + def _restore_context(self): + forward_context._forward_context = self.forward_context + torch.cuda.set_stream(self.current_stream) + + def update_stream(self, stream): + self.current_stream = stream + torch.cuda.set_stream(self.current_stream) + + def _signal_comm_done(self): + self.gpu_comm_done_event.record(self.comm_stream) + + def _signal_compute_done(self): + self.gpu_compute_done_event.record(self.compute_stream) + + def _wait_compute_done(self): + self.comm_stream.wait_event(self.gpu_compute_done_event) + + def _wait_comm_done(self): + self.compute_stream.wait_event(self.gpu_comm_done_event) + + def _cpu_yield(self): + # It is critical for correctness that only one thread is running + # at a time. These asserts just make sure that this is the only + # thread running before waking the other one up and going to sleep + assert forward_context._forward_context == self.forward_context + assert current_stream() == self.current_stream + assert not self.cpu_wait_event.is_set() + + self.cpu_signal_event.set() + self.cpu_wait_event.wait() + self.cpu_wait_event.clear() + self._restore_context() + + def switch_to_comm_sync(self): + self._signal_compute_done() + self.update_stream(self.comm_stream) + self._wait_comm_done() + + def maybe_run_recv_hook(self): + if self.recv_hook is not None: + self.recv_hook() + self.recv_hook = None + + def yield_(self): + self.current_stream = current_stream() + self._cpu_yield() + if self.current_stream != current_stream(): + self.update_stream(self.current_stream) + + def yield_and_switch_from_compute_to_comm(self): + assert current_stream() == self.compute_stream + self._signal_compute_done() + self._cpu_yield() + assert self.current_stream == self.compute_stream + self.update_stream(self.comm_stream) + self._wait_compute_done() + + def yield_and_switch_from_comm_to_compute(self): + assert current_stream() == self.comm_stream + self._signal_comm_done() + self._cpu_yield() + assert self.current_stream == self.comm_stream + self.update_stream(self.compute_stream) + self._wait_comm_done() + + +def dbo_enabled() -> bool: + return len(_THREAD_ID_TO_CONTEXT) > 0 + + +def dbo_current_ubatch_id() -> int: + if len(_THREAD_ID_TO_CONTEXT) == 0: + return 0 + return _THREAD_ID_TO_CONTEXT[threading.get_ident()] + + +def _register_ubatch_function(func): + + def wrapper(*args, **kwargs): + if len(_THREAD_ID_TO_CONTEXT) > 0: + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + ctx = _CURRENT_CONTEXTS[ctx_idx] + func(ctx, *args, **kwargs) + + return wrapper + + +dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( + UBatchContext.yield_and_switch_from_compute_to_comm) +dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( + UBatchContext.yield_and_switch_from_comm_to_compute) +dbo_yield = _register_ubatch_function(UBatchContext.yield_) +dbo_maybe_run_recv_hook = _register_ubatch_function( + UBatchContext.maybe_run_recv_hook) +dbo_switch_to_comm_sync = _register_ubatch_function( + UBatchContext.switch_to_comm_sync) + + +def dbo_register_recv_hook(recv_hook): + if len(_THREAD_ID_TO_CONTEXT) > 0: + ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] + next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] + next_ctx.recv_hook = recv_hook + + +def make_ubatch_contexts( + num_micro_batches: int, + compute_stream: torch.cuda.Stream, + comm_stream: torch.cuda.Stream, + forward_contexts: list[ForwardContext], + ready_barrier: threading.Barrier, + schedule: str = "default", +) -> list[UBatchContext]: + assert num_micro_batches == 2, "only been tested with 2 micro-batches" + """ + Create a context manager for micro-batching synchronization. + """ + cpu_events = [threading.Event() for _ in range(num_micro_batches)] + gpu_comm_done_events = [ + torch.cuda.Event() for _ in range(num_micro_batches) + ] + gpu_compute_done_events = [ + torch.cuda.Event() for _ in range(num_micro_batches) + ] + + assert len(forward_contexts) == 2 + + ctxs = [] + for i in range(num_micro_batches): + ctx = UBatchContext(id=i, + compute_stream=compute_stream, + comm_stream=comm_stream, + forward_context=forward_contexts[i], + ready_barrier=ready_barrier, + cpu_wait_event=cpu_events[i], + cpu_signal_event=cpu_events[(i + 1) % + num_micro_batches], + gpu_comm_done_event=gpu_comm_done_events[i], + gpu_compute_done_event=gpu_compute_done_events[i], + schedule=schedule) + ctxs.append(ctx) + + return ctxs diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 6767804c71b9f..b76ac633892f3 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import ModelConfig, SchedulerConfig +from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry +from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec @@ -129,9 +130,17 @@ class MultiModalBudget: @dataclass class AttentionGroup: backend: type[AttentionBackend] - metadata_builder: AttentionMetadataBuilder + metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] + def get_metadata_builder(self, + ubatch_id: Optional[int] = None + ) -> AttentionMetadataBuilder: + if ubatch_id is None: + return self.metadata_builders[0] + assert len(self.metadata_builders) > ubatch_id + return self.metadata_builders[ubatch_id] + def sanity_check_mm_encoder_outputs( mm_embeddings: MultiModalEmbeddings, @@ -269,7 +278,17 @@ def bind_kv_cache( # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer # has different layer_name but the same layer_index. - raise NotImplementedError + + # TODO - analyze where runner_kv_caches is used and the right + # way to ensure it properly reflects multiple attention layers + # in the same decoder block. + if current_platform.is_cuda() or current_platform.is_xpu(): + # We know that the GPU runner is not impacted by this + # case. Some test code depends on runner_kv_caches, but + # not in a way that's impacted by ignoring this. + pass + else: + raise NotImplementedError layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name]) @@ -277,3 +296,28 @@ def bind_kv_cache( for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] + + +def is_residual_scattered_for_sp(vllm_config: VllmConfig, + num_input_tokens: int) -> bool: + """Check if the residual tensor is scattered for sequence parallelism. + + The residual tensor is scattered across tensor parallel ranks when sequence + parallelism and tensor parallelism is enabled, and the number of + input tokens is one of the compilation sizes. + """ + if not vllm_config.compilation_config.pass_config.\ + enable_sequence_parallelism: + return False + + tp = vllm_config.parallel_config.tensor_parallel_size + + if tp == 1: + return False + + # When sequence parallelism is enabled, we always pad num_input_tokens + # to be a multiple of tensor_parallel_size (tp) earlier. + assert num_input_tokens % tp == 0 + + # Currently, SP is only enabled for static size fx graphs. + return (num_input_tokens in vllm_config.compilation_config.compile_sizes) diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index fb892211f19db..7becdd392498f 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -45,8 +45,12 @@ def _torch_cuda_wrapper(): self.synchronize = lambda: None try: - # replace cuda Event with xpu Event, this should work by default + # replace cuda APIs with xpu APIs, this should work by default torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.default_stream = torch.xpu.current_stream + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.stream = torch.xpu.stream yield finally: # if anything goes wrong, just patch it with a placeholder diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 17288cda8eccf..7355206f30f57 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -84,7 +84,7 @@ class XPUWorker(Worker): """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks + Then, it calculates the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. .. tip:: You may limit the usage of GPU memory diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py deleted file mode 100644 index 12fd25f4de2ad..0000000000000 --- a/vllm/worker/enc_dec_model_runner.py +++ /dev/null @@ -1,553 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import itertools -from typing import Any, Dict, List, Optional, Tuple, Type, cast - -import torch -import torch.distributed - -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) -from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.attention.selector import (get_env_variable_attn_backend, - get_global_forced_attn_backend) -from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context -from vllm.inputs import INPUT_REGISTRY, InputRegistry -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, - MultiModalRegistry) -from vllm.platforms import _Backend -from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad -from vllm.worker.model_runner import (GPUModelRunnerBase, - ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) -from vllm.worker.model_runner_base import ( - _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict) -from vllm.worker.utils import assert_enc_dec_mr_supported_scenario - -logger = init_logger(__name__) -LORA_WARMUP_RANK = 8 - - -@dataclasses.dataclass(frozen=True) -class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): - """ - Used by the EncoderDecoderModelRunner. - """ - encoder_input_tokens: Optional[torch.Tensor] = None - encoder_input_positions: Optional[torch.Tensor] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "inputs_embeds": self.inputs_embeds, - "input_positions": self.input_positions, - "encoder_input_tokens": self.encoder_input_tokens, - "encoder_input_positions": self.encoder_input_positions, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - "multi_modal_kwargs": self.multi_modal_kwargs, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "EncoderDecoderModelInput": - return cast( - EncoderDecoderModelInput, - super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) - - -class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): - _model_input_cls: Type[EncoderDecoderModelInput] = ( - EncoderDecoderModelInput) - _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - input_registry: InputRegistry = INPUT_REGISTRY, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - ''' - EncoderDecoderModelRunner constructor. - - `lora_config` is unused (since these features are not yet supported - for encoder/decoder models) but these arguments are present here for - compatibility with the base-class constructor. - ''' - self._maybe_force_supported_attention_backend() - - super().__init__( - vllm_config=vllm_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - input_registry=input_registry, - mm_registry=mm_registry, - ) - - # Crash for unsupported encoder/scenarios - assert_enc_dec_mr_supported_scenario(self) - - def _maybe_force_supported_attention_backend(self): - ''' - Force vLLM to use the XFormers attention backend, - which is currently the only supported option. - ''' - - def raise_backend_err(): - # The user has specified an attention backend override - # which is invalid for encoder/decoder models - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND) - - maybe_env_var_forced_backend = get_env_variable_attn_backend() - maybe_global_forced_backend = get_global_forced_attn_backend() - is_forced_by_global = maybe_global_forced_backend is not None - is_forced_by_env_var = maybe_env_var_forced_backend is not None - if is_forced_by_global: # noqa: SIM102 - # Backend override enforced by global variable takes - # precedence over vLLM backend environment variable. - if maybe_global_forced_backend not in\ - [_Backend.XFORMERS, _Backend.FLASH_ATTN]: - raise_backend_err() - elif is_forced_by_env_var: # noqa: SIM102 - # Backend override enforced by vLLM backend - # environment variable - if maybe_env_var_forced_backend not in\ - [_Backend.XFORMERS, _Backend.FLASH_ATTN]: - raise_backend_err() - - def _list_to_int32_tensor( - self, - _list: List[int], - ) -> torch.Tensor: - return torch.tensor(_list, dtype=torch.int32, device=self.device) - - def _list_to_long_tensor( - self, - _list: List[int], - ) -> torch.Tensor: - return torch.tensor(_list, dtype=torch.long, device=self.device) - - def _empty_int32_tensor(self) -> torch.Tensor: - return self._list_to_int32_tensor([]) - - def _empty_long_tensor(self) -> torch.Tensor: - return self._list_to_long_tensor([]) - - @torch.inference_mode() - def execute_model( - self, - model_input: EncoderDecoderModelInput, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError("num_steps > 1 is not supported in " - "EncoderDecoderModelRunner") - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - if (model_input.attn_metadata is not None - and model_input.attn_metadata.prefill_metadata is None - and model_input.attn_metadata.decode_metadata.use_cuda_graph): - if model_input.inputs_embeds is None: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, False)]) - else: - graph_batch_size = model_input.inputs_embeds.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, True)]) - else: - model_executable = self.model - - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - with set_forward_context(model_input.attn_metadata, self.vllm_config, - model_input.virtual_engine): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - inputs_embeds=model_input.inputs_embeds, - positions=model_input.input_positions, - encoder_input_ids=model_input.encoder_input_tokens, - encoder_positions=model_input.encoder_input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs( - multi_modal_kwargs, - device=self.device, - ), - **seqlen_agnostic_kwargs, - ) - - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) - - if not self.is_driver_worker: - return [] - - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - output: SamplerOutput = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - - return [output] - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: - return EncoderDecoderModelInput.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> EncoderDecoderModelInput: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - Since chunked prefill is not supported for encoder/decoder models, - `input_tokens` is assumed to be either entirely prefill tokens or - entirely decode tokens. - - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - ( - attn_metadata, - encoder_input_tokens_tensor, - encoder_input_positions_tensor, - ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, - model_input)) - # Inject attn_metadata encoder/cross-attention fields & - # encoder input tokens/positions into model_input. - # Frozen dataclass fields cannot be modified, so use - # dataclasses.replace to construct a new model input - # instance. - model_input = dataclasses.replace( - model_input, - attn_metadata=attn_metadata, - encoder_input_tokens=encoder_input_tokens_tensor, - encoder_input_positions=encoder_input_positions_tensor, - ) - - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - self.pin_memory, - generators=generators) - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) - - @torch.inference_mode() - def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - - # This represents the maximum number of different requests - # that will have unique loras, and therefore the max amount of - # memory consumption. Create dummy lora request copies from the - # lora request passed in, which contains a lora from the lora - # warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - dummy_lora_requests = self._add_dummy_loras( - self.lora_config.max_loras) - assert len(dummy_lora_requests) == self.lora_config.max_loras - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - logger.info("Starting profile run for multi-modal models.") - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - decoder_dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=False) - encoder_dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=True) - - # Having more tokens is over-conservative but otherwise fine - assert len( - decoder_dummy_data.seq_data.prompt_token_ids - ) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}" - ) - - assert decoder_dummy_data.multi_modal_data is None or \ - encoder_dummy_data.multi_modal_data is None, ( - "Multi-modal data can't be provided in both encoder and decoder" - ) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: decoder_dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - encoder_seq_data=encoder_dummy_data.seq_data, - cross_block_table=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - multi_modal_data=decoder_dummy_data.multi_modal_data - or encoder_dummy_data.multi_modal_data, - multi_modal_placeholders=decoder_dummy_data. - multi_modal_placeholders - or encoder_dummy_data.multi_modal_placeholders) - seqs.append(seq) - - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - self.execute_model(model_input, None, intermediate_tensors) - torch.cuda.synchronize() - return - - def _prepare_encoder_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - model_input: EncoderDecoderModelInput, - ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], - Optional[torch.Tensor]]: - """Helper method to prepare the encoder- and cross-attn-related - model inputs based on a given sequence group. These additional inputs - are used to augment an already-computed `EncoderDecoderModelInput` - data structure which already has decoder-related model inputs - populated. - - Sets the following attn_metadata fields: - * `num_encoder_tokens` - * `encoder_seq_lens` - * `encoder_seq_lens_tensor` - * `max_encoder_seq_len` - * `cross_slot_mapping` - * `cross_block_tables` - - Constructs a new model inputs data structure, based on - (1) the existing fields in the `model_inputs` argument, - and (2) the following additional fields which are - computed (or in the case of `attn_metadata`, updated) - by this function: - * attn_metadata - * encoder_input_tokens - * encoder_input_positions - - Arguments: - - * seq_group_metadata_list: list of sequence groups for which to - compute inputs - * model_inputs: model inputs data structure with decoder-oriented - fields already computed. - - Return: - - * Updated model inputs data structure - """ - - if len(seq_group_metadata_list) == 0: - return (model_input.attn_metadata, None, None) - - # Since we are not supporting chunked prefill either the entire - # batch is prefill or it is decode - is_prompt = seq_group_metadata_list[0].is_prompt - - # Build encoder inputs - encoder_seq_lens: List[int] = [] - if is_prompt: - # Prefill phase. - cross_block_tables = self._empty_int32_tensor().view( - len(seq_group_metadata_list), -1) - - # Extract input tokens/positions, cross-attention slot-mapping, - # & seq len from each sequence group metadata - ( - encoder_input_tokens, - encoder_input_positions, - cross_slot_mapping, - ) = ( - [], - [], - [], - ) - for seq_group_metadata in seq_group_metadata_list: - # Build seq lens - seq_len = seq_group_metadata.encoder_seq_data.get_len() - token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() - encoder_seq_lens.append(seq_len) - - # Build slot mapping - is_profile_run = (seq_group_metadata.block_tables is None) - if is_profile_run: - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) - else: - for i in range(0, seq_len): - block_number = seq_group_metadata.cross_block_table[ - i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - cross_slot_mapping.append(slot) - - # Build encoder input tokens - encoder_input_tokens.extend(token_ids) - encoder_input_positions.extend(list(range(0, seq_len))) - - # Convert tokens/positions & cross-attention - # slot-mapping to encoder input tensors - encoder_input_tokens_tensor = self._list_to_long_tensor( - encoder_input_tokens) - encoder_input_positions_tensor = self._list_to_long_tensor( - encoder_input_positions) - cross_slot_mapping_tensor = self._list_to_long_tensor( - cross_slot_mapping) - - else: - # Decode phase. - encoder_input_tokens_tensor = self._empty_long_tensor() - encoder_input_positions_tensor = self._empty_long_tensor() - cross_slot_mapping_tensor = self._empty_long_tensor() - # Extract cross-attention block tables & - # seq len from each sequence group metadata. - # Cross-attention block tables are empty - # during vLLM memory profiling. - cross_block_tables = [] - for seq_group_metadata in seq_group_metadata_list: - for _ in range(len(seq_group_metadata.seq_data)): - encoder_seq_lens.append( - seq_group_metadata.encoder_seq_data.get_len()) - cross_block_table = seq_group_metadata.cross_block_table - cross_block_tables.append([] if ( - cross_block_table is None) else cross_block_table) - - if (model_input.attn_metadata is not None - and model_input.attn_metadata.use_cuda_graph): - # We will be using CUDA graph replay for this decode. - max_len_of_block_table = self.get_max_block_per_batch() - batch_size = len(encoder_seq_lens) - graph_batch_size = self.vllm_config.pad_for_cudagraph( - batch_size) - assert graph_batch_size >= batch_size - cuda_graph_pad_size = graph_batch_size - batch_size - # extend the cross_block_tables and encoder_seq_lens to match - # the graph_batch_size. - cross_block_tables.extend([[] - for _ in range(cuda_graph_pad_size) - ]) - encoder_seq_lens.extend( - itertools.repeat(1, cuda_graph_pad_size)) - - else: - max_len_of_block_table = max( - len(block_table) for block_table in cross_block_tables) - - cross_block_tables = make_tensor_with_pad( - cross_block_tables, - max_len=max_len_of_block_table, - pad=0, - dtype=torch.int32, - device=self.device, - ) - - # Compute encoder sequence lengths & encoder - # sequence starting offset tensors - max_encoder_seq_len = max(encoder_seq_lens, default=0) - encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) - encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + - 1, - dtype=torch.int32, - device=self.device) - torch.cumsum(encoder_seq_lens_tensor, - dim=0, - dtype=encoder_seq_start_loc.dtype, - out=encoder_seq_start_loc[1:]) - - # Update attention metadata with encoder-oriented attributes - attn_metadata = model_input.attn_metadata - assert attn_metadata is not None - ( - attn_metadata.num_encoder_tokens, - attn_metadata.encoder_seq_lens, - attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.cross_slot_mapping, - attn_metadata.cross_block_tables, - ) = ( - sum(encoder_seq_lens), - encoder_seq_lens, - encoder_seq_lens_tensor, - max_encoder_seq_len, - encoder_seq_start_loc, - cross_slot_mapping_tensor, - cross_block_tables, - ) - - return (attn_metadata, encoder_input_tokens_tensor, - encoder_input_positions_tensor) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f05401fd01327..88f83c9dd7e6c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1337,8 +1337,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): return self.lora_manager.list_adapters() @torch.inference_mode() - def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: - """Cuda graph capture a model. + def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> int: + """Cuda graph capture a model and return cudagraph memory + consumption in bytes. Note that CUDA graph's performance gain is negligible if number of batched tokens are larger than 200. And since CUDA graph @@ -1505,6 +1506,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / GiB_bytes) + return cuda_graph_size def _update_inputs_to_capture_for_enc_dec_model(self, capture_inputs: Dict[str, diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py deleted file mode 100644 index 8317b9abff0cd..0000000000000 --- a/vllm/worker/neuron_model_runner.py +++ /dev/null @@ -1,455 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union - -import torch -from torch import nn - -from vllm.config import DeviceConfig, VllmConfig -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs -from vllm.platforms import current_platform -from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - - -@dataclass(frozen=True) -class ModelInputForNeuron(ModelRunnerInputBase): - """ - Used by the NeuronModelRunner. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - input_block_ids: Optional[torch.Tensor] = None - sampling_metadata: SamplingMetadata = None - multi_modal_kwargs: BatchedTensorInputs = None - adapter_ids: Optional[str] = None - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - return { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "input_block_ids": self.input_block_ids, - "sampling_metadata": self.sampling_metadata, - "multi_modal_kwargs": self.multi_modal_kwargs, - } - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForNeuron": - return ModelInputForNeuron( - input_tokens=tensor_dict["input_tokens"], - input_positions=tensor_dict["input_positions"], - input_block_ids=tensor_dict["input_block_ids"], - sampling_metadata=tensor_dict["sampling_metadata"], - multi_modal_kwargs=tensor_dict["multi_modal_kwargs"], - ) - - -class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): - """A model runner for AWS Neuron hardware""" - - # NEURON has an upper limit on the top_k - _MAX_NEURON_SAMPLING_TOP_K = 256 - - def __init__( - self, - vllm_config: VllmConfig, - ): - ModelRunnerBase.__init__(self, vllm_config) - - if (self.model_config is not None - and self.model_config.get_sliding_window()): - logger.warning("Sliding window is not supported on Neuron. " - "The model will run without sliding window.") - self.device_config = (self.device_config if self.device_config - is not None else DeviceConfig()) - self.lora_config = vllm_config.lora_config - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - - # Lazy initialization. - self.model: nn.Module # initialize after load_model. - - # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value, - # turn off on-device sampling. - self._on_device_sampling_disabled = int( - os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0")) - - # NEURON needs to update sampling parameters when request IDs change - # across batches. This variable stores the previous batch's request IDs - # to determine if an update is needed. - self._previous_batch_request_ids: List[str] = [] - - if not self._on_device_sampling_disabled: - self._init_neuron_sampling() - - def _init_neuron_sampling(self) -> None: - if current_platform.use_transformers_neuronx(): - from transformers_neuronx.config import GenerationConfig - else: - from transformers import GenerationConfig - logger.warning( - "On-device sampling is turned on in Neuron by default, only " - "top_k, top_p, and temperature are current supported sampling " - "parameters. To turn off the on-device sampling, please set " - "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.") - self.model_config.neuron_sampling_params = GenerationConfig( - max_length=self.scheduler_config.max_model_len, - do_sample=True, - per_batch_line=True, - top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ - * self.scheduler_config.max_num_seqs, - top_p=[1.0] * self.scheduler_config.max_num_seqs, - temperature=[1.0] * self.scheduler_config.max_num_seqs, - dynamic=True, - global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) - - def load_model(self) -> None: - self.model = get_neuron_model(self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - - def get_model(self) -> nn.Module: - return self.model - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], - BatchedTensorInputs]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - input_block_ids: List[int] = [] - - seq_lens: List[int] = [] - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - seq_len = len(prompt_tokens) - seq_lens.append(seq_len) - - input_tokens.append(prompt_tokens) - input_positions.append(list(range(seq_len))) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - assert len(block_table) == 1 - input_block_ids.append(block_table[0]) - - mm_kwargs = seq_group_metadata.multi_modal_data - if mm_kwargs: - mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs) - multi_modal_kwargs_list.append(mm_kwargs) - - max_seq_len = max(seq_lens) - assert max_seq_len > 0 - input_tokens = make_tensor_with_pad(input_tokens, - pad=0, - max_len=max_seq_len, - dtype=torch.long, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - max_len=max_seq_len, - dtype=torch.long, - device=self.device) - input_block_ids = torch.tensor(input_block_ids, - dtype=torch.long, - device=self.device) - - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - input_block_ids: List[int] = [] - context_lens: List[int] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - - seq_ids = list(seq_group_metadata.seq_data.keys()) - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) - - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append([position]) - context_lens.append(seq_len) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - assert len(block_table) == 1 - input_block_ids.append(block_table[0]) - - input_tokens = make_tensor_with_pad(input_tokens, - pad=0, - max_len=1, - dtype=torch.long, - device=self.device) - input_positions = make_tensor_with_pad(input_positions, - pad=0, - max_len=1, - dtype=torch.long, - device=self.device) - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - input_block_ids = torch.tensor(input_block_ids, - dtype=torch.long, - device=self.device) - - return input_tokens, input_positions, input_block_ids - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron: - return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForNeuron: - multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - input_block_ids) = self._prepare_decode(seq_group_metadata_list) - seq_lens = None - - if not self._on_device_sampling_disabled: - for seq_group_metadata in seq_group_metadata_list: - sampling_params = seq_group_metadata.sampling_params - top_k, top_p, temperature = ( - self._convert_to_neuron_sampling_params(sampling_params)) - sampling_params.top_k = top_k - sampling_params.top_p = top_p - sampling_params.temperature = temperature - - # we need multi_modal_data for later tokens as well - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - for seq_group_metadata in seq_group_metadata_list: - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - multi_modal_kwargs_list.append(mm_data) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since neuron worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - self.pin_memory, - generators=self.get_generators(finished_requests_ids)) - - if current_platform.use_transformers_neuronx( - ) and not self._on_device_sampling_disabled: - # Once the request IDs are changed in current iteration, we will - # update the on-device sampling parameters. - current_batch_request_ids = [ - seq_group_meta_data.request_id - for seq_group_meta_data in seq_group_metadata_list - ] - if current_batch_request_ids != self._previous_batch_request_ids: - self._update_neuron_sampling_params(seq_group_metadata_list) - self._previous_batch_request_ids = current_batch_request_ids - - return ModelInputForNeuron(input_tokens=input_tokens, - input_positions=input_positions, - input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs) - - def _update_neuron_sampling_params( - self, seq_group_metadata_list: List[SequenceGroupMetadata]): - # Update Neuron sampling parameters (GenerationConfig in Neuron) - current_sampling_params = self.model_config.neuron_sampling_params - assert current_sampling_params is not None, ( - f"Failed to update sampling_params, " - f"current sampling params is {current_sampling_params}") - - is_update_needed = False - - top_k = current_sampling_params.top_k - top_p = current_sampling_params.top_p - temperature = current_sampling_params.temperature - - # The index of a sequence's sampling parameters in neuron is equal to - # its index in `input_block_ids`. - for seq_group_metadata in seq_group_metadata_list: - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - - seq_group_top_k = sampling_params.top_k - seq_group_top_p = sampling_params.top_p - seq_group_temperature = sampling_params.temperature - - for seq_id in seq_ids: - index = seq_group_metadata.block_tables[seq_id][0] - if (top_k[index] != seq_group_top_k - or top_p[index] != seq_group_top_p - or temperature[index] != seq_group_temperature): - is_update_needed = True - - top_k[index] = seq_group_top_k - top_p[index] = seq_group_top_p - temperature[index] = seq_group_temperature - - # update_generation_config is only available in transformers-neuronx - if is_update_needed and current_platform.use_transformers_neuronx(): - self.model.model.update_generation_config(current_sampling_params) - - def _convert_to_neuron_sampling_params( - self, sampling_params: SamplingParams) -> Tuple[int, float, float]: - # Returns the top_k, top_p and temperature parameters for neuron. - top_k = sampling_params.top_k - top_p = sampling_params.top_p - temperature = sampling_params.temperature - - if temperature == 0.0: - # Enable greedy sampling on zero temperature - return (1, 1.0, 1.0) - if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: - top_k = self._MAX_NEURON_SAMPLING_TOP_K - - return (top_k, top_p, temperature) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "NeuronModelRunner does not support multi-step execution.") - - # extract top_k, top_p and temperature from model_input for neuron - # forward call - sampling_params = (torch.tensor([[ - seq_group.sampling_params.top_k, seq_group.sampling_params.top_p, - seq_group.sampling_params.temperature - ] for seq_group in model_input.sampling_metadata.seq_groups])) - - if current_platform.use_neuronx_distributed(): - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - sampling_params=sampling_params, - adapter_ids=model_input.adapter_ids, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - ) - elif current_platform.use_transformers_neuronx(): - # [TODO] validate on-device sampling - # The model signature may need change for on-device sampling - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - ) - - # Compute the logits only if the on-device sampling is turned off as - # on-device sampling outputs the token ids. - if self._on_device_sampling_disabled: - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - else: - logits = hidden_states - - # Sample the next token. - output = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - return [output] - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - def process_multi_modal_data_neuron(self, mm_data): - # this is a no-op for NeuronModelRunner - return mm_data - - def remove_all_loras(self): - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def add_lora(self, lora_request: LoRARequest): - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") - - def list_loras(self) -> Set[int]: - raise NotImplementedError( - "LoRAs are not supported for Transformers NeuronX framework") diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py deleted file mode 100644 index 3e4512a639083..0000000000000 --- a/vllm/worker/neuron_worker.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A Neuron worker class.""" -import os -from typing import List, Optional, Set, Tuple - -import torch.distributed - -from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform -from vllm.platforms.neuron import NeuronFramework -from vllm.sequence import ExecuteModelRequest -from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) - -logger = init_logger(__name__) - - -class NeuronWorker(LocalOrDistributedWorkerBase): - """A worker class that executes the model on a group of neuron cores. - """ - - model_runner: NeuronModelRunner - - def __init__(self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.is_driver_worker = is_driver_worker - self.lora_config = vllm_config.lora_config - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - neuron_framework = current_platform.get_neuron_framework_to_use() - if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: - self.model_runner = self.get_tnx_model_runner(vllm_config) - elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: - self.model_runner = self.get_neuronx_distributed_model_runner( - vllm_config) - else: - raise NotImplementedError( - "Specified framework" + - f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" + - " is either not installed or not supported." + - " Supported frameworks: " + - "[transformers-neuronx, neuronx-distributed-inference]") - - def get_tnx_model_runner(self, vllm_config): - assert (self.lora_config - is None), ("LoRA is not supported for TransformersNeuronX " - "framework.") - if self.speculative_config is not None: - raise NotImplementedError( - "Speculative decoding is not supported for TransformersNeuronX" - ) - return NeuronModelRunner(vllm_config=vllm_config) - - def get_neuronx_distributed_model_runner(self, vllm_config): - from vllm.worker.neuronx_distributed_model_runner import ( - NeuronxDistributedModelRunner) - if self.speculative_config is not None: - assert (self.lora_config is None), ( - "LoRA is not supported for Speculative Decoding") - raise NotImplementedError( - "Speculative decoding is not supported for NeuronxDistributed") - return NeuronxDistributedModelRunner(vllm_config=vllm_config) - - def init_device(self) -> None: - self.init_distributed_environment() - - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - self.model_runner.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - Swapping is not yet supported, so always return num_cpu_blocks=0. - - We configure num_gpu_blocks to be equal to max_num_seqs. - """ - # Set the number of GPU blocks to be the same as the maximum number of - # sequences that can be processed in a single batch. This is equivalent - # to schedule without PagedAttention. - num_gpu_blocks = self.scheduler_config.max_num_seqs + 1 - - # Swap not yet supported with Neuron backend. - num_cpu_blocks = 0 - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache. - """ - - # Different values are not tested. - assert num_cpu_blocks == 0 - assert num_gpu_blocks == self.scheduler_config.max_num_seqs + 1 - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - @property - def do_metadata_broadcast(self) -> bool: - return False - - @property - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - return None - - @torch.inference_mode() - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - return WorkerInput(num_seq_groups=len( - execute_model_req.seq_group_metadata_list), ) - - def execute_worker(self, worker_input: WorkerInput) -> None: - pass - - def get_cache_block_size_bytes(self) -> int: - """Determine the size in bytes of a cache block. - - This is required for speculative decoding; it is not yet implemented. - """ - raise NotImplementedError - - def init_distributed_environment(self): - """Neuron uses transformers-neuronx for tensor parallelism. - - vLLM still needs the environment initialized when TP/PP > 1 - """ - init_distributed_environment( - world_size=1, - rank=self.rank, - local_rank=self.local_rank, - distributed_init_method=self.distributed_init_method, - backend=current_platform.dist_backend, - ) - - ensure_model_parallel_initialized( - 1, - 1, - ) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - if current_platform.use_transformers_neuronx(): - raise NotImplementedError( - f"{type(self)} does not support LoRA with Neuron Framework " - f"Transformers NeuronX") - return self.model_runner.list_loras() diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py deleted file mode 100644 index 2a0f4e77c99e5..0000000000000 --- a/vllm/worker/neuronx_distributed_model_runner.py +++ /dev/null @@ -1,294 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Set - -import torch -from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import ( - get_all_supported_aspect_ratios) -from neuronx_distributed_inference.modules.generation.sampling import ( - prepare_sampling_params) -from neuronx_distributed_inference.modules.lora_serving import ( - LoraCheckpoint, LoraServingConfig) - -from vllm.config import VllmConfig -from vllm.entrypoints.openai.serving_models import LoRAModulePath -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.neuronx_distributed import ( - _get_model_architecture, get_neuron_model) -from vllm.multimodal import MultiModalKwargs -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.worker.neuron_model_runner import (ModelInputForNeuron, - NeuronModelRunner) - -logger = init_logger(__name__) - - -class NeuronxDistributedModelRunner(NeuronModelRunner): - - def __init__( - self, - vllm_config: VllmConfig, - ): - super().__init__(vllm_config) - self.lora_checkpoint = None - self.model = None - self.lora_serving_config = None - - @staticmethod - def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]): - if not lora_modules: - return None - return {_.get("name"): _.get("path") for _ in lora_modules} - - def _get_nxdi_lora_config(self): - override_neuron_config = self.model_config.override_neuron_config - lora_modules = override_neuron_config.pop("lora_modules", None) - target_modules = override_neuron_config.pop("target_modules", None) - lora_ckpt_paths = self._get_lora_paths_strings(lora_modules) - if self.lora_config.max_loras < len(lora_ckpt_paths): - raise ValueError( - "Number of LoRAs (%s) exceeds maximum " - "allowed (%s)", len(lora_ckpt_paths), - self.lora_config.max_loras) - - return LoraServingConfig( - max_loras=self.lora_config.max_loras, - max_lora_rank=self.lora_config.max_lora_rank, - target_modules=target_modules, - lora_ckpt_paths=lora_ckpt_paths, - ) - - def load_model(self) -> None: - # Update LoRA config - if self.lora_config is not None: - self.lora_serving_config = self._get_nxdi_lora_config() - self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config) - self.model = get_neuron_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - lora_serving_config=self.lora_serving_config) - - def get_nxd_sampling_params(self, sampling_metadata): - if self.model.config.neuron_config.on_device_sampling_config: - max_topk = (self.model.config.neuron_config. - on_device_sampling_config.global_topk) - else: - max_topk = self.model.config.vocab_size - - top_k = [1] * self.scheduler_config.max_num_seqs - top_p = [1.0] * self.scheduler_config.max_num_seqs - temperature = [1.0] * self.scheduler_config.max_num_seqs - - for index, sequenceGroupToSample in enumerate( - sampling_metadata.seq_groups): - top_k[index] = (sequenceGroupToSample.sampling_params.top_k - if sequenceGroupToSample.sampling_params.top_k > 0 - else max_topk) - top_p[index] = sequenceGroupToSample.sampling_params.top_p - temperature[index] = ( - sequenceGroupToSample.sampling_params.temperature) - - sampling_params = prepare_sampling_params( - batch_size=self.scheduler_config.max_num_seqs, - top_k=top_k, - top_p=top_p, - temperature=temperature) - return sampling_params - - def get_multi_modal_data_neuron(self, input_images): - raise NotImplementedError("need to restore multi-modal support") - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "NeuronModelRunner does not support multi-step execution.") - - if _get_model_architecture( - self.model.config) != "MllamaForConditionalGeneration": - return super().execute_model(model_input, kv_caches, - intermediate_tensors, num_steps) - - sampling_params = self.get_nxd_sampling_params( - model_input.sampling_metadata) - - if model_input.multi_modal_kwargs.get('pixel_values') is not None: - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - seq_ids=model_input.input_block_ids, - pixel_values=model_input.multi_modal_kwargs.get( - 'pixel_values'), - aspect_ratios=model_input.multi_modal_kwargs.get( - 'aspect_ratios'), - sampling_params=sampling_params, - num_chunks=model_input.multi_modal_kwargs.get('num_chunks'), - has_image=model_input.multi_modal_kwargs.get( - 'has_image').squeeze(1), - ) - else: - bs = model_input.input_tokens.shape[0] if (model_input.input_tokens - is not None) else 1 - empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560], - dtype=torch.bfloat16) - empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64) - num_chunks = torch.zeros((bs, 1), dtype=torch.int32) - has_image = torch.zeros([bs], dtype=torch.int32) - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - seq_ids=model_input.input_block_ids, - pixel_values=empty_pixel_values, - aspect_ratios=empty_aspect_ratios, - sampling_params=sampling_params, - num_chunks=num_chunks, - has_image=has_image, - ) - - output = self.model.sample( - hidden_states=hidden_states, - sampling_metadata=model_input.sampling_metadata, - ) - - return [output] - - def process_multi_modal_data_neuron(self, mm_data): - # Neuron uses aspect_ratios instead of aspect_ratio_ids - all_supported_aspect_ratios = get_all_supported_aspect_ratios( - self.model.config.vision_config.max_num_tiles) - aspect_ratio_ids = mm_data.get("aspect_ratio_ids") - mm_data["aspect_ratios"] = torch.tensor( - all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0) - - # Neuron's num_chunks is HF's num_tiles - mm_data["num_chunks"] = mm_data.get("num_tiles") - - # Input has an image if it has pixel_values - bs = mm_data["num_chunks"].shape[0] - pixel_values = mm_data.get("pixel_values") - if pixel_values is not None and not torch.all(pixel_values == 0): - mm_data["has_image"] = torch.ones(bs) - - else: - mm_data["has_image"] = torch.zeros(bs) - return mm_data - - def _get_lora_adapter_ids(self, seq_group_metadata_list): - # set LoRA adapter IDs for multi-lora serving - batch_size = len(seq_group_metadata_list) - if self.lora_checkpoint is not None: - # "0" indicates NxDI to use the base model for inference - adapter_ids = ["0"] * batch_size - for idx, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.lora_request is not None: - adapter_ids[ - idx] = seq_group_metadata.lora_request.lora_name - - # convert adapter_ids from strings to integers - adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices( - adapter_ids, batch_size) - else: - adapter_ids = torch.zeros((batch_size), dtype=torch.int32) - - return adapter_ids - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForNeuron: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, input_block_ids, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - input_block_ids) = self._prepare_decode(seq_group_metadata_list) - seq_lens = None - - if not self._on_device_sampling_disabled: - for seq_group_metadata in seq_group_metadata_list: - sampling_params = seq_group_metadata.sampling_params - top_k, top_p, temperature = ( - self._convert_to_neuron_sampling_params(sampling_params)) - sampling_params.top_k = top_k - sampling_params.top_p = top_p - sampling_params.temperature = temperature - - # we need multi_modal_data for later tokens as well - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - for seq_group_metadata in seq_group_metadata_list: - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - multi_modal_kwargs_list.append(mm_data) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list) - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since neuron worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - self.pin_memory, - generators=self.get_generators(finished_requests_ids)) - - return ModelInputForNeuron(input_tokens=input_tokens, - input_positions=input_positions, - input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs, - adapter_ids=lora_adapter_ids) - - def remove_all_loras(self): - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def add_lora(self, lora_request: LoRARequest): - logger.warning( - "Adding LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config. If you supplied " - "the parameter, you can ignore this warning. Ignoring" - "lora request: ", lora_request) - - def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def pin_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") - - def list_loras(self) -> Set[int]: - raise NotImplementedError( - "Managing LoRAs is only supported through the " - "lora_modules parameter in override_neuron_config") diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py deleted file mode 100644 index 512a1dca73701..0000000000000 --- a/vllm/worker/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -''' -Worker-related helper functions. -''' - -from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS -from vllm.worker.model_runner import GPUModelRunnerBase - - -def assert_enc_dec_mr_supported_scenario( - enc_dec_mr: GPUModelRunnerBase) -> None: - ''' - Asserted that the provided encoder/decoder model runner instance reflects - a supported scenario. - ''' - - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - - if enc_dec_mr.cache_config.enable_prefix_caching: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE']) - - if enc_dec_mr.sliding_window is not None: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA']) - - if enc_dec_mr.scheduler_config.chunked_prefill_enabled: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ - 'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL']) - - if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping', - None) is not None: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP'] - ) - - if enc_dec_mr.lora_config is not None: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA']) - - if enc_dec_mr.parallel_config.pipeline_parallel_size > 1: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - - if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2e20c89c632c5..12047bc390737 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -28,7 +28,6 @@ from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, memory_profiling) from vllm.worker.cache_engine import CacheEngine -from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) @@ -78,13 +77,11 @@ class Worker(LocalOrDistributedWorkerBase): "deepseek_mtp", "glm4_moe_mtp", "mimo_mtp", - "ernie_mtp")) \ + "ernie_mtp", + "qwen3_next_mtp")) \ else {"return_hidden_states": True} - ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if self.model_config.is_encoder_decoder: - ModelRunnerClass = EncoderDecoderModelRunner - self.model_runner: GPUModelRunnerBase = ModelRunnerClass( + self.model_runner: GPUModelRunnerBase = ModelRunner( vllm_config=self.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, @@ -128,8 +125,10 @@ class Worker(LocalOrDistributedWorkerBase): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") self.profiler.stop() - print( - self.profiler.key_averages().table(sort_by="self_cuda_time_total")) + # only print profiler results on rank 0 + if self.local_rank == 0: + print(self.profiler.key_averages().table( + sort_by="self_cuda_time_total")) def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = torch.cuda.mem_get_info()[0] @@ -226,13 +225,74 @@ class Worker(LocalOrDistributedWorkerBase): self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) + @torch.inference_mode() + def determine_available_kv_cache_memory(self, + total_gpu_memory: int) -> float: + if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: + # still need a profile run which compiles the model for + # max_num_batched_tokens + self.model_runner.profile_run() + + GiB = lambda b: b / GiB_bytes + msg = ( + f"Initial free memory " + f"{GiB(self.baseline_snapshot.free_memory):.2f} " + f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for " + "KV Cache as specified by kv_cache_memory_bytes config and " + "skipped memory profiling. This does does not respect the " + "gpu_memory_utilization config. Only use kv_cache_memory_bytes " + "config when you want manual control of KV cache memory " + "size. If OOM'ed, check the difference of initial free " + "memory between the current run and the previous run " + "where kv_cache_memory_bytes is suggested and update it " + "correspondingly.") + logger.info(msg) + return self.cache_config.kv_cache_memory_bytes + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + with memory_profiling( + self.baseline_snapshot, + weights_memory=self.model_runner.model_memory_usage) as result: + self.model_runner.profile_run() + + self.non_torch_memory = result.non_torch_increase + self.peak_activation_memory = result.torch_peak_increase + + self._assert_memory_footprint_increased_during_profiling() + + self.requested_memory = total_gpu_memory * \ + self.cache_config.gpu_memory_utilization + + self.available_kv_cache_memory = (self.requested_memory - + result.non_kv_cache_memory) + + msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" + "the current vLLM instance can use " + "total_gpu_memory " + f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" + " x gpu_memory_utilization " + f"({self.cache_config.gpu_memory_utilization:.2f})" + f" = {(self.requested_memory / GiB_bytes):.2f}GiB\n" + "model weights take " + f"{(result.weights_memory / GiB_bytes):.2f}GiB;" + " non_torch_memory takes " + f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" + " PyTorch activation peak memory takes " + f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" + " the rest of the memory reserved for KV Cache is " + f"{(self.available_kv_cache_memory / GiB_bytes):.2f}GiB.") + + logger.info(msg) + return self.available_kv_cache_memory + @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks + Then, it calculates the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. Tip: @@ -245,20 +305,8 @@ class Worker(LocalOrDistributedWorkerBase): torch.cuda.reset_peak_memory_stats() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - with memory_profiling( - self.baseline_snapshot, - weights_memory=self.model_runner.model_memory_usage) as result: - self.model_runner.profile_run() - - self._assert_memory_footprint_increased_during_profiling() - - memory_for_current_instance = total_gpu_memory * \ - self.cache_config.gpu_memory_utilization - available_kv_cache_memory = (memory_for_current_instance - - result.non_kv_cache_memory) + available_kv_cache_memory = self.determine_available_kv_cache_memory( + total_gpu_memory) # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -273,23 +321,6 @@ class Worker(LocalOrDistributedWorkerBase): num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" - "the current vLLM instance can use " - "total_gpu_memory " - f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" - " x gpu_memory_utilization " - f"({self.cache_config.gpu_memory_utilization:.2f})" - f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" - "model weights take " - f"{(result.weights_memory / GiB_bytes):.2f}GiB;" - " non_torch_memory takes " - f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" - " PyTorch activation peak memory takes " - f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" - " the rest of the memory reserved for KV Cache is " - f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") - - logger.info(msg) # Final cleanup gc.collect() @@ -379,8 +410,58 @@ class Worker(LocalOrDistributedWorkerBase): for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size) + + cuda_graph_memory_bytes = 0 if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) + cuda_graph_memory_bytes = self.model_runner.capture_model( + self.gpu_cache) + + if (self.cache_config.kv_cache_memory_bytes is None + and hasattr(self, "peak_activation_memory")): + # Suggests optimal kv cache memory size if we rely on + # memory_profiling to guess the kv cache memory size which + # provides peak_activation_memory and a few other memory + # consumption. `memory_profiling` does not consider + # CUDAGraph memory size and may not utilize all gpu memory. + # Users may want fine-grained control to specify kv cache + # memory size. + GiB = lambda b: round(b / GiB_bytes, 2) + non_kv_cache_memory = (self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes) + + # empirically observed that the memory profiling may + # slightly underestimate the memory consumption. + # So leave a small buffer (=150MiB) to avoid OOM. + redundancy_buffer_memory = 150 * (1 << 20) + kv_cache_memory_bytes_to_gpu_limit = ( + self.baseline_snapshot.free_memory - non_kv_cache_memory - + redundancy_buffer_memory) + kv_cache_memory_bytes_to_requested_limit = ( + int(self.requested_memory) - non_kv_cache_memory - + redundancy_buffer_memory) + + msg = ( + f"Free memory on device " + f"({GiB(self.baseline_snapshot.free_memory)}/" + f"{GiB(self.baseline_snapshot.total_memory)} GiB) on startup. " + f"Desired GPU memory utilization is " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). " + f"Actual usage is {GiB(self.model_runner.model_memory_usage)} " + f"GiB for weight, {GiB(self.peak_activation_memory)} GiB " + f"for peak activation, {GiB(self.non_torch_memory)} GiB " + f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} " + f"GiB for CUDAGraph memory. Replace gpu_memory_utilization " + f"config with `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_requested_limit}` to fit into " + f"requested memory, or `--kv-cache-memory=" + f"{kv_cache_memory_bytes_to_gpu_limit}` to fully " + f"utilize gpu memory. Current kv cache memory in use is " + f"{int(self.available_kv_cache_memory)} bytes.") + logger.info(msg) + # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) @@ -537,8 +618,10 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, current_platform.dist_backend) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.decode_context_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index a1fa7f2cf7a2e..aa76d21f0fcaa 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -129,6 +129,10 @@ class WorkerBase: """Get vocabulary size from model configuration.""" return self.model_config.get_vocab_size() + def shutdown(self) -> None: + """Clean up resources held by the worker.""" + return + class DelegateWorkerBase(WorkerBase): """ @@ -519,6 +523,10 @@ class WorkerWrapperBase: from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + def shutdown(self) -> None: + if self.worker is not None: + self.worker.shutdown() + def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: """ Adjust the rpc_rank based on the given mapping.