diff --git a/.buildkite/ci_config.yaml b/.buildkite/ci_config.yaml new file mode 100644 index 0000000000000..199c33159fde3 --- /dev/null +++ b/.buildkite/ci_config.yaml @@ -0,0 +1,24 @@ +name: vllm_ci +job_dirs: + - ".buildkite/test_areas" + - ".buildkite/image_build" +run_all_patterns: + - "docker/Dockerfile" + - "CMakeLists.txt" + - "requirements/common.txt" + - "requirements/cuda.txt" + - "requirements/build.txt" + - "requirements/test.txt" + - "setup.py" + - "csrc/" + - "cmake/" +run_all_exclude_patterns: + - "docker/Dockerfile." + - "csrc/cpu/" + - "csrc/rocm/" + - "cmake/hipify.py" + - "cmake/cpu_extension.cmake" +registries: public.ecr.aws/q9t5s3a7 +repositories: + main: "vllm-ci-postmerge-repo" + premerge: "vllm-ci-test-repo" diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py deleted file mode 100644 index bbed80ebe8476..0000000000000 --- a/.buildkite/generate_index.py +++ /dev/null @@ -1,46 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import os - -template = """ - - -

Links for vLLM

- {x86_wheel}
- {arm_wheel}
- - -""" - -parser = argparse.ArgumentParser() -parser.add_argument("--wheel", help="The wheel path.", required=True) -args = parser.parse_args() - -filename = os.path.basename(args.wheel) - -with open("index.html", "w") as f: - print(f"Generated index.html for {args.wheel}") - # sync the abi tag with .buildkite/scripts/upload-wheels.sh - if "x86_64" in filename: - x86_wheel = filename - arm_wheel = filename.replace("x86_64", "aarch64").replace( - "manylinux1", "manylinux2014" - ) - elif "aarch64" in filename: - x86_wheel = filename.replace("aarch64", "x86_64").replace( - "manylinux2014", "manylinux1" - ) - arm_wheel = filename - else: - raise ValueError(f"Unsupported wheel: {filename}") - # cloudfront requires escaping the '+' character - f.write( - template.format( - x86_wheel=x86_wheel, - x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"), - arm_wheel=arm_wheel, - arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"), - ) - ) diff --git a/.buildkite/image_build/image_build.sh b/.buildkite/image_build/image_build.sh new file mode 100755 index 0000000000000..9a2384e524b63 --- /dev/null +++ b/.buildkite/image_build/image_build.sh @@ -0,0 +1,56 @@ +#!/bin/bash +set -e + +if [[ $# -lt 8 ]]; then + echo "Usage: $0 " + exit 1 +fi + +REGISTRY=$1 +REPO=$2 +BUILDKITE_COMMIT=$3 +BRANCH=$4 +VLLM_USE_PRECOMPILED=$5 +VLLM_MERGE_BASE_COMMIT=$6 +CACHE_FROM=$7 +CACHE_TO=$8 + +# authenticate with AWS ECR +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin $REGISTRY +aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 936637512419.dkr.ecr.us-east-1.amazonaws.com + +# docker buildx +docker buildx create --name vllm-builder --driver docker-container --use +docker buildx inspect --bootstrap +docker buildx ls + +# skip build if image already exists +if [[ -z $(docker manifest inspect $REGISTRY/$REPO:$BUILDKITE_COMMIT) ]]; then + echo "Image not found, proceeding with build..." +else + echo "Image found" + exit 0 +fi + +if [[ "${VLLM_USE_PRECOMPILED:-0}" == "1" ]]; then + merge_base_commit_build_args="--build-arg VLLM_MERGE_BASE_COMMIT=${VLLM_MERGE_BASE_COMMIT}" +else + merge_base_commit_build_args="" +fi + +# build +docker buildx build --file docker/Dockerfile \ + --build-arg max_jobs=16 \ + --build-arg buildkite_commit=$BUILDKITE_COMMIT \ + --build-arg USE_SCCACHE=1 \ + --build-arg TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0 10.0" \ + --build-arg FI_TORCH_CUDA_ARCH_LIST="8.0 8.9 9.0a 10.0a" \ + --build-arg VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED:-0}" \ + ${merge_base_commit_build_args} \ + --cache-from type=registry,ref=${CACHE_FROM},mode=max \ + --cache-to type=registry,ref=${CACHE_TO},mode=max \ + --tag ${REGISTRY}/${REPO}:${BUILDKITE_COMMIT} \ + $( [[ "${BRANCH}" == "main" ]] && echo "--tag ${REGISTRY}/${REPO}:latest" ) \ + --push \ + --target test \ + --progress plain . diff --git a/.buildkite/image_build/image_build.yaml b/.buildkite/image_build/image_build.yaml new file mode 100644 index 0000000000000..d01c71dd9becf --- /dev/null +++ b/.buildkite/image_build/image_build.yaml @@ -0,0 +1,57 @@ +group: Abuild +steps: + - label: ":docker: Build image" + key: image-build + depends_on: [] + commands: + - .buildkite/image_build/image_build.sh $REGISTRY $REPO $BUILDKITE_COMMIT $BRANCH $VLLM_USE_PRECOMPILED $VLLM_MERGE_BASE_COMMIT $CACHE_FROM $CACHE_TO + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 2 + - exit_status: -10 # Agent was lost + limit: 2 + + - label: ":docker: Build CPU image" + key: image-build-cpu + depends_on: [] + commands: + - .buildkite/image_build/image_build_cpu.sh $REGISTRY $REPO $BUILDKITE_COMMIT + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 2 + - exit_status: -10 # Agent was lost + limit: 2 + + - label: ":docker: Build HPU image" + soft_fail: true + depends_on: [] + key: image-build-hpu + commands: + - .buildkite/image_build/image_build_hpu.sh $REGISTRY $REPO $BUILDKITE_COMMIT + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 2 + - exit_status: -10 # Agent was lost + limit: 2 + + - label: ":docker: Build CPU arm64 image" + key: cpu-arm64-image-build + depends_on: [] + optional: true + commands: + - .buildkite/image_build/image_build_cpu_arm64.sh $REGISTRY $REPO $BUILDKITE_COMMIT + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 2 + - exit_status: -10 # Agent was lost + limit: 2 diff --git a/.buildkite/image_build/image_build_cpu.sh b/.buildkite/image_build/image_build_cpu.sh new file mode 100755 index 0000000000000..a69732f430985 --- /dev/null +++ b/.buildkite/image_build/image_build_cpu.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -e + +if [[ $# -lt 3 ]]; then + echo "Usage: $0 " + exit 1 +fi + +REGISTRY=$1 +REPO=$2 +BUILDKITE_COMMIT=$3 + +# authenticate with AWS ECR +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin $REGISTRY + +# skip build if image already exists +if [[ -z $(docker manifest inspect $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu) ]]; then + echo "Image not found, proceeding with build..." +else + echo "Image found" + exit 0 +fi + +# build +docker build --file docker/Dockerfile.cpu \ + --build-arg max_jobs=16 \ + --build-arg buildkite_commit=$BUILDKITE_COMMIT \ + --build-arg VLLM_CPU_AVX512BF16=true \ + --build-arg VLLM_CPU_AVX512VNNI=true \ + --build-arg VLLM_CPU_AMXBF16=true \ + --tag $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu \ + --target vllm-test \ + --progress plain . + +# push +docker push $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu diff --git a/.buildkite/image_build/image_build_cpu_arm64.sh b/.buildkite/image_build/image_build_cpu_arm64.sh new file mode 100755 index 0000000000000..615298b6555bd --- /dev/null +++ b/.buildkite/image_build/image_build_cpu_arm64.sh @@ -0,0 +1,33 @@ +#!/bin/bash +set -e + +if [[ $# -lt 3 ]]; then + echo "Usage: $0 " + exit 1 +fi + +REGISTRY=$1 +REPO=$2 +BUILDKITE_COMMIT=$3 + +# authenticate with AWS ECR +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin $REGISTRY + +# skip build if image already exists +if [[ -z $(docker manifest inspect $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu) ]]; then + echo "Image not found, proceeding with build..." +else + echo "Image found" + exit 0 +fi + +# build +docker build --file docker/Dockerfile.cpu \ + --build-arg max_jobs=16 \ + --build-arg buildkite_commit=$BUILDKITE_COMMIT \ + --tag $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu \ + --target vllm-test \ + --progress plain . + +# push +docker push $REGISTRY/$REPO:$BUILDKITE_COMMIT-cpu diff --git a/.buildkite/image_build/image_build_hpu.sh b/.buildkite/image_build/image_build_hpu.sh new file mode 100755 index 0000000000000..192447ef4577e --- /dev/null +++ b/.buildkite/image_build/image_build_hpu.sh @@ -0,0 +1,34 @@ +#!/bin/bash +set -e + +if [[ $# -lt 3 ]]; then + echo "Usage: $0 " + exit 1 +fi + +REGISTRY=$1 +REPO=$2 +BUILDKITE_COMMIT=$3 + +# authenticate with AWS ECR +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin $REGISTRY + +# skip build if image already exists +if [[ -z $(docker manifest inspect $REGISTRY/$REPO:$BUILDKITE_COMMIT-hpu) ]]; then + echo "Image not found, proceeding with build..." +else + echo "Image found" + exit 0 +fi + +# build +docker build \ + --file tests/pytorch_ci_hud_benchmark/Dockerfile.hpu \ + --build-arg max_jobs=16 \ + --build-arg buildkite_commit=$BUILDKITE_COMMIT \ + --tag $REGISTRY/$REPO:$BUILDKITE_COMMIT-hpu \ + --progress plain \ + https://github.com/vllm-project/vllm-gaudi.git + +# push +docker push $REGISTRY/$REPO:$BUILDKITE_COMMIT-hpu diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml index 46f1a9fbf6ff9..6c0b5540cbb6a 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml @@ -8,3 +8,4 @@ tasks: value: 0.80 limit: 250 # will run on 250 * 14 subjects = 3500 samples num_fewshot: 5 +rtol: 0.05 diff --git a/.buildkite/lm-eval-harness/configs/models-large-rocm.txt b/.buildkite/lm-eval-harness/configs/models-large-rocm.txt new file mode 100644 index 0000000000000..4fb0b84bc4d81 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/models-large-rocm.txt @@ -0,0 +1 @@ +Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 3627b760eddcf..f94d681197d2d 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -9,11 +9,40 @@ pytest -s -v test_lm_eval_correctness.py \ --tp-size=1 """ +import os +from contextlib import contextmanager + import lm_eval import numpy as np import yaml -RTOL = 0.08 +DEFAULT_RTOL = 0.08 + + +@contextmanager +def scoped_env_vars(new_env: dict[str, str]): + if not new_env: + # Fast path: nothing to do + yield + return + + old_values = {} + new_keys = [] + + try: + for key, value in new_env.items(): + if key in os.environ: + old_values[key] = os.environ[key] + else: + new_keys.append(key) + os.environ[key] = str(value) + yield + finally: + # Restore / clean up + for key, value in old_values.items(): + os.environ[key] = value + for key in new_keys: + os.environ.pop(key, None) def launch_lm_eval(eval_config, tp_size): @@ -32,23 +61,26 @@ def launch_lm_eval(eval_config, tp_size): f"trust_remote_code={trust_remote_code}," f"max_model_len={max_model_len}," ) - results = lm_eval.simple_evaluate( - model=backend, - model_args=model_args, - tasks=[task["name"] for task in eval_config["tasks"]], - num_fewshot=eval_config["num_fewshot"], - limit=eval_config["limit"], - # TODO(yeq): using chat template w/ fewshot_as_multiturn is supposed help - # text models. however, this is regressing measured strict-match for - # existing text models in CI, so only apply it for mm, or explicitly set - apply_chat_template=eval_config.get( - "apply_chat_template", backend == "vllm-vlm" - ), - fewshot_as_multiturn=eval_config.get("fewshot_as_multiturn", False), - # Forward decoding and early-stop controls (e.g., max_gen_toks, until=...) - gen_kwargs=eval_config.get("gen_kwargs"), - batch_size=batch_size, - ) + + env_vars = eval_config.get("env_vars", None) + with scoped_env_vars(env_vars): + results = lm_eval.simple_evaluate( + model=backend, + model_args=model_args, + tasks=[task["name"] for task in eval_config["tasks"]], + num_fewshot=eval_config["num_fewshot"], + limit=eval_config["limit"], + # TODO(yeq): using chat template w/ fewshot_as_multiturn is supposed help + # text models. however, this is regressing measured strict-match for + # existing text models in CI, so only apply it for mm, or explicitly set + apply_chat_template=eval_config.get( + "apply_chat_template", backend == "vllm-vlm" + ), + fewshot_as_multiturn=eval_config.get("fewshot_as_multiturn", False), + # Forward decoding and early-stop controls (e.g., max_gen_toks, until=...) + gen_kwargs=eval_config.get("gen_kwargs"), + batch_size=batch_size, + ) return results @@ -57,6 +89,8 @@ def test_lm_eval_correctness_param(config_filename, tp_size): results = launch_lm_eval(eval_config, tp_size) + rtol = eval_config.get("rtol", DEFAULT_RTOL) + success = True for task in eval_config["tasks"]: for metric in task["metrics"]: @@ -64,8 +98,9 @@ def test_lm_eval_correctness_param(config_filename, tp_size): measured_value = results["results"][task["name"]][metric["name"]] print( f"{task['name']} | {metric['name']}: " - f"ground_truth={ground_truth} | measured={measured_value}" + f"ground_truth={ground_truth:.3f} | " + f"measured={measured_value:.3f} | rtol={rtol}" ) - success = success and np.isclose(ground_truth, measured_value, rtol=RTOL) + success = success and np.isclose(ground_truth, measured_value, rtol=rtol) assert success diff --git a/.buildkite/performance-benchmarks/README.md b/.buildkite/performance-benchmarks/README.md index 6d494f64f14fa..015f48c2520d6 100644 --- a/.buildkite/performance-benchmarks/README.md +++ b/.buildkite/performance-benchmarks/README.md @@ -108,6 +108,65 @@ The number of this test is less stable compared to the delay and latency benchma WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`. +#### Default Parameters Field + +We can specify default parameters in a JSON field with key `defaults`. Parameters defined in the field are applied globally to all serving tests, and can be overridden in test case fields. Here is an example: + +
+ An Example of default parameters field + +```json +{ + "defaults": { + "qps_list": [ + "inf" + ], + "server_environment_variables": { + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1 + }, + "server_parameters": { + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "block_size": 128, + "disable_log_stats": "", + "load_format": "dummy" + }, + "client_parameters": { + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "num_prompts": 200, + "ignore-eos": "" + } + }, + "tests": [ + { + "test_name": "serving_llama3B_tp2_random_128_128", + "server_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + "tensor_parallel_size": 2, + }, + "client_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + } + }, + { + "test_name": "serving_qwen3_tp4_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-14B", + "tensor_parallel_size": 4, + }, + "client_parameters": { + "model": "Qwen/Qwen3-14B", + } + }, + ] +} +``` + +
+ ### Visualizing the results The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](performance-benchmarks-descriptions.md) with real benchmarking results. diff --git a/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh index 99a5a5e334f8e..34ceefe0996f2 100644 --- a/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh @@ -110,7 +110,8 @@ json2envs() { wait_for_server() { # wait for vllm server to start # return 1 if vllm server crashes - timeout 1200 bash -c ' + local timeout_val="1200" + timeout "$timeout_val" bash -c ' until curl -X POST localhost:8000/v1/completions; do sleep 1 done' && return 0 || return 1 @@ -316,12 +317,44 @@ run_throughput_tests() { run_serving_tests() { # run serving tests using `vllm bench serve` command # $1: a json file specifying serving test cases + # + # Supported JSON formats: + # 1) Plain format: top-level array + # [ { "test_name": "...", "server_parameters": {...}, ... }, ... ] + # + # 2) Default parameters field + plain format tests + # { + # "defaults": { ... }, + # "tests": [ { "test_name": "...", "server_parameters": {...}, ... }, ... ] + # } local serving_test_file serving_test_file=$1 # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do + jq -c ' + if type == "array" then + # Plain format: test cases array + .[] + elif (type == "object" and has("tests")) then + # merge the default parameters into each test cases + . as $root + | ($root.defaults // {}) as $d + | ($root.tests // [])[] + # default qps / max_concurrency from defaults if missing + | .qps_list = (.qps_list // $d.qps_list) + | .max_concurrency_list = (.max_concurrency_list // $d.max_concurrency_list) + # merge envs / params: test overrides defaults + | .server_environment_variables = + (($d.server_environment_variables // {}) + (.server_environment_variables // {})) + | .server_parameters = + (($d.server_parameters // {}) + (.server_parameters // {})) + | .client_parameters = + (($d.client_parameters // {}) + (.client_parameters // {})) + else + error("Unsupported serving test file format: must be array or object with .tests") + end + ' "$serving_test_file" | while read -r params; do # get the test name, and append the GPU type back to it. test_name=$(echo "$params" | jq -r '.test_name') if [[ ! "$test_name" =~ ^serving_ ]]; then @@ -335,20 +368,25 @@ run_serving_tests() { continue fi - # get client and server arguments + # get client and server arguments (after merged the default parameters) server_params=$(echo "$params" | jq -r '.server_parameters') server_envs=$(echo "$params" | jq -r '.server_environment_variables') client_params=$(echo "$params" | jq -r '.client_parameters') + server_args=$(json2args "$server_params") server_envs=$(json2envs "$server_envs") client_args=$(json2args "$client_params") + + # qps_list qps_list=$(echo "$params" | jq -r '.qps_list') qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') echo "Running over qps list $qps_list" + + # max_concurrency_list (fallback to num_prompts if missing) max_concurrency_list=$(echo "$params" | jq -r '.max_concurrency_list') if [[ -z "$max_concurrency_list" || "$max_concurrency_list" == "null" ]]; then - num_prompts=$(echo "$client_params" | jq -r '.num_prompts') - max_concurrency_list="[$num_prompts]" + num_prompts=$(echo "$client_params" | jq -r '.num_prompts') + max_concurrency_list="[$num_prompts]" fi max_concurrency_list=$(echo "$max_concurrency_list" | jq -r '.[] | @sh') echo "Running over max concurrency list $max_concurrency_list" diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc2.json b/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc2.json deleted file mode 100644 index f758097e098e4..0000000000000 --- a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc2.json +++ /dev/null @@ -1,610 +0,0 @@ -[ - { - "test_name": "serving_llama8B_bf16_tp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - } -] diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc3.json b/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc3.json deleted file mode 100644 index 0b1a42e790255..0000000000000 --- a/.buildkite/performance-benchmarks/tests/serving-tests-cpu-snc3.json +++ /dev/null @@ -1,1023 +0,0 @@ -[ - { - "test_name": "serving_llama8B_bf16_pp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_bf16_pp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_bf16_tp2pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_pp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_tp2pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int8_pp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int8_tp2pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_pp1_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_tp2pp3_sharegpt", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 - } - }, - { - "test_name": "serving_llama8B_int4_pp1_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp2_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp4_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 4, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - }, - { - "test_name": "serving_llama8B_int4_tp2pp3_random_128_128", - "qps_list": ["inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "quantization": "awq", - "tensor_parallel_size": 2, - "pipeline_parallel_size": 3, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 1000 - } - } -] diff --git a/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json b/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json index f792956f39472..8f7200862d20c 100644 --- a/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json +++ b/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json @@ -1,276 +1,246 @@ -[ - { - "test_name": "serving_llama8B_tp1_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 32 - } +{ + "defaults": { + "qps_list": [ + "inf" + ], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 }, - { - "test_name": "serving_llama8B_tp2_sharegpt", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 32 - } + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" }, - { - "test_name": "serving_llama8B_tp1_random_128_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp2_random_128_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp1_random_128_2048", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 2048, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp2_random_128_2048", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 128, - "random-output-len": 2048, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp1_random_2048_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 2048, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } - }, - { - "test_name": "serving_llama8B_tp2_random_2048_128", - "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [32], - "server_environment_variables": { - "VLLM_RPC_TIMEOUT": 100000, - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, - "VLLM_CPU_SGL_KERNEL": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "server_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 2, - "dtype": "bfloat16", - "distributed_executor_backend": "mp", - "block_size": 128, - "trust_remote_code": "", - "enable_chunked_prefill": "", - "disable_log_stats": "", - "enforce_eager": "", - "max_num_batched_tokens": 2048, - "max_num_seqs": 256, - "load_format": "dummy" - }, - "client_parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "backend": "vllm", - "dataset_name": "random", - "random-input-len": 2048, - "random-output-len": 128, - "ignore-eos": "", - "num_prompts": 32 - } + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "ignore-eos": "", + "num_prompts": 200 } -] + }, + "tests": [ + { + "test_name": "serving_llama8B_tp1_sharegpt", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json" + } + }, + { + "test_name": "serving_llama8B_tp2_sharegpt", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json" + } + }, + { + "test_name": "serving_llama8B_tp1_random_128_128", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp2_random_128_128", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp4_random_128_128", + "server_parameters": { + "tensor_parallel_size": 4 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp1_random_128_2048", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048 + } + }, + { + "test_name": "serving_llama8B_tp2_random_128_2048", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048 + } + }, + { + "test_name": "serving_llama8B_tp4_random_128_2048", + "server_parameters": { + "tensor_parallel_size": 4 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048 + } + }, + { + "test_name": "serving_llama8B_tp1_random_2048_128", + "server_parameters": { + "tensor_parallel_size": 1 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 2048, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp2_random_2048_128", + "server_parameters": { + "tensor_parallel_size": 2 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 2048, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama8B_tp4_random_2048_128", + "server_parameters": { + "tensor_parallel_size": 4 + }, + "client_parameters": { + "dataset_name": "random", + "random-input-len": 2048, + "random-output-len": 128 + } + }, + { + "test_name": "serving_llama3B_tp1_random_128_128", + "server_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "meta-llama/Llama-3.2-3B-Instruct", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_granite2B_tp1_random_128_128", + "server_parameters": { + "model": "ibm-granite/granite-3.2-2b-instruct", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "ibm-granite/granite-3.2-2b-instruct", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_qwen1.7B_tp1_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-1.7B", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "Qwen/Qwen3-1.7B", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_qwen4B_tp1_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-4B", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "Qwen/Qwen3-4B", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_qwen8B_tp1_random_128_128", + "server_parameters": { + "model": "Qwen/Qwen3-8B", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "Qwen/Qwen3-8B", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_glm9B_tp1_random_128_128", + "server_parameters": { + "model": "zai-org/glm-4-9b-hf", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "zai-org/glm-4-9b-hf", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + }, + { + "test_name": "serving_gemma7B_tp1_random_128_128", + "server_parameters": { + "model": "google/gemma-7b", + "tensor_parallel_size": 1 + }, + "client_parameters": { + "model": "google/gemma-7b", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128 + } + } + ] +} diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 38c400ba1faf5..a9d51557bd9bb 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -8,13 +8,28 @@ steps: commands: # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" + - label: "Build arm64 wheel - CUDA 13.0" + depends_on: ~ + id: build-wheel-arm64-cuda-13-0 + agents: + queue: arm64_cpu_queue_postmerge + commands: + # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: + # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35" + env: + DOCKER_BUILDKIT: "1" + # aarch64 build - label: "Build arm64 CPU wheel" depends_on: ~ @@ -25,24 +40,11 @@ steps: - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_BUILD_ACL=ON --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - - "bash .buildkite/scripts/upload-wheels.sh" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35" env: DOCKER_BUILDKIT: "1" # x86 + CUDA builds - - label: "Build wheel - CUDA 12.8" - depends_on: ~ - id: build-wheel-cuda-12-8 - agents: - queue: cpu_queue_postmerge - commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - - "mkdir artifacts" - - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - - "bash .buildkite/scripts/upload-wheels.sh" - env: - DOCKER_BUILDKIT: "1" - - label: "Build wheel - CUDA 12.9" depends_on: ~ id: build-wheel-cuda-12-9 @@ -52,7 +54,7 @@ steps: - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - - "bash .buildkite/scripts/upload-wheels.sh" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_31" env: DOCKER_BUILDKIT: "1" @@ -65,7 +67,21 @@ steps: - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - - "bash .buildkite/scripts/upload-wheels.sh" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35" + env: + DOCKER_BUILDKIT: "1" + + # x86 CPU wheel build + - label: "Build x86 CPU wheel" + depends_on: ~ + id: build-wheel-x86-cpu + agents: + queue: cpu_queue_postmerge + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --build-arg VLLM_CPU_AMXBF16=true --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/scripts/upload-wheels.sh manylinux_2_35" env: DOCKER_BUILDKIT: "1" @@ -109,7 +125,6 @@ steps: - label: "Annotate release workflow" depends_on: - create-multi-arch-manifest - - build-wheel-cuda-12-8 id: annotate-release-workflow agents: queue: cpu_queue_postmerge diff --git a/.buildkite/scripts/generate-nightly-index.py b/.buildkite/scripts/generate-nightly-index.py new file mode 100644 index 0000000000000..d0965fbd56405 --- /dev/null +++ b/.buildkite/scripts/generate-nightly-index.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# do not complain about line length (for docstring) +# ruff: noqa: E501 + +import argparse +import json +import sys +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any +from urllib.parse import quote + +import regex as re + +if not sys.version_info >= (3, 12): + raise RuntimeError("This script requires Python 3.12 or higher.") + +INDEX_HTML_TEMPLATE = """ + + + + +{items} + + +""" + + +@dataclass +class WheelFileInfo: + package_name: str + version: str + build_tag: str | None + python_tag: str + abi_tag: str + platform_tag: str + variant: str | None + filename: str + + +def parse_from_filename(file: str) -> WheelFileInfo: + """ + Parse wheel file name to extract metadata. + + The format of wheel names: + {package_name}-{version}(-{build_tag})?-{python_tag}-{abi_tag}-{platform_tag}.whl + All versions could contain a variant like '+cu129' or '.cpu' or `.rocm` (or not). + Example: + vllm-0.11.0-cp38-abi3-manylinux1_x86_64.whl + vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl + vllm-0.11.1rc8.dev14+gaa384b3c0-cp38-abi3-manylinux2014_aarch64.whl + vllm-0.11.1rc8.dev14+gaa384b3c0.cu130-cp38-abi3-manylinux1_x86_64.whl + """ + wheel_file_re = re.compile( + r"^(?P.+)-(?P[^-]+?)(-(?P[^-]+))?-(?P[^-]+)-(?P[^-]+)-(?P[^-]+)\.whl$" + ) + match = wheel_file_re.match(file) + if not match: + raise ValueError(f"Invalid wheel file name: {file}") + + package_name = match.group("package_name") + version = match.group("version") + build_tag = match.group("build_tag") + python_tag = match.group("python_tag") + abi_tag = match.group("abi_tag") + platform_tag = match.group("platform_tag") + + # extract variant from version + variant = None + if "dev" in version: + ver_after_dev = version.split("dev")[-1] + if "." in ver_after_dev: + variant = ver_after_dev.split(".")[-1] + version = version.removesuffix("." + variant) + else: + if "+" in version: + version, variant = version.split("+") + + return WheelFileInfo( + package_name=package_name, + version=version, + build_tag=build_tag, + python_tag=python_tag, + abi_tag=abi_tag, + platform_tag=platform_tag, + variant=variant, + filename=file, + ) + + +def generate_project_list(subdir_names: list[str], comment: str = "") -> str: + """ + Generate project list HTML content linking to each project & variant sub-directory. + """ + href_tags = [] + for name in sorted(subdir_names): + name = name.strip("/").strip(".") + href_tags.append(f' {name}/
') + return INDEX_HTML_TEMPLATE.format(items="\n".join(href_tags), comment=comment) + + +def generate_package_index_and_metadata( + wheel_files: list[WheelFileInfo], + wheel_base_dir: Path, + index_base_dir: Path, + comment: str = "", +) -> tuple[str, str]: + """ + Generate package index HTML content for a specific package, linking to actual wheel files. + """ + href_tags = [] + metadata = [] + for file in sorted(wheel_files, key=lambda x: x.filename): + relative_path = ( + wheel_base_dir.relative_to(index_base_dir, walk_up=True) / file.filename + ) + # handle with '+' in URL, and avoid double-encoding '/' and already-encoded '%2B' + # NOTE: this is AWS S3 specific behavior! + file_path_quoted = quote(relative_path.as_posix(), safe=":%/") + href_tags.append(f' {file.filename}
') + file_meta = asdict(file) + file_meta["path"] = file_path_quoted + metadata.append(file_meta) + index_str = INDEX_HTML_TEMPLATE.format(items="\n".join(href_tags), comment=comment) + metadata_str = json.dumps(metadata, indent=2) + return index_str, metadata_str + + +def generate_index_and_metadata( + whl_files: list[str], + wheel_base_dir: Path, + index_base_dir: Path, + default_variant: str | None = None, + alias_to_default: str | None = None, + comment: str = "", +): + """ + Generate index for all wheel files. + + Args: + whl_files (list[str]): List of wheel files (must be directly under `wheel_base_dir`). + wheel_base_dir (Path): Base directory for wheel files. + index_base_dir (Path): Base directory to store index files. + default_variant (str | None): The default variant name, if any. + alias_to_default (str | None): Alias variant name for the default variant, if any. + comment (str | None): Optional comment to include in the generated HTML files. + + First, parse all wheel files to extract metadata. + We need to collect all wheel files for each variant, and generate an index for it (in a sub-directory). + The index for the default variant (if any) is generated in the root index directory. + + If `default_variant` is provided, all wheels must have variant suffixes, and the default variant index + is purely a copy of the corresponding variant index, with only the links adjusted. + Otherwise, all wheels without variant suffixes are treated as the default variant. + + If `alias_to_default` is provided, an additional alias sub-directory is created, it has the same content + as the default variant index, but the links are adjusted accordingly. + + Index directory structure: + index_base_dir/ (hosted at wheels.vllm.ai/{nightly,$commit,$version}/) + index.html # project list, linking to "vllm/" and other packages, and all variant sub-directories + vllm/ + index.html # package index, pointing to actual files in wheel_base_dir (relative path) + metadata.json # machine-readable metadata for all wheels in this package + cpu/ # cpu variant sub-directory + index.html + vllm/ + index.html + metadata.json + cu129/ # cu129 is actually the alias to default variant + index.html + vllm/ + index.html + metadata.json + cu130/ # cu130 variant sub-directory + index.html + vllm/ + index.html + metadata.json + ... + + metadata.json stores a dump of all wheel files' metadata in a machine-readable format: + [ + { + "package_name": "vllm", + "version": "0.10.2rc2", + "build_tag": null, + "python_tag": "cp38", + "abi_tag": "abi3", + "platform_tag": "manylinux2014_aarch64", + "variant": "cu129", + "filename": "vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl", + "path": "../vllm-0.10.2rc2%2Bcu129-cp38-abi3-manylinux2014_aarch64.whl" # to be concatenated with the directory URL and URL-encoded + }, + ... + ] + """ + + parsed_files = [parse_from_filename(f) for f in whl_files] + + if not parsed_files: + print("No wheel files found, skipping index generation.") + return + + # Group by variant + variant_to_files: dict[str, list[WheelFileInfo]] = {} + for file in parsed_files: + variant = file.variant or "default" + if variant not in variant_to_files: + variant_to_files[variant] = [] + variant_to_files[variant].append(file) + + print(f"Found variants: {list(variant_to_files.keys())}") + + # sanity check for default variant + if default_variant: + if "default" in variant_to_files: + raise ValueError( + "All wheel files must have variant suffixes when `default_variant` is specified." + ) + if default_variant not in variant_to_files: + raise ValueError( + f"Default variant '{default_variant}' not found among wheel files." + ) + + if alias_to_default: + if "default" not in variant_to_files: + # e.g. only some wheels are uploaded to S3 currently + print( + "[WARN] Alias to default variant specified, but no default variant found." + ) + elif alias_to_default in variant_to_files: + raise ValueError( + f"Alias variant name '{alias_to_default}' already exists among wheel files." + ) + else: + variant_to_files[alias_to_default] = variant_to_files["default"].copy() + print(f"Alias variant '{alias_to_default}' created for default variant.") + + # Generate comment in HTML header + comment_str = f" ({comment})" if comment else "" + comment_tmpl = f"Generated on {datetime.now().isoformat()}{comment_str}" + + # Generate index for each variant + subdir_names = set() + for variant, files in variant_to_files.items(): + if variant == "default": + variant_dir = index_base_dir + else: + variant_dir = index_base_dir / variant + subdir_names.add(variant) + + variant_dir.mkdir(parents=True, exist_ok=True) + + # gather all package names in this variant + packages = set(f.package_name for f in files) + if variant == "default": + # these packages should also appear in the "project list" + # generate after all variants are processed + subdir_names = subdir_names.union(packages) + else: + # generate project list for this variant directly + project_list_str = generate_project_list(sorted(packages), comment_tmpl) + with open(variant_dir / "index.html", "w") as f: + f.write(project_list_str) + + for package in packages: + # filter files belonging to this package only + package_files = [f for f in files if f.package_name == package] + package_dir = variant_dir / package + package_dir.mkdir(parents=True, exist_ok=True) + index_str, metadata_str = generate_package_index_and_metadata( + package_files, wheel_base_dir, package_dir, comment + ) + with open(package_dir / "index.html", "w") as f: + f.write(index_str) + with open(package_dir / "metadata.json", "w") as f: + f.write(metadata_str) + + # Generate top-level project list index + project_list_str = generate_project_list(sorted(subdir_names), comment_tmpl) + with open(index_base_dir / "index.html", "w") as f: + f.write(project_list_str) + + +if __name__ == "__main__": + """ + Arguments: + --version : version string for the current build (e.g., commit hash) + --current-objects : path to JSON file containing current S3 objects listing in this version directory + --output-dir : directory to store generated index files + --alias-to-default : (optional) alias variant name for the default variant + --comment : (optional) comment string to include in generated HTML files + """ + + parser = argparse.ArgumentParser( + description="Process nightly build wheel files to generate indices." + ) + parser.add_argument( + "--version", + type=str, + required=True, + help="Version string for the current build (e.g., commit hash)", + ) + parser.add_argument( + "--current-objects", + type=str, + required=True, + help="Path to JSON file containing current S3 objects listing in this version directory", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory to store generated index files", + ) + parser.add_argument( + "--alias-to-default", + type=str, + default=None, + help="Alias variant name for the default variant", + ) + parser.add_argument( + "--comment", + type=str, + default="", + help="Optional comment string to include in generated HTML files", + ) + + args = parser.parse_args() + + version = args.version + if "/" in version or "\\" in version: + raise ValueError("Version string must not contain slashes.") + current_objects_path = Path(args.current_objects) + output_dir = Path(args.output_dir) + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + + # Read current objects JSON + with open(current_objects_path) as f: + current_objects: dict[str, list[dict[str, Any]]] = json.load(f) + + # current_objects looks like from list_objects_v2 S3 API: + """ + "Contents": [ + { + "Key": "e2f56c309d2a28899c68975a7e104502d56deb8f/vllm-0.11.2.dev363+ge2f56c309-cp38-abi3-manylinux1_x86_64.whl", + "LastModified": "2025-11-28T14:00:32+00:00", + "ETag": "\"37a38339c7cdb61ca737021b968075df-52\"", + "ChecksumAlgorithm": [ + "CRC64NVME" + ], + "ChecksumType": "FULL_OBJECT", + "Size": 435649349, + "StorageClass": "STANDARD" + }, + ... + ] + """ + + # Extract wheel file keys + wheel_files = [] + for item in current_objects.get("Contents", []): + key: str = item["Key"] + if key.endswith(".whl"): + wheel_files.append(key.split("/")[-1]) # only the filename is used + + print(f"Found {len(wheel_files)} wheel files for version {version}: {wheel_files}") + + # keep only "official" files for a non-nightly version (specifed by cli args) + PY_VERSION_RE = re.compile(r"^\d+\.\d+\.\d+([a-zA-Z0-9.+-]*)?$") + if PY_VERSION_RE.match(version): + # upload-wheels.sh ensures no "dev" is in args.version + wheel_files = list( + filter(lambda x: version in x and "dev" not in x, wheel_files) + ) + print(f"Non-nightly version detected, wheel files used: {wheel_files}") + else: + print("Nightly version detected, keeping all wheel files.") + + # Generate index and metadata, assuming wheels and indices are stored as: + # s3://vllm-wheels/{version}/ + # s3://vllm-wheels// + wheel_base_dir = Path(output_dir).parent / version + index_base_dir = Path(output_dir) + + generate_index_and_metadata( + whl_files=wheel_files, + wheel_base_dir=wheel_base_dir, + index_base_dir=index_base_dir, + default_variant=None, + alias_to_default=args.alias_to_default, + comment=args.comment.strip(), + ) + print(f"Successfully generated index and metadata in {output_dir}") diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh index b5f6b2494792f..b6274d698d01a 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-arm.sh @@ -36,11 +36,17 @@ function cpu_tests() { set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + # Run model tests + docker exec cpu-test bash -c " + set -e + pytest -x -v -s tests/models/multimodal/generation/test_whisper.py -m cpu_model" + # Run kernel tests docker exec cpu-test bash -c " set -e pytest -x -v -s tests/kernels/test_onednn.py - pytest -x -v -s tests/kernels/attention/test_cpu_attn.py" + pytest -x -v -s tests/kernels/attention/test_cpu_attn.py + pytest -x -v -s tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic" # basic online serving docker exec cpu-test bash -c ' diff --git a/.buildkite/scripts/hardware_ci/run-npu-test.sh b/.buildkite/scripts/hardware_ci/run-npu-test.sh index 29c8f5ed5a91a..0db1abe37ba11 100644 --- a/.buildkite/scripts/hardware_ci/run-npu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-npu-test.sh @@ -74,6 +74,7 @@ FROM ${BASE_IMAGE_NAME} # Define environments ENV DEBIAN_FRONTEND=noninteractive +ENV SOC_VERSION="ascend910b1" RUN pip config set global.index-url http://cache-service-vllm.nginx-pypi-cache.svc.cluster.local:${PYPI_CACHE_PORT}/pypi/simple && \ pip config set global.trusted-host cache-service-vllm.nginx-pypi-cache.svc.cluster.local && \ diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 4d163399cfc6c..dfc9db512d1e9 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -38,6 +38,7 @@ docker run \ python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -cc.cudagraph_mode=NONE python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp + python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager cd tests pytest -v -s v1/core @@ -46,6 +47,6 @@ docker run \ pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py pytest -v -s v1/structured_output pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py --ignore=v1/spec_decode/test_speculators_eagle3.py - pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py + pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_example_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py pytest -v -s v1/test_serial_utils.py ' diff --git a/.buildkite/scripts/run-prime-rl-test.sh b/.buildkite/scripts/run-prime-rl-test.sh index 5b25c358fc4aa..3fb7c82c8d333 100755 --- a/.buildkite/scripts/run-prime-rl-test.sh +++ b/.buildkite/scripts/run-prime-rl-test.sh @@ -12,6 +12,11 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" PRIME_RL_REPO="https://github.com/PrimeIntellect-ai/prime-rl.git" PRIME_RL_DIR="${REPO_ROOT}/prime-rl" +if command -v rocm-smi &> /dev/null || command -v rocminfo &> /dev/null; then + echo "AMD GPU detected. Prime-RL currently only supports NVIDIA. Skipping..." + exit 0 +fi + echo "Setting up Prime-RL integration test environment..." # Clean up any existing Prime-RL directory diff --git a/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh b/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh new file mode 100644 index 0000000000000..937a43d1a3221 --- /dev/null +++ b/.buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash +set -euxo pipefail + +# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] +THRESHOLD=${1:-0.25} +NUM_Q=${2:-1319} +PORT=${3:-8040} +OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled} +mkdir -p "${OUT_DIR}" + +wait_for_server() { + local port=$1 + timeout 600 bash -c ' + until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do + sleep 1 + done' +} + +MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct" + +# Set BACKENDS based on platform +if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then + # ROCm platform + BACKENDS=("allgather_reducescatter") + # Disable MOE padding for ROCm since it is causing eplb to fail + export VLLM_ROCM_MOE_PADDING=0 +else + # Non-ROCm platform (CUDA/other) + BACKENDS=("deepep_high_throughput" "deepep_low_latency") +fi + +cleanup() { + if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then + kill "${SERVER_PID}" 2>/dev/null || true + for _ in {1..20}; do + kill -0 "${SERVER_PID}" 2>/dev/null || break + sleep 0.5 + done + kill -9 "${SERVER_PID}" 2>/dev/null || true + fi +} +trap cleanup EXIT + +for BACK in "${BACKENDS[@]}"; do + VLLM_DEEP_GEMM_WARMUP=skip \ + VLLM_ALL2ALL_BACKEND=$BACK \ + vllm serve "$MODEL" \ + --enforce-eager \ + --tensor-parallel-size 4 \ + --enable-expert-parallel \ + --enable-eplb \ + --eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \ + --speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \ + --trust-remote-code \ + --max-model-len 2048 \ + --gpu-memory-utilization 0.9 \ + --port $PORT & + SERVER_PID=$! + wait_for_server $PORT + + TAG=$(echo "$MODEL" | tr '/: \\n' '_____') + OUT="${OUT_DIR}/${TAG}_${BACK}.json" + python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT} + python3 - <= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}" +PY + + cleanup + SERVER_PID= + sleep 1 + PORT=$((PORT+1)) +done diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 945c5e48c0090..3a218a4bb2e6d 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -2,6 +2,28 @@ set -ex +# ======== part 0: setup ======== + +BUCKET="vllm-wheels" +INDICES_OUTPUT_DIR="indices" +DEFAULT_VARIANT_ALIAS="cu129" # align with vLLM_MAIN_CUDA_VERSION in vllm/envs.py +PYTHON=${PYTHON_PROG:=python3} # try to read from env var, otherwise use python3 +SUBPATH=$BUILDKITE_COMMIT +S3_COMMIT_PREFIX="s3://$BUCKET/$SUBPATH/" + +# detect if python3.10+ is available +has_new_python=$($PYTHON -c "print(1 if __import__('sys').version_info >= (3,12) else 0)") +if [[ "$has_new_python" -eq 0 ]]; then + # use new python from docker + docker pull python:3-slim + PYTHON="docker run --rm -v $(pwd):/app -w /app python:3-slim python3" +fi + +echo "Using python interpreter: $PYTHON" +echo "Python version: $($PYTHON --version)" + +# ========= part 1: collect, rename & upload the wheel ========== + # Assume wheels are in artifacts/dist/*.whl wheel_files=(artifacts/dist/*.whl) @@ -10,74 +32,76 @@ if [[ ${#wheel_files[@]} -ne 1 ]]; then echo "Error: Expected exactly one wheel file in artifacts/dist/, but found ${#wheel_files[@]}" exit 1 fi - -# Get the single wheel file wheel="${wheel_files[0]}" -# Detect architecture and rename 'linux' to appropriate manylinux version -arch=$(uname -m) -if [[ $arch == "x86_64" ]]; then - manylinux_version="manylinux1" -elif [[ $arch == "aarch64" ]]; then - manylinux_version="manylinux2014" -else - echo "Warning: Unknown architecture $arch, using manylinux1 as default" - manylinux_version="manylinux1" -fi +# default build image uses ubuntu 20.04, which corresponds to manylinux_2_31 +# we also accept params as manylinux tag +# refer to https://github.com/mayeut/pep600_compliance?tab=readme-ov-file#acceptable-distros-to-build-wheels +manylinux_version="${1:-manylinux_2_31}" # Rename 'linux' to the appropriate manylinux version in the wheel filename +if [[ "$wheel" != *"linux"* ]]; then + echo "Error: Wheel filename does not contain 'linux': $wheel" + exit 1 +fi new_wheel="${wheel/linux/$manylinux_version}" mv -- "$wheel" "$new_wheel" wheel="$new_wheel" +echo "Renamed wheel to: $wheel" # Extract the version from the wheel version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2) -echo "Version: $version" +echo "Version in wheel: $version" +pure_version="${version%%+*}" +echo "Pure version (without variant): $pure_version" -normal_wheel="$wheel" # Save the original wheel filename +# copy wheel to its own bucket +aws s3 cp "$wheel" "$S3_COMMIT_PREFIX" -# If the version contains "dev", rename it to v1.0.0.dev for consistency -if [[ $version == *dev* ]]; then - suffix="${version##*.}" - if [[ $suffix == cu* ]]; then - new_version="1.0.0.dev+${suffix}" - else - new_version="1.0.0.dev" - fi - new_wheel="${wheel/$version/$new_version}" - # use cp to keep both files in the artifacts directory - cp -- "$wheel" "$new_wheel" - wheel="$new_wheel" - version="$new_version" -fi +# ========= part 2: generate and upload indices ========== +# generate indices for all existing wheels in the commit directory +# this script might be run multiple times if there are multiple variants being built +# so we need to guarantee there is little chance for "TOCTOU" issues +# i.e., one process is generating indices while another is uploading a new wheel +# so we need to ensure no time-consuming operations happen below -# Upload the wheel to S3 -python3 .buildkite/generate_index.py --wheel "$normal_wheel" +# list all wheels in the commit directory +echo "Existing wheels on S3:" +aws s3 ls "$S3_COMMIT_PREFIX" +obj_json="objects.json" +aws s3api list-objects-v2 --bucket "$BUCKET" --prefix "$SUBPATH/" --delimiter / --output json > "$obj_json" +mkdir -p "$INDICES_OUTPUT_DIR" -# generate index for this commit -aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" -aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" - -if [[ $normal_wheel == *"cu129"* ]]; then - # only upload index.html for cu129 wheels (default wheels) as it - # is available on both x86 and arm64 - aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" - aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" +# call script to generate indicies for all existing wheels +# this indices have relative paths that could work as long as it is next to the wheel directory in s3 +# i.e., the wheels are always in s3://vllm-wheels// +# and indices can be placed in //, or /nightly/, or // +if [[ ! -z "$DEFAULT_VARIANT_ALIAS" ]]; then + alias_arg="--alias-to-default $DEFAULT_VARIANT_ALIAS" else - echo "Skipping index files for non-cu129 wheels" + alias_arg="" fi -# generate index for nightly -aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" -aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" +# HACK: we do not need regex module here, but it is required by pre-commit hook +# To avoid any external dependency, we simply replace it back to the stdlib re module +sed -i 's/import regex as re/import re/g' .buildkite/scripts/generate-nightly-index.py +$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "commit $BUILDKITE_COMMIT" $alias_arg -if [[ $normal_wheel == *"cu129"* ]]; then - # only upload index.html for cu129 wheels (default wheels) as it - # is available on both x86 and arm64 - aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" -else - echo "Skipping index files for non-cu129 wheels" +# copy indices to // unconditionally +echo "Uploading indices to $S3_COMMIT_PREFIX" +aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "$S3_COMMIT_PREFIX" + +# copy to /nightly/ only if it is on the main branch and not a PR +if [[ "$BUILDKITE_BRANCH" == "main" && "$BUILDKITE_PULL_REQUEST" == "false" ]]; then + echo "Uploading indices to overwrite /nightly/" + aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/nightly/" fi -aws s3 cp "$wheel" "s3://vllm-wheels/$version/" -aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html" +# re-generate and copy to // only if it does not have "dev" in the version +if [[ "$version" != *"dev"* ]]; then + echo "Re-generating indices for /$pure_version/" + rm -rf "$INDICES_OUTPUT_DIR/*" + mkdir -p "$INDICES_OUTPUT_DIR" + $PYTHON .buildkite/scripts/generate-nightly-index.py --version "$pure_version" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" --comment "version $pure_version" $alias_arg + aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/$pure_version/" +fi diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 687b6b08507c7..3c9b8cbedcf06 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -39,9 +39,9 @@ steps: # if this test fails, it means the nightly torch version is not compatible with some # of the dependencies. Please check the error message and add the package to whitelist # in /vllm/tools/pre_commit/generate_nightly_torch_test.py - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking soft_fail: true source_file_dependencies: - requirements/nightly_torch_test.txt @@ -50,9 +50,9 @@ steps: - label: Async Engine, Inputs, Utils, Worker Test # 10min timeout_in_minutes: 15 - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/multimodal @@ -61,11 +61,11 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min - timeout_in_minutes: 20 - mirror_hardwares: [amdexperimental, amdproduction] +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min + timeout_in_minutes: 30 + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/test_inputs.py @@ -73,6 +73,7 @@ steps: - tests/multimodal - tests/standalone_tests/lazy_imports.py - tests/tokenizers_ + - tests/tool_parsers - tests/transformers_utils - tests/config no_gpu: true @@ -82,6 +83,7 @@ steps: - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal - pytest -v -s tokenizers_ + - pytest -v -s tool_parsers - pytest -v -s transformers_utils - pytest -v -s config @@ -115,9 +117,9 @@ steps: - pytest -v -s basic_correctness/test_cpu_offload.py - label: Entrypoints Unit Tests # 5min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking timeout_in_minutes: 10 working_dir: "/vllm-workspace/tests" fast_check: true @@ -214,6 +216,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py @@ -252,9 +255,9 @@ steps: - torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep - label: EPLB Algorithm Test # 5min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking timeout_in_minutes: 15 working_dir: "/vllm-workspace/tests" source_file_dependencies: @@ -325,10 +328,10 @@ steps: commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py -- label: V1 Test e2e + engine # 30min - timeout_in_minutes: 45 +- label: V1 Test e2e + engine # 65min + timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] - agent_pool: mi325_1 + agent_pool: mi325_4 # grade: Blocking source_file_dependencies: - vllm/ @@ -341,9 +344,9 @@ steps: - label: V1 Test entrypoints # 35min timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/v1 @@ -391,6 +394,21 @@ steps: commands: - pytest -v -s v1/attention +- label: Batch Invariance Tests (H100) # 10min + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + timeout_in_minutes: 25 + gpu: h100 + source_file_dependencies: + - vllm/v1/attention + - vllm/model_executor/layers + - tests/v1/determinism/ + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pip install pytest-timeout pytest-forked + - pytest -v -s v1/determinism/test_batch_invariance.py + - pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py + - label: V1 Test attention (B200) # 10min timeout_in_minutes: 30 gpu: b200 @@ -401,9 +419,9 @@ steps: - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this - label: V1 Test others (CPU) # 5 mins - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental, amdproduction, amdtentative] agent_pool: mi325_1 - # grade: Blocking + grade: Blocking source_file_dependencies: - vllm/ - tests/v1 @@ -419,29 +437,34 @@ steps: - label: Examples Test # 30min timeout_in_minutes: 45 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking working_dir: "/vllm-workspace/examples" source_file_dependencies: - vllm/entrypoints + - vllm/multimodal - examples/ commands: - pip install tensorizer # for tensorizer test + # for basic + - python3 offline_inference/basic/chat.py - python3 offline_inference/basic/generate.py --model facebook/opt-125m - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 - - python3 offline_inference/basic/chat.py - - python3 offline_inference/prefix_caching.py - - python3 offline_inference/llm_engine_example.py - - python3 offline_inference/audio_language.py --seed 0 - - python3 offline_inference/vision_language.py --seed 0 - - python3 offline_inference/vision_language_pooling.py --seed 0 - - python3 offline_inference/vision_language_multi_image.py --seed 0 - - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py + # for multi-modal models + - python3 offline_inference/audio_language.py --seed 0 + - python3 offline_inference/vision_language.py --seed 0 + - python3 offline_inference/vision_language_multi_image.py --seed 0 + - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 + # for pooling models + - python3 pooling/pooling/vision_language_pooling.py --seed 0 + # for features demo + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 @@ -495,7 +518,7 @@ steps: - label: PyTorch Compilation Unit Tests # 15min timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking torch_nightly: true @@ -512,7 +535,7 @@ steps: - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking torch_nightly: true @@ -568,7 +591,7 @@ steps: - label: Kernels Attention Test %N # 23min timeout_in_minutes: 35 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_8 # grade: Blocking source_file_dependencies: @@ -595,7 +618,7 @@ steps: - label: Kernels MoE Test %N # 40min timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_8 # grade: Blocking source_file_dependencies: @@ -622,6 +645,26 @@ steps: commands: - pytest -v -s kernels/mamba +- label: Kernels DeepGEMM Test (H100) # Nvidia-centric +# Not replicating for CUTLAS & CuTe + timeout_in_minutes: 45 + gpu: h100 + num_gpus: 1 + source_file_dependencies: + - tools/install_deepgemm.sh + - vllm/utils/deep_gemm.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization + - tests/kernels/quantization/test_block_fp8.py + - tests/kernels/moe/test_deepgemm.py + - tests/kernels/moe/test_batched_deepgemm.py + - tests/kernels/attention/test_deepgemm_attention.py + commands: + - pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm + - pytest -v -s kernels/moe/test_deepgemm.py + - pytest -v -s kernels/moe/test_batched_deepgemm.py + - pytest -v -s kernels/attention/test_deepgemm_attention.py + - label: Model Executor Test # 23min timeout_in_minutes: 35 torch_nightly: true @@ -680,16 +723,18 @@ steps: # we can only upgrade after this is resolved # TODO(jerryzh168): resolve the above comment - uv pip install --system torchao==0.13.0 + - uv pip install --system conch-triton-kernels - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py -- label: LM Eval Small Models # 15min - timeout_in_minutes: 20 - mirror_hardwares: [amdexperimental, amdproduction] +- label: LM Eval Small Models # 53min + timeout_in_minutes: 75 + mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization + autorun_on_main: true commands: - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 @@ -702,7 +747,7 @@ steps: - csrc/ - vllm/entrypoints/openai/ - vllm/model_executor/models/whisper.py - commands: # LMEval + commands: # LMEval+Transcription WER check # Transcription WER check is skipped because encoder-decoder models are not supported on ROCm, see https://github.com/vllm-project/vllm/issues/27442 - pytest -s entrypoints/openai/correctness/ @@ -716,19 +761,7 @@ steps: - vllm/ - tests/tool_use commands: - - pytest -v -s -m 'not cpu_test' tool_use - -- label: OpenAI-Compatible Tool Use (CPU) # 5 mins - mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_1 - # grade: Blocking - timeout_in_minutes: 10 - source_file_dependencies: - - vllm/ - - tests/tool_use - no_gpu: true - commands: - - pytest -v -s -m 'cpu_test' tool_use + - pytest -v -s tool_use ##### models test ##### @@ -899,6 +932,18 @@ steps: commands: - pytest -v -s models/language/pooling_mteb_test +- label: Multi-Modal Processor Test (CPU) + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + no_gpu: true + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + - label: Multi-Modal Processor Test # 44min timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] @@ -926,8 +971,8 @@ steps: - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work -- label: Multi-Modal Accuracy Eval (Small Models) # 10min - timeout_in_minutes: 70 +- label: Multi-Modal Accuracy Eval (Small Models) # 150min - 180min + timeout_in_minutes: 180 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking @@ -939,7 +984,8 @@ steps: commands: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1 -- label: Multi-Modal Models Test (Extended) 1 +- label: Multi-Modal Models Test (Extended) 1 # 60min + timeout_in_minutes: 120 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking @@ -963,7 +1009,8 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' -- label: Multi-Modal Models Test (Extended) 3 +- label: Multi-Modal Models Test (Extended) 3 # 75min + timeout_in_minutes: 150 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking @@ -1055,6 +1102,7 @@ steps: - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py - label: Blackwell Fusion and Compile Tests # 30 min timeout_in_minutes: 40 @@ -1064,11 +1112,18 @@ steps: - csrc/quantization/fp4/ - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/worker/ + - vllm/v1/cudagraph_dispatcher.py - vllm/compilation/ # can affect pattern matching - vllm/model_executor/layers/layernorm.py - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/test_fusion_attn.py + - tests/compile/test_silu_mul_quant_fusion.py + - tests/compile/distributed/test_fusion_all_reduce.py + - tests/compile/distributed/test_fusions_e2e.py + - tests/compile/fullgraph/test_full_graph.py commands: - nvidia-smi - pytest -v -s tests/compile/test_fusion_attn.py @@ -1079,7 +1134,7 @@ steps: # Wrap with quotes to escape yaml - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) - - pytest -v -s tests/compile/distributed/test_full_graph.py::test_fp8_kv_scale_compile + - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile - label: Blackwell Fusion E2E Tests # 30 min timeout_in_minutes: 40 @@ -1097,17 +1152,15 @@ steps: - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py - tests/compile/distributed/test_fusions_e2e.py - - tests/compile/fullgraph/test_full_graph.py commands: - nvidia-smi # Run all e2e fusion tests - - pytest -v -s tests/compile/test_fusions_e2e.py + - pytest -v -s tests/compile/distributed/test_fusions_e2e.py -- label: ROCm GPT-OSS Eval +- label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 working_dir: "/vllm-workspace/" - agent_pool: mi325_1 - mirror_hardwares: [amdexperimental, amdproduction] + gpu: b200 optional: true # run on nightlies source_file_dependencies: - tests/evals/gpt_oss @@ -1116,7 +1169,7 @@ steps: - vllm/v1/attention/backends/flashinfer.py commands: - uv pip install --system 'gpt-oss[eval]==0.0.5' - - VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 + - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 - label: Blackwell Quantized MoE Test timeout_in_minutes: 60 @@ -1216,6 +1269,7 @@ steps: - tests/v1/worker/test_worker_memory_snapshot.py commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py @@ -1251,7 +1305,7 @@ steps: - label: Plugin Tests (2 GPUs) # 40min timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 # grade: Blocking working_dir: "/vllm-workspace/tests" @@ -1320,14 +1374,14 @@ steps: - pytest -v -s -x lora/test_llm_with_multi_loras.py - pytest -v -s -x lora/test_olmoe_tp.py - # Disabled for now because MXFP4 backend on non-cuda platform + # Disabled for now because MXFP4 backend on non-cuda platform # doesn't support LoRA yet #- pytest -v -s -x lora/test_gptoss_tp.py - label: Weight Loading Multiple GPU Test # 33min timeout_in_minutes: 45 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 # grade: Blocking working_dir: "/vllm-workspace/tests" @@ -1386,7 +1440,83 @@ steps: - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - pytest -v -s -x lora/test_mixtral.py + - label: LM Eval Large Models # optional + gpu: a100 + optional: true + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 + +##### H100 test ##### +- label: LM Eval Large Models (H100) # optional + gpu: h100 + optional: true + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4 + + +##### H200 test ##### +- label: Distributed Tests (H200) # optional + mirror_hardwares: [amdexperimental] + agent_pool: mi325_2 + # grade: Blocking + gpu: h200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py + - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py + - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py + #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm + - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py + - pytest -v -s tests/distributed/test_context_parallel.py + - HIP_VISIBLE_DEVICES=0,1 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - pytest -v -s tests/v1/distributed/test_dbo.py + +##### B200 test ##### +- label: Distributed Tests (B200) # optional + gpu: b200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/v1/distributed/test_dbo.py + +##### E2E Eval Tests ##### +- label: LM Eval Small Models (1 Card) # 15min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + +- label: LM Eval Large Models (4 Card) mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking @@ -1401,52 +1531,29 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 -##### H100 test ##### -- label: LM Eval Large Models (H100) # optional - mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_4 - # grade: Blocking - gpu: h100 - optional: true - num_gpus: 4 +- label: ROCm LM Eval Large Models (8 Card) + mirror_hardwares: [amdproduction] + agent_pool: mi325_8 + num_gpus: 8 working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-rocm.txt --tp-size=8 + +- label: ROCm GPT-OSS Eval + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + agent_pool: mi325_1 + mirror_hardwares: [amdexperimental, amdproduction] + optional: true # run on nightlies source_file_dependencies: - - csrc/ - - vllm/model_executor/layers/quantization + - tests/evals/gpt_oss + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py commands: - - export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100 - - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4 - -##### H200 test ##### -- label: Distributed Tests (H200) # optional - mirror_hardwares: [amdexperimental] - agent_pool: mi325_2 - # grade: Blocking - gpu: h200 - optional: true - working_dir: "/vllm-workspace/" - num_gpus: 2 - commands: - - pytest -v -s tests/compile/distributed/test_async_tp.py - - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py - - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py - #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" - - pytest -v -s tests/compile/distributed/test_sequence_parallel.py - - pytest -v -s tests/distributed/test_context_parallel.py - - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - - pytest -v -s tests/v1/distributed/test_dbo.py - -##### B200 test ##### -- label: Distributed Tests (B200) # optional - gpu: b200 - optional: true - working_dir: "/vllm-workspace/" - num_gpus: 2 - commands: - - pytest -v -s tests/distributed/test_context_parallel.py - - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py - - pytest -v -s tests/v1/distributed/test_dbo.py + - uv pip install --system 'gpt-oss[eval]==0.0.5' + - VLLM_ROCM_USE_AITER_MHA=0 VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 ##### RL Integration Tests ##### - label: Prime-RL Integration Test # 15min @@ -1462,9 +1569,8 @@ steps: - .buildkite/scripts/run-prime-rl-test.sh commands: - bash .buildkite/scripts/run-prime-rl-test.sh - - label: DeepSeek V2-Lite Accuracy - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking timeout_in_minutes: 60 @@ -1475,8 +1581,8 @@ steps: commands: - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010 -- label: Qwen3-30B-A3B-FP8-block Accuracy - mirror_hardwares: [amdexperimental] +- label: Qwen3-30B-A3B-FP8-block Accuracy (H100) + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking timeout_in_minutes: 60 @@ -1486,3 +1592,35 @@ steps: working_dir: "/vllm-workspace" commands: - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 + +- label: Qwen3-30B-A3B-FP8-block Accuracy (B200) + timeout_in_minutes: 60 + gpu: b200 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 + +- label: DeepSeek V2-Lite Async EPLB Accuracy + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030 + +- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040 diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9f2107fb1e5ab..9d0b3fdd3a02c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -57,8 +57,8 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min - timeout_in_minutes: 20 +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 20min + timeout_in_minutes: 30 source_file_dependencies: - vllm/ - tests/test_inputs.py @@ -66,6 +66,7 @@ steps: - tests/multimodal - tests/standalone_tests/lazy_imports.py - tests/tokenizers_ + - tests/tool_parsers - tests/transformers_utils - tests/config no_gpu: true @@ -75,6 +76,7 @@ steps: - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal - pytest -v -s tokenizers_ + - pytest -v -s tool_parsers - pytest -v -s transformers_utils - pytest -v -s config @@ -350,7 +352,8 @@ steps: timeout_in_minutes: 25 gpu: h100 source_file_dependencies: - - vllm/ + - vllm/v1/attention + - vllm/model_executor/layers - tests/v1/determinism/ commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn @@ -387,23 +390,28 @@ steps: working_dir: "/vllm-workspace/examples" source_file_dependencies: - vllm/entrypoints + - vllm/multimodal - examples/ commands: - pip install tensorizer # for tensorizer test + # for basic + - python3 offline_inference/basic/chat.py - python3 offline_inference/basic/generate.py --model facebook/opt-125m - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 - - python3 offline_inference/basic/chat.py - - python3 offline_inference/prefix_caching.py - - python3 offline_inference/llm_engine_example.py - - python3 offline_inference/audio_language.py --seed 0 - - python3 offline_inference/vision_language.py --seed 0 - - python3 offline_inference/vision_language_pooling.py --seed 0 - - python3 offline_inference/vision_language_multi_image.py --seed 0 - - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py + # for multi-modal models + - python3 offline_inference/audio_language.py --seed 0 + - python3 offline_inference/vision_language.py --seed 0 + - python3 offline_inference/vision_language_multi_image.py --seed 0 + - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 + # for pooling models + - python3 pooling/pooling/vision_language_pooling.py --seed 0 + # for features demo + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 @@ -462,7 +470,9 @@ steps: # tests covered elsewhere. # Use `find` to launch multiple instances of pytest so that # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 - - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;" + # However, find does not normally propagate error codes, so we combine it with xargs + # (using -0 for proper path handling) + - "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'" - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 @@ -476,7 +486,9 @@ steps: # as it is a heavy test that is covered in other steps. # Use `find` to launch multiple instances of pytest so that # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 - - "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\\\;" + # However, find does not normally propagate error codes, so we combine it with xargs + # (using -0 for proper path handling) + - "find compile/fullgraph -maxdepth 1 -name 'test_*.py' -not -name 'test_full_graph.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'" - label: PyTorch Fullgraph Test # 27min timeout_in_minutes: 40 @@ -662,16 +674,7 @@ steps: - vllm/ - tests/tool_use commands: - - pytest -v -s -m 'not cpu_test' tool_use - -- label: OpenAI-Compatible Tool Use (CPU) # 5 mins - timeout_in_minutes: 10 - source_file_dependencies: - - vllm/ - - tests/tool_use - no_gpu: true - commands: - - pytest -v -s -m 'cpu_test' tool_use + - pytest -v -s tool_use ##### models test ##### @@ -682,6 +685,7 @@ steps: source_file_dependencies: - vllm/ - tests/models/test_initialization.py + - tests/models/registry.py commands: # Run a subset of model initialization tests - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset @@ -694,6 +698,7 @@ steps: - vllm/model_executor/models/ - vllm/transformers_utils/ - tests/models/test_initialization.py + - tests/models/registry.py commands: # Only when vLLM model source is modified - test initialization of a large # subset of supported models (the complement of the small subset in the above @@ -826,7 +831,7 @@ steps: - tests/models/multimodal no_gpu: true commands: - - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - "pip install git+https://github.com/TIGER-AI-Lab/Mantis.git || echo 'Mantis installation skipped (decord not available on CPU-only environment)'" - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py - label: Multi-Modal Processor Test @@ -1218,6 +1223,8 @@ steps: # FIXIT: find out which code initialize cuda before running the test # before the fix, we need to use spawn to test it - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # Alot of these tests are on the edge of OOMing + - export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True # There is some Tensor Parallelism related processing logic in LoRA that # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py @@ -1336,6 +1343,7 @@ steps: - label: Prime-RL Integration Test # 15min timeout_in_minutes: 30 optional: true + soft_fail: true num_gpus: 2 working_dir: "/vllm-workspace" source_file_dependencies: @@ -1369,4 +1377,4 @@ steps: num_gpus: 2 working_dir: "/vllm-workspace" commands: - - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 \ No newline at end of file + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 diff --git a/.buildkite/test_areas/attention.yaml b/.buildkite/test_areas/attention.yaml new file mode 100644 index 0000000000000..6e444eae14c74 --- /dev/null +++ b/.buildkite/test_areas/attention.yaml @@ -0,0 +1,21 @@ +group: Attention +depends_on: + - image-build +steps: +- label: V1 attention (H100) + timeout_in_minutes: 30 + gpu: h100 + source_file_dependencies: + - vllm/v1/attention + - tests/v1/attention + commands: + - pytest -v -s v1/attention + +- label: V1 attention (B200) + timeout_in_minutes: 30 + gpu: b200 + source_file_dependencies: + - vllm/v1/attention + - tests/v1/attention + commands: + - VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this diff --git a/.buildkite/test_areas/basic_correctness.yaml b/.buildkite/test_areas/basic_correctness.yaml new file mode 100644 index 0000000000000..759d2b5358714 --- /dev/null +++ b/.buildkite/test_areas/basic_correctness.yaml @@ -0,0 +1,16 @@ +group: Basic Correctness +depends_on: + - image-build +steps: +- label: Basic Correctness + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/ + - tests/basic_correctness/test_basic_correctness + - tests/basic_correctness/test_cpu_offload + - tests/basic_correctness/test_cumem.py + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s basic_correctness/test_cumem.py + - pytest -v -s basic_correctness/test_basic_correctness.py + - pytest -v -s basic_correctness/test_cpu_offload.py diff --git a/.buildkite/test_areas/benchmarks.yaml b/.buildkite/test_areas/benchmarks.yaml new file mode 100644 index 0000000000000..574b642d407b0 --- /dev/null +++ b/.buildkite/test_areas/benchmarks.yaml @@ -0,0 +1,19 @@ +group: Benchmarks +depends_on: + - image-build +steps: +- label: Benchmarks + timeout_in_minutes: 20 + working_dir: "/vllm-workspace/.buildkite" + source_file_dependencies: + - benchmarks/ + commands: + - bash scripts/run-benchmarks.sh + +- label: Benchmarks CLI Test + timeout_in_minutes: 20 + source_file_dependencies: + - vllm/ + - tests/benchmarks/ + commands: + - pytest -v -s benchmarks/ diff --git a/.buildkite/test_areas/compile.yaml b/.buildkite/test_areas/compile.yaml new file mode 100644 index 0000000000000..0ba00925a4838 --- /dev/null +++ b/.buildkite/test_areas/compile.yaml @@ -0,0 +1,57 @@ +group: Compile +depends_on: + - image-build +steps: +- label: Fusion and Compile Tests (B200) + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/worker/ + - vllm/v1/cudagraph_dispatcher.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/test_fusion_attn.py + - tests/compile/test_silu_mul_quant_fusion.py + - tests/compile/distributed/test_fusion_all_reduce.py + - tests/compile/distributed/test_fusions_e2e.py + - tests/compile/fullgraph/test_full_graph.py + commands: + - nvidia-smi + - pytest -v -s tests/compile/test_fusion_attn.py + - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + # this runner has 2 GPUs available even though num_gpus=2 is not set + - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py + # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time + # Wrap with quotes to escape yaml + - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" + # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) + - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile + +- label: Fusion E2E (2 GPUs)(B200) + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/" + gpu: b200 + optional: true + num_gpus: 2 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/distributed/test_fusions_e2e.py + commands: + - nvidia-smi + # Run all e2e fusion tests + - pytest -v -s tests/compile/distributed/test_fusions_e2e.py + diff --git a/.buildkite/test_areas/cuda.yaml b/.buildkite/test_areas/cuda.yaml new file mode 100644 index 0000000000000..50c0c338c2434 --- /dev/null +++ b/.buildkite/test_areas/cuda.yaml @@ -0,0 +1,22 @@ +group: CUDA +depends_on: + - image-build +steps: +- label: Platform Tests (CUDA) + timeout_in_minutes: 15 + source_file_dependencies: + - vllm/ + - tests/cuda + commands: + - pytest -v -s cuda/test_cuda_context.py + +- label: Cudagraph + timeout_in_minutes: 20 + source_file_dependencies: + - tests/v1/cudagraph + - vllm/v1/cudagraph_dispatcher.py + - vllm/config/compilation.py + - vllm/compilation + commands: + - pytest -v -s v1/cudagraph/test_cudagraph_dispatch.py + - pytest -v -s v1/cudagraph/test_cudagraph_mode.py \ No newline at end of file diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml new file mode 100644 index 0000000000000..2cc90698d916a --- /dev/null +++ b/.buildkite/test_areas/distributed.yaml @@ -0,0 +1,199 @@ +group: Distributed +depends_on: + - image-build +steps: +- label: Distributed Comm Ops + timeout_in_minutes: 20 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/distributed + - tests/distributed + commands: + - pytest -v -s distributed/test_comm_ops.py + - pytest -v -s distributed/test_shm_broadcast.py + - pytest -v -s distributed/test_shm_buffer.py + - pytest -v -s distributed/test_shm_storage.py + +- label: Distributed (2 GPUs) + timeout_in_minutes: 90 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/compilation/ + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/worker/worker_base.py + - vllm/v1/engine/ + - vllm/v1/worker/ + - tests/compile/fullgraph/test_basic_correctness.py + - tests/compile/test_wrapper.py + - tests/distributed/ + - tests/entrypoints/llm/test_collective_rpc.py + - tests/v1/distributed + - tests/v1/entrypoints/openai/test_multi_api_servers.py + - tests/v1/shutdown + - tests/v1/worker/test_worker_memory_snapshot.py + commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py + - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py + - pytest -v -s entrypoints/llm/test_collective_rpc.py + - pytest -v -s ./compile/fullgraph/test_basic_correctness.py + - pytest -v -s ./compile/test_wrapper.py + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - pytest -v -s distributed/test_sequence_parallel.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown + - pytest -v -s v1/worker/test_worker_memory_snapshot.py + +- label: Distributed Tests (4 GPUs) + timeout_in_minutes: 50 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/ + - tests/distributed/test_utils + - tests/distributed/test_pynccl + - tests/distributed/test_events + - tests/compile/fullgraph/test_basic_correctness.py + - examples/offline_inference/rlhf.py + - examples/offline_inference/rlhf_colocate.py + - tests/examples/offline_inference/data_parallel.py + - tests/v1/distributed + - tests/v1/engine/test_engine_core_client.py + - tests/distributed/test_symm_mem_allreduce.py + commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 + # test with torchrun tp=2 and external_dp=2 + - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with torchrun tp=2 and pp=2 + - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with torchrun tp=4 and dp=1 + - TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2, pp=2 and dp=1 + - PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=1 and dp=4 with ep + - DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2 and dp=2 with ep + - TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with internal dp + - python3 ../examples/offline_inference/data_parallel.py --enforce-eager + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_eagle_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py + - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp + - pytest -v -s distributed/test_utils.py + - pytest -v -s compile/fullgraph/test_basic_correctness.py + - pytest -v -s distributed/test_pynccl.py + - pytest -v -s distributed/test_events.py + - pytest -v -s distributed/test_symm_mem_allreduce.py + # TODO: create a dedicated test section for multi-GPU example tests + # when we have multiple distributed example tests + - cd ../examples/offline_inference + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + +- label: Distributed Tests (8 GPUs)(H100) + timeout_in_minutes: 10 + gpu: h100 + num_gpus: 8 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - examples/offline_inference/torchrun_dp_example.py + - vllm/config/parallel.py + - vllm/distributed/ + - vllm/v1/engine/llm_engine.py + - vllm/v1/executor/uniproc_executor.py + - vllm/v1/worker/gpu_worker.py + commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 + # test with torchrun tp=2 and dp=4 with ep + - torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep + +- label: Distributed Tests (4 GPUs)(A100) + gpu: a100 + optional: true + num_gpus: 4 + source_file_dependencies: + - vllm/ + commands: + # NOTE: don't test llama model here, it seems hf implementation is buggy + # see https://github.com/vllm-project/vllm/pull/5689 for details + - pytest -v -s distributed/test_custom_all_reduce.py + - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py + - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' + - pytest -v -s -x lora/test_mixtral.py + +- label: Distributed Tests (2 GPUs)(H200) + gpu: h200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py + - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py + - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4' + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py + - pytest -v -s tests/distributed/test_context_parallel.py + - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + - pytest -v -s tests/v1/distributed/test_dbo.py + +- label: Distributed Tests (2 GPUs)(B200) + gpu: b200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + - pytest -v -s tests/v1/distributed/test_dbo.py + +- label: 2 Node Test (4 GPUs) + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + num_nodes: 2 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + - tests/examples/offline_inference/data_parallel.py + commands: + - ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:0bec63fa317e1fbd62e19b0fc31c43c81bf89077 "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code" + +- label: Distributed NixlConnector PD accuracy (4 GPUs) + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - tests/v1/kv_connector/nixl_integration/ + commands: + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt + - bash v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh + +- label: Pipeline + Context Parallelism (4 GPUs)) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + commands: + - pytest -v -s distributed/test_pp_cudagraph.py + - pytest -v -s distributed/test_pipeline_parallel.py \ No newline at end of file diff --git a/.buildkite/test_areas/e2e_integration.yaml b/.buildkite/test_areas/e2e_integration.yaml new file mode 100644 index 0000000000000..93d389815edac --- /dev/null +++ b/.buildkite/test_areas/e2e_integration.yaml @@ -0,0 +1,59 @@ +group: E2E Integration +depends_on: + - image-build +steps: +- label: DeepSeek V2-Lite Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_eplb.sh 0.25 200 8010 + +- label: Qwen3-30B-A3B-FP8-block Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 + +- label: Qwen3-30B-A3B-FP8-block Accuracy (B200) + timeout_in_minutes: 60 + gpu: b200 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1 + +- label: Prime-RL Integration (2 GPUs) + timeout_in_minutes: 30 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + source_file_dependencies: + - vllm/ + - .buildkite/scripts/run-prime-rl-test.sh + commands: + - bash .buildkite/scripts/run-prime-rl-test.sh + +- label: DeepSeek V2-Lite Async EPLB Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030 + +- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy + timeout_in_minutes: 60 + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace" + commands: + - bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040 diff --git a/.buildkite/test_areas/engine.yaml b/.buildkite/test_areas/engine.yaml new file mode 100644 index 0000000000000..a028e0e4af4c1 --- /dev/null +++ b/.buildkite/test_areas/engine.yaml @@ -0,0 +1,26 @@ +group: Engine +depends_on: + - image-build +steps: +- label: Engine + timeout_in_minutes: 15 + source_file_dependencies: + - vllm/ + - tests/engine + - tests/test_sequence + - tests/test_config + - tests/test_logger + - tests/test_vllm_port + commands: + - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py + +- label: V1 e2e + engine + timeout_in_minutes: 45 + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - pytest -v -s v1/e2e + - pytest -v -s v1/engine diff --git a/.buildkite/test_areas/entrypoints.yaml b/.buildkite/test_areas/entrypoints.yaml new file mode 100644 index 0000000000000..0a789be943f37 --- /dev/null +++ b/.buildkite/test_areas/entrypoints.yaml @@ -0,0 +1,68 @@ +group: Entrypoints +depends_on: + - image-build +steps: +- label: Entrypoints Unit Tests + timeout_in_minutes: 10 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/entrypoints + - tests/entrypoints/ + commands: + - pytest -v -s entrypoints/openai/tool_parsers + - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling + +- label: Entrypoints Integration (LLM) + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/ + - tests/entrypoints/llm + - tests/entrypoints/offline_mode + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py + - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process + - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests + +- label: Entrypoints Integration (API Server) + timeout_in_minutes: 130 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/ + - tests/entrypoints/openai + - tests/entrypoints/test_chat_utils + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/ + - pytest -v -s entrypoints/test_chat_utils.py + + +- label: Entrypoints Integration (Pooling) + timeout_in_minutes: 50 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/ + - tests/entrypoints/pooling + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/pooling + + +- label: Entrypoints V1 + timeout_in_minutes: 50 + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + - pytest -v -s v1/entrypoints + +- label: OpenAI API Correctness + timeout_in_minutes: 30 + source_file_dependencies: + - csrc/ + - vllm/entrypoints/openai/ + - vllm/model_executor/models/whisper.py + commands: # LMEval+Transcription WER check + - pytest -s entrypoints/openai/correctness/ diff --git a/.buildkite/test_areas/expert_parallelism.yaml b/.buildkite/test_areas/expert_parallelism.yaml new file mode 100644 index 0000000000000..feb8252148c7f --- /dev/null +++ b/.buildkite/test_areas/expert_parallelism.yaml @@ -0,0 +1,23 @@ +group: Expert Parallelism +depends_on: + - image-build +steps: +- label: EPLB Algorithm + timeout_in_minutes: 15 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_algo.py + commands: + - pytest -v -s distributed/test_eplb_algo.py + +- label: EPLB Execution + timeout_in_minutes: 20 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_execute.py + commands: + - pytest -v -s distributed/test_eplb_execute.py + - pytest -v -s distributed/test_eplb_spec_decode.py \ No newline at end of file diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml new file mode 100644 index 0000000000000..7ca099516d641 --- /dev/null +++ b/.buildkite/test_areas/kernels.yaml @@ -0,0 +1,117 @@ +group: Kernels +depends_on: + - image-build +steps: +- label: Kernels Core Operation Test + timeout_in_minutes: 75 + source_file_dependencies: + - csrc/ + - tests/kernels/core + - tests/kernels/test_top_k_per_row.py + commands: + - pytest -v -s kernels/core kernels/test_top_k_per_row.py + +- label: Kernels Attention Test %N + timeout_in_minutes: 35 + source_file_dependencies: + - csrc/attention/ + - vllm/attention + - vllm/v1/attention + - tests/kernels/attention + commands: + - pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels Quantization Test %N + timeout_in_minutes: 90 + source_file_dependencies: + - csrc/quantization/ + - vllm/model_executor/layers/quantization + - tests/kernels/quantization + commands: + - pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels MoE Test %N + timeout_in_minutes: 60 + source_file_dependencies: + - csrc/quantization/cutlass_w8a8/moe/ + - csrc/moe/ + - tests/kernels/moe + - vllm/model_executor/layers/fused_moe/ + - vllm/distributed/device_communicators/ + - vllm/envs.py + - vllm/config + commands: + - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels Mamba Test + timeout_in_minutes: 45 + source_file_dependencies: + - csrc/mamba/ + - tests/kernels/mamba + - vllm/model_executor/layers/mamba/ops + commands: + - pytest -v -s kernels/mamba + +- label: Kernels DeepGEMM Test (H100) + timeout_in_minutes: 45 + gpu: h100 + num_gpus: 1 + source_file_dependencies: + - tools/install_deepgemm.sh + - vllm/utils/deep_gemm.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization + - tests/kernels/quantization/test_block_fp8.py + - tests/kernels/moe/test_deepgemm.py + - tests/kernels/moe/test_batched_deepgemm.py + - tests/kernels/attention/test_deepgemm_attention.py + commands: + - pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm + - pytest -v -s kernels/moe/test_deepgemm.py + - pytest -v -s kernels/moe/test_batched_deepgemm.py + - pytest -v -s kernels/attention/test_deepgemm_attention.py + +- label: Kernels (B200) + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/" + gpu: b200 + # optional: true + source_file_dependencies: + - csrc/quantization/fp4/ + - csrc/attention/mla/ + - csrc/quantization/cutlass_w8a8/moe/ + - vllm/model_executor/layers/fused_moe/cutlass_moe.py + - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py + - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/v1/attention/backends/mla/cutlass_mla.py + - vllm/v1/attention/backends/mla/flashinfer_mla.py + - vllm/platforms/cuda.py + - vllm/attention/selector.py + commands: + - nvidia-smi + - python3 examples/offline_inference/basic/chat.py + # Attention + # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 + - pytest -v -s tests/kernels/attention/test_attention_selector.py + - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' + - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py + - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py + # Quantization + - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' + - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py + - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py + - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py + - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py + - pytest -v -s tests/kernels/moe/test_flashinfer.py + - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py \ No newline at end of file diff --git a/.buildkite/test_areas/lm_eval.yaml b/.buildkite/test_areas/lm_eval.yaml new file mode 100644 index 0000000000000..9af43e0c375a8 --- /dev/null +++ b/.buildkite/test_areas/lm_eval.yaml @@ -0,0 +1,46 @@ +group: LM Eval +depends_on: + - image-build +steps: +- label: LM Eval Small Models + timeout_in_minutes: 75 + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + autorun_on_main: true + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + +- label: LM Eval Large Models (4 GPUs)(A100) + gpu: a100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 + +- label: LM Eval Large Models (4 GPUs)(H100) + gpu: h100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4 + +- label: LM Eval Small Models (B200) + timeout_in_minutes: 120 + gpu: b200 + optional: true + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1 diff --git a/.buildkite/test_areas/lora.yaml b/.buildkite/test_areas/lora.yaml new file mode 100644 index 0000000000000..809b4138f44ba --- /dev/null +++ b/.buildkite/test_areas/lora.yaml @@ -0,0 +1,31 @@ +group: LoRA +depends_on: + - image-build +steps: +- label: LoRA %N + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/lora + - tests/lora + commands: + - pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py + parallelism: 4 + + +- label: LoRA TP (Distributed) + timeout_in_minutes: 30 + num_gpus: 4 + source_file_dependencies: + - vllm/lora + - tests/lora + commands: + # FIXIT: find out which code initialize cuda before running the test + # before the fix, we need to use spawn to test it + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # There is some Tensor Parallelism related processing logic in LoRA that + # requires multi-GPU testing for validation. + - pytest -v -s -x lora/test_chatglm3_tp.py + - pytest -v -s -x lora/test_llama_tp.py + - pytest -v -s -x lora/test_llm_with_multi_loras.py + - pytest -v -s -x lora/test_olmoe_tp.py + - pytest -v -s -x lora/test_gptoss_tp.py \ No newline at end of file diff --git a/.buildkite/test_areas/misc.yaml b/.buildkite/test_areas/misc.yaml new file mode 100644 index 0000000000000..252af1e56a105 --- /dev/null +++ b/.buildkite/test_areas/misc.yaml @@ -0,0 +1,165 @@ +group: Miscellaneous +depends_on: + - image-build +steps: +- label: V1 Others + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt + # split the test to avoid interference + - pytest -v -s -m 'not cpu_test' v1/core + - pytest -v -s v1/executor + - pytest -v -s v1/kv_offload + - pytest -v -s v1/sample + - pytest -v -s v1/logits_processors + - pytest -v -s v1/worker + - pytest -v -s v1/spec_decode + - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit + - pytest -v -s -m 'not cpu_test' v1/metrics + - pytest -v -s v1/test_oracle.py + - pytest -v -s v1/test_request.py + - pytest -v -s v1/test_outputs.py + # Integration test for streaming correctness (requires special branch). + - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api + - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine + +- label: V1 Others (CPU) + depends_on: ~ + source_file_dependencies: + - vllm/ + - tests/v1 + no_gpu: true + commands: + # split the test to avoid interference + - pytest -v -s -m 'cpu_test' v1/core + - pytest -v -s v1/structured_output + - pytest -v -s v1/test_serial_utils.py + - pytest -v -s -m 'cpu_test' v1/kv_connector/unit + - pytest -v -s -m 'cpu_test' v1/metrics + +- label: Regression + timeout_in_minutes: 20 + source_file_dependencies: + - vllm/ + - tests/test_regression + commands: + - pip install modelscope + - pytest -v -s test_regression.py + working_dir: "/vllm-workspace/tests" # optional + +- label: Examples + timeout_in_minutes: 45 + working_dir: "/vllm-workspace/examples" + source_file_dependencies: + - vllm/entrypoints + - vllm/multimodal + - examples/ + commands: + - pip install tensorizer # for tensorizer test + - python3 offline_inference/basic/chat.py # for basic + - python3 offline_inference/basic/generate.py --model facebook/opt-125m + - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 + - python3 offline_inference/basic/classify.py + - python3 offline_inference/basic/embed.py + - python3 offline_inference/basic/score.py + # for multi-modal models + - python3 offline_inference/audio_language.py --seed 0 + - python3 offline_inference/vision_language.py --seed 0 + - python3 offline_inference/vision_language_multi_image.py --seed 0 + - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 + # for pooling models + - python3 pooling/pooling/vision_language_pooling.py --seed 0 + # for features demo + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 + # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU + - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 + +- label: Metrics, Tracing (2 GPUs) + timeout_in_minutes: 20 + num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/v1/tracing + commands: + - "pip install \ + 'opentelemetry-sdk>=1.26.0' \ + 'opentelemetry-api>=1.26.0' \ + 'opentelemetry-exporter-otlp>=1.26.0' \ + 'opentelemetry-semantic-conventions-ai>=0.4.1'" + - pytest -v -s v1/tracing + +- label: Python-only Installation + depends_on: ~ + timeout_in_minutes: 20 + source_file_dependencies: + - tests/standalone_tests/python_only_compile.sh + - setup.py + commands: + - bash standalone_tests/python_only_compile.sh + +- label: Async Engine, Inputs, Utils, Worker + timeout_in_minutes: 50 + source_file_dependencies: + - vllm/ + - tests/multimodal + - tests/utils_ + commands: + - pytest -v -s -m 'not cpu_test' multimodal + - pytest -v -s utils_ + +- label: Async Engine, Inputs, Utils, Worker, Config (CPU) + depends_on: ~ + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/ + - tests/test_inputs.py + - tests/test_outputs.py + - tests/multimodal + - tests/standalone_tests/lazy_imports.py + - tests/tokenizers_ + - tests/tool_parsers + - tests/transformers_utils + - tests/config + no_gpu: true + commands: + - python3 standalone_tests/lazy_imports.py + - pytest -v -s test_inputs.py + - pytest -v -s test_outputs.py + - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s tokenizers_ + - pytest -v -s tool_parsers + - pytest -v -s transformers_utils + - pytest -v -s config + +- label: GPT-OSS Eval (B200) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + optional: true + source_file_dependencies: + - tests/evals/gpt_oss + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - uv pip install --system 'gpt-oss[eval]==0.0.5' + - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 + +- label: Batch Invariance (H100) + timeout_in_minutes: 25 + gpu: h100 + source_file_dependencies: + - vllm/v1/attention + - vllm/model_executor/layers + - tests/v1/determinism/ + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pip install pytest-timeout pytest-forked + - pytest -v -s v1/determinism/test_batch_invariance.py + - pytest -v -s v1/determinism/test_rms_norm_batch_invariant.py \ No newline at end of file diff --git a/.buildkite/test_areas/model_executor.yaml b/.buildkite/test_areas/model_executor.yaml new file mode 100644 index 0000000000000..996c8bb8b780a --- /dev/null +++ b/.buildkite/test_areas/model_executor.yaml @@ -0,0 +1,17 @@ +group: Model Executor +depends_on: + - image-build +steps: +- label: Model Executor + timeout_in_minutes: 35 + source_file_dependencies: + - vllm/engine/arg_utils.py + - vllm/config/model.py + - vllm/model_executor + - tests/model_executor + - tests/entrypoints/openai/test_tensorizer_entrypoint.py + commands: + - apt-get update && apt-get install -y curl libsodium23 + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s model_executor + - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py diff --git a/.buildkite/test_areas/models_basic.yaml b/.buildkite/test_areas/models_basic.yaml new file mode 100644 index 0000000000000..39a5d51c48833 --- /dev/null +++ b/.buildkite/test_areas/models_basic.yaml @@ -0,0 +1,62 @@ +group: Models - Basic +depends_on: + - image-build +steps: +- label: Basic Models Tests (Initialization) + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/test_initialization.py + commands: + # Run a subset of model initialization tests + - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset + +- label: Basic Models Tests (Extra Initialization) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/test_initialization.py + commands: + # Only when vLLM model source is modified - test initialization of a large + # subset of supported models (the complement of the small subset in the above + # test.) Also run if model initialization test file is modified + - pytest -v -s models/test_initialization.py -k 'not test_can_initialize_small_subset' --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Basic Models Tests (Other) + timeout_in_minutes: 45 + source_file_dependencies: + - vllm/ + - tests/models/test_transformers.py + - tests/models/test_registry.py + commands: + - pytest -v -s models/test_transformers.py models/test_registry.py + +- label: Basic Models Test (Other CPU) # 5min + timeout_in_minutes: 10 + source_file_dependencies: + - vllm/ + - tests/models/test_utils.py + - tests/models/test_vision.py + no_gpu: true + commands: + - pytest -v -s models/test_utils.py models/test_vision.py + +- label: Transformers Nightly Models + working_dir: "/vllm-workspace/" + optional: true + soft_fail: true + commands: + - pip install --upgrade git+https://github.com/huggingface/transformers + - pytest -v -s tests/models/test_initialization.py + - pytest -v -s tests/models/test_transformers.py + - pytest -v -s tests/models/multimodal/processing/ + - pytest -v -s tests/models/multimodal/test_mapping.py + - python3 examples/offline_inference/basic/chat.py + - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl + # Whisper needs spawn method to avoid deadlock + - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper diff --git a/.buildkite/test_areas/models_distributed.yaml b/.buildkite/test_areas/models_distributed.yaml new file mode 100644 index 0000000000000..b6bfbf2ddab47 --- /dev/null +++ b/.buildkite/test_areas/models_distributed.yaml @@ -0,0 +1,22 @@ +group: Models - Distributed +depends_on: + - image-build +steps: +- label: Distributed Model Tests (2 GPUs) + timeout_in_minutes: 50 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/model_executor/model_loader/sharded_state_loader.py + - vllm/model_executor/models/ + - tests/basic_correctness/ + - tests/model_executor/model_loader/test_sharded_state_loader.py + - tests/models/ + commands: + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py + # Avoid importing model tests that cause CUDA reinitialization error + - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/language -v -s -m 'distributed(num_gpus=2)' + - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py + - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' diff --git a/.buildkite/test_areas/models_language.yaml b/.buildkite/test_areas/models_language.yaml new file mode 100644 index 0000000000000..f70192c4ebc0a --- /dev/null +++ b/.buildkite/test_areas/models_language.yaml @@ -0,0 +1,91 @@ +group: Models - Language +depends_on: + - image-build +steps: +- label: Language Models Tests (Standard) + timeout_in_minutes: 25 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/language + commands: + # Test standard language models, excluding a subset of slow tests + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m 'core_model and (not slow_test)' + +- label: Language Models Tests (Extra Standard) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/language/pooling/test_embedding.py + - tests/models/language/generation/test_common.py + - tests/models/language/pooling/test_classification.py + commands: + # Shard slow subset of standard language models tests. Only run when model + # source is modified, or when specified test files are modified + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m 'core_model and slow_test' --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Language Models Tests (Hybrid) %N + timeout_in_minutes: 75 + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation + commands: + # Install fast path packages for testing against transformers + # Note: also needed to run plamo2 model in vLLM + - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5' + - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' + # Shard hybrid language model tests + - pytest -v -s models/language/generation -m hybrid_model --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Language Models Test (Extended Generation) # 80min + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation + commands: + # Install fast path packages for testing against transformers + # Note: also needed to run plamo2 model in vLLM + - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5' + - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' + - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' + +- label: Language Models Test (PPL) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation_ppl_test + commands: + - pytest -v -s models/language/generation_ppl_test + +- label: Language Models Test (Extended Pooling) # 36min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling + commands: + - pytest -v -s models/language/pooling -m 'not core_model' + +- label: Language Models Test (MTEB) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling_mteb_test + commands: + - pytest -v -s models/language/pooling_mteb_test diff --git a/.buildkite/test_areas/models_multimodal.yaml b/.buildkite/test_areas/models_multimodal.yaml new file mode 100644 index 0000000000000..fc24068c20a46 --- /dev/null +++ b/.buildkite/test_areas/models_multimodal.yaml @@ -0,0 +1,79 @@ +group: Models - Multimodal +depends_on: + - image-build +steps: +- label: Multi-Modal Models (Standard) # 60min + timeout_in_minutes: 80 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pip freeze | grep -E 'torch' + - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing + - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work + +- label: Multi-Modal Processor Test (CPU) + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + no_gpu: true + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py + +- label: Multi-Modal Processor # 44min + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing/test_tensor_schema.py + +- label: Multi-Modal Accuracy Eval (Small Models) # 50min + timeout_in_minutes: 70 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - vllm/multimodal/ + - vllm/inputs/ + - vllm/v1/core/ + commands: + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1 + +- label: Multi-Modal Models (Extended) 1 + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing + +- label: Multi-Modal Models (Extended) 2 + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' + +- label: Multi-Modal Models (Extended) 3 + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' + +# This test is used only in PR development phase to test individual models and should never run on main +- label: Custom Models + optional: true + commands: + - echo 'Testing custom models...' + # PR authors can temporarily add commands below to test individual models + # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py + # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* diff --git a/.buildkite/test_areas/plugins.yaml b/.buildkite/test_areas/plugins.yaml new file mode 100644 index 0000000000000..60c179aa098e1 --- /dev/null +++ b/.buildkite/test_areas/plugins.yaml @@ -0,0 +1,34 @@ +group: Plugins +depends_on: + - image-build +steps: +- label: Plugin Tests (2 GPUs) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/plugins/ + - tests/plugins/ + commands: + # begin platform plugin and general plugin tests, all the code in-between runs on dummy platform + - pip install -e ./plugins/vllm_add_dummy_platform + - pytest -v -s plugins_tests/test_platform_plugins.py + - pip uninstall vllm_add_dummy_platform -y + # end platform plugin tests + # begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin + - pip install -e ./plugins/prithvi_io_processor_plugin + - pytest -v -s plugins_tests/test_io_processor_plugins.py + - pip uninstall prithvi_io_processor_plugin -y + # end io_processor plugins test + # begin stat_logger plugins test + - pip install -e ./plugins/vllm_add_dummy_stat_logger + - pytest -v -s plugins_tests/test_stats_logger_plugins.py + - pip uninstall dummy_stat_logger -y + # end stat_logger plugins test + # other tests continue here: + - pytest -v -s plugins_tests/test_scheduler_plugins.py + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s distributed/test_distributed_oot.py + - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins diff --git a/.buildkite/test_areas/pytorch.yaml b/.buildkite/test_areas/pytorch.yaml new file mode 100644 index 0000000000000..703c82eb1a91b --- /dev/null +++ b/.buildkite/test_areas/pytorch.yaml @@ -0,0 +1,50 @@ +group: PyTorch +depends_on: + - image-build +steps: +- label: PyTorch Compilation Unit Tests + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/ + - tests/compile + commands: + # Run unit tests defined directly under compile/, + # not including subdirectories, which are usually heavier + # tests covered elsewhere. + # Use `find` to launch multiple instances of pytest so that + # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 + - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\;" + +- label: PyTorch Fullgraph Smoke Test + timeout_in_minutes: 30 + source_file_dependencies: + - vllm/ + - tests/compile + commands: + # Run smoke tests under fullgraph directory, except test_full_graph.py + # as it is a heavy test that is covered in other steps. + # Use `find` to launch multiple instances of pytest so that + # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 + - "find compile/fullgraph/ -name 'test_*.py' -not -name 'test_full_graph.py' -exec pytest -s -v {} \\;" + +- label: PyTorch Fullgraph + timeout_in_minutes: 40 + source_file_dependencies: + - vllm/ + - tests/compile + commands: + # fp8 kv scales not supported on sm89, tested on Blackwell instead + - pytest -v -s compile/fullgraph/test_full_graph.py -k 'not test_fp8_kv_scale_compile' + # Limit to no custom ops to reduce running time + # Wrap with quotes to escape yaml and avoid starting -k string with a - + - "pytest -v -s compile/distributed/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'" + +- label: Pytorch Nightly Dependency Override Check # 2min + # if this test fails, it means the nightly torch version is not compatible with some + # of the dependencies. Please check the error message and add the package to whitelist + # in /vllm/tools/pre_commit/generate_nightly_torch_test.py + soft_fail: true + source_file_dependencies: + - requirements/nightly_torch_test.txt + commands: + - bash standalone_tests/pytorch_nightly_dependency.sh \ No newline at end of file diff --git a/.buildkite/test_areas/quantization.yaml b/.buildkite/test_areas/quantization.yaml new file mode 100644 index 0000000000000..6e89d6af3b8d1 --- /dev/null +++ b/.buildkite/test_areas/quantization.yaml @@ -0,0 +1,46 @@ +group: Quantization +depends_on: + - image-build +steps: +- label: Quantization + timeout_in_minutes: 90 + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + - tests/quantization + commands: + # temporary install here since we need nightly, will move to requirements/test.in + # after torchao 0.12 release, and pin a working version of torchao nightly here + + # since torchao nightly is only compatible with torch nightly currently + # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now + # we can only upgrade after this is resolved + # TODO(jerryzh168): resolve the above comment + - uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129 + - uv pip install --system conch-triton-kernels + - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py + +- label: Quantized MoE Test (B200) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - tests/quantization/test_blackwell_moe.py + - vllm/model_executor/models/deepseek_v2.py + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/models/llama4.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization/compressed_tensors + - vllm/model_executor/layers/quantization/modelopt.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - pytest -s -v tests/quantization/test_blackwell_moe.py + +- label: Quantized Models Test + timeout_in_minutes: 60 + source_file_dependencies: + - vllm/model_executor/layers/quantization + - tests/models/quantization + commands: + - pytest -v -s models/quantization diff --git a/.buildkite/test_areas/samplers.yaml b/.buildkite/test_areas/samplers.yaml new file mode 100644 index 0000000000000..ad377148fd073 --- /dev/null +++ b/.buildkite/test_areas/samplers.yaml @@ -0,0 +1,14 @@ +group: Samplers +depends_on: + - image-build +steps: +- label: Samplers Test + timeout_in_minutes: 75 + source_file_dependencies: + - vllm/model_executor/layers + - vllm/sampling_metadata.py + - tests/samplers + - tests/conftest.py + commands: + - pytest -v -s samplers + - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers diff --git a/.buildkite/test_areas/tool_use.yaml b/.buildkite/test_areas/tool_use.yaml new file mode 100644 index 0000000000000..69527a1214229 --- /dev/null +++ b/.buildkite/test_areas/tool_use.yaml @@ -0,0 +1,13 @@ +group: Tool use +depends_on: + - image-build +steps: +- label: OpenAI-Compatible Tool Use + timeout_in_minutes: 35 + mirror_hardwares: [amdexperimental] + fast_check: false + source_file_dependencies: + - vllm/ + - tests/tool_use + commands: + - pytest -v -s tool_use diff --git a/.buildkite/test_areas/weight_loading.yaml b/.buildkite/test_areas/weight_loading.yaml new file mode 100644 index 0000000000000..cfc5bb20fe7ad --- /dev/null +++ b/.buildkite/test_areas/weight_loading.yaml @@ -0,0 +1,25 @@ +group: Weight Loading +depends_on: + - image-build +steps: +- label: Weight Loading Multiple GPU # 33min + timeout_in_minutes: 45 + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index ecb10d1a450f3..d6447649cd89a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -146,10 +146,10 @@ mkdocs.yaml @hmellor /requirements/kv_connectors.txt @NickLucche # Pooling models -/examples/*/pooling/ @noooop +/examples/pooling @noooop /tests/models/*/pooling* @noooop /tests/entrypoints/pooling @noooop -/vllm/entrypoints/pooling @aarnphm @chaunceyjiang @noooop +/vllm/entrypoints/pooling @noooop /vllm/config/pooler.py @noooop /vllm/pooling_params.py @noooop /vllm/model_executor/layers/pooler.py @noooop diff --git a/.github/mergify.yml b/.github/mergify.yml index 997a40e18e588..3ad79f93bc7ad 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -14,6 +14,52 @@ pull_request_rules: comment: message: "Documentation preview: https://vllm--{{number}}.org.readthedocs.build/en/{{number}}/" +- name: comment-pre-commit-failure + description: Comment on PR when pre-commit check fails + conditions: + - status-failure=pre-commit + - -closed + - -draft + actions: + comment: + message: | + Hi @{{author}}, the pre-commit checks have failed. Please run: + + ```bash + uv pip install pre-commit + pre-commit install + pre-commit run --all-files + ``` + + Then, commit the changes and push to your branch. + + For future commits, `pre-commit` will run automatically on changed files before each commit. + + > [!TIP] + >
+ > Is mypy or markdownlint failing? + >
+ > mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally: + > + > ```bash + > # For mypy (substitute "3.10" with the failing version if needed) + > pre-commit run --hook-stage manual mypy-3.10 + > # For markdownlint + > pre-commit run --hook-stage manual markdownlint + > ``` + >
+ +- name: comment-dco-failure + description: Comment on PR when DCO check fails + conditions: + - status-failure=dco + - -closed + - -draft + actions: + comment: + message: | + Hi @{{author}}, the DCO check has failed. Please click on DCO in the Checks section for instructions on how to resolve this. + - name: label-ci-build description: Automatically apply ci/build label conditions: @@ -140,7 +186,7 @@ pull_request_rules: - files~=^tests/entrypoints/test_context.py - files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py - files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py - - files~=^vllm/entrypoints/harmony_utils.py + - files~=^vllm/entrypoints/openai/parser/harmony_utils.py - files~=^vllm/entrypoints/tool_server.py - files~=^vllm/entrypoints/tool.py - files~=^vllm/entrypoints/context.py @@ -358,4 +404,4 @@ pull_request_rules: actions: label: add: - - kv-connector \ No newline at end of file + - kv-connector diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml index 861290ea43c87..df8910837715d 100644 --- a/.github/workflows/cleanup_pr_body.yml +++ b/.github/workflows/cleanup_pr_body.yml @@ -13,10 +13,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - name: Set up Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: '3.12' diff --git a/.github/workflows/macos-smoke-test.yml b/.github/workflows/macos-smoke-test.yml index 3a12c4b3a8300..e80a5c0cc80f9 100644 --- a/.github/workflows/macos-smoke-test.yml +++ b/.github/workflows/macos-smoke-test.yml @@ -12,7 +12,7 @@ jobs: timeout-minutes: 30 steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@v6.0.1 - uses: astral-sh/setup-uv@v7 with: diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index d5e70f30ef638..1041653c2f57e 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -16,8 +16,8 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: "3.12" - run: echo "::add-matcher::.github/workflows/matchers/actionlint.json" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index dca3089f496c9..44bf71db5e9de 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -7,13 +7,15 @@ on: jobs: close-issues-and-pull-requests: + # Prevents triggering on forks or other repos + if: github.repository == 'vllm-project/vllm' permissions: issues: write pull-requests: write actions: write runs-on: ubuntu-latest steps: - - uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 + - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1 with: # Increasing this value ensures that changes to this workflow # propagate to all issues and PRs in days rather than months diff --git a/CMakeLists.txt b/CMakeLists.txt index e09972fe71995..cd52df86e0346 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -384,7 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=$PYTHONPATH + PYTHONPATH=$ENV{PYTHONPATH} ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} RESULT_VARIABLE marlin_generation_result OUTPUT_VARIABLE marlin_generation_result @@ -822,7 +822,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$ENV{PYTHONPATH} ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT} RESULT_VARIABLE machete_generation_result OUTPUT_VARIABLE machete_generation_output @@ -874,7 +874,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) set(SRCS - "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu" + "csrc/quantization/cutlass_w4a8/w4a8_utils.cu" + ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -944,7 +947,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" - "csrc/moe/moe_lora_align_sum_kernels.cu" "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -1002,7 +1004,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH_AND_ARCH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=$PYTHONPATH + PYTHONPATH=$ENV{PYTHONPATH} ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} ${CUDA_ARCHS_STR} RESULT_VARIABLE moe_marlin_generation_result OUTPUT_VARIABLE moe_marlin_generation_output diff --git a/README.md b/README.md index abbb63158f166..26222b815370d 100644 --- a/README.md +++ b/README.md @@ -137,16 +137,19 @@ Compute Resources: - Alibaba Cloud - AMD - Anyscale +- Arm - AWS - Crusoe Cloud - Databricks - DeepInfra - Google Cloud +- IBM - Intel - Lambda Lab - Nebius - Novita AI - NVIDIA +- Red Hat - Replicate - Roblox - RunPod diff --git a/benchmarks/auto_tune/README.md b/benchmarks/auto_tune/README.md index d1bdb4c43f10b..9a9600e08dafe 100644 --- a/benchmarks/auto_tune/README.md +++ b/benchmarks/auto_tune/README.md @@ -83,7 +83,7 @@ MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number ``` -#### 2. Maximize Throughput with a Latency Requirement +### 2. Maximize Throughput with a Latency Requirement - **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms. - **Configuration**: @@ -96,7 +96,7 @@ MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=500 ``` -#### 3. Maximize Throughput with Prefix Caching and Latency Requirements +### 3. Maximize Throughput with Prefix Caching and Latency Requirements - **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms. - **Configuration**: diff --git a/benchmarks/auto_tune/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh index 56b721cbb4021..a245e2022e605 100644 --- a/benchmarks/auto_tune/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -18,6 +18,11 @@ MIN_CACHE_HIT_PCT=${MIN_CACHE_HIT_PCT:-0} MAX_LATENCY_ALLOWED_MS=${MAX_LATENCY_ALLOWED_MS:-100000000000} NUM_SEQS_LIST=${NUM_SEQS_LIST:-"128 256"} NUM_BATCHED_TOKENS_LIST=${NUM_BATCHED_TOKENS_LIST:-"512 1024 2048 4096"} +HOSTNAME=$(hostname) +if [[ -z "$HOSTNAME" ]]; then + echo "Error: Failed to determine hostname." >&2 + exit 1 +fi LOG_FOLDER="$BASE/auto-benchmark/$TAG" RESULT="$LOG_FOLDER/result.txt" @@ -82,6 +87,7 @@ start_server() { "$MODEL" "--disable-log-requests" "--port" "8004" + "--host" "$HOSTNAME" "--gpu-memory-utilization" "$gpu_memory_utilization" "--max-num-seqs" "$max_num_seqs" "--max-num-batched-tokens" "$max_num_batched_tokens" @@ -96,8 +102,9 @@ start_server() { # This correctly passes each element as a separate argument. if [[ -n "$profile_dir" ]]; then # Start server with profiling enabled - VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \ - vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & + local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}" + VLLM_SERVER_DEV_MODE=1 \ + vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 & else # Start server without profiling VLLM_SERVER_DEV_MODE=1 \ @@ -112,7 +119,7 @@ start_server() { # since that we should always have permission to send signal to the server process. kill -0 $server_pid 2> /dev/null || break - RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout) + RESPONSE=$(curl -s -X GET "http://${HOSTNAME}:8004/health" -w "%{http_code}" -o /dev/stdout) STATUS_CODE=$(echo "$RESPONSE" | tail -n 1) if [[ "$STATUS_CODE" -eq 200 ]]; then server_started=1 @@ -172,6 +179,7 @@ run_benchmark() { --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 1000 \ --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ --port 8004 &> "$bm_log" throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') @@ -187,7 +195,7 @@ run_benchmark() { request_rate=$((${throughput%.*} + 1)) while ((request_rate > 0)); do # clear prefix cache - curl -X POST http://0.0.0.0:8004/reset_prefix_cache + curl -X POST http://${HOSTNAME}:8004/reset_prefix_cache sleep 5 bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt" vllm bench serve \ @@ -203,6 +211,7 @@ run_benchmark() { --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 100 \ --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ --port 8004 &> "$bm_log" throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') @@ -303,6 +312,7 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 100 \ --random-prefix-len $prefix_len \ + --host "$HOSTNAME" \ --port 8004 \ --profile &> "$bm_log" else diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index d69d74ca61f54..831b76b66e096 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -620,7 +620,7 @@ def get_tokenizer( kwargs["use_fast"] = False if tokenizer_mode == "mistral": try: - from vllm.tokenizers import MistralTokenizer + from vllm.tokenizers.mistral import MistralTokenizer except ImportError as e: raise ImportError( "MistralTokenizer requires vllm package.\n" diff --git a/benchmarks/benchmark_hash.py b/benchmarks/benchmark_hash.py new file mode 100644 index 0000000000000..08cdc012d6527 --- /dev/null +++ b/benchmarks/benchmark_hash.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Micro benchmark comparing built-in hash(), SHA-256, and xxHash. + +This focuses on a single test payload shaped like the prefix-cache hash input: + (32-byte bytes object, 32-int tuple) + +Usage: + python benchmarks/hash_micro_benchmark.py --iterations 20000 +""" + +from __future__ import annotations + +import argparse +import random +import statistics +import time +from collections.abc import Callable, Iterable + +from vllm.utils.hashing import sha256, xxhash + + +def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]: + """Generate a deterministic test payload.""" + random.seed(seed) + bytes_data = bytes(random.getrandbits(8) for _ in range(32)) + int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32)) + return (bytes_data, int_tuple) + + +def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int): + """Return (avg_seconds, std_seconds) for hashing `data` `iterations` times.""" + times: list[float] = [] + + # Warm-up to avoid first-run noise. + for _ in range(200): + func(data) + + for _ in range(iterations): + start = time.perf_counter() + func(data) + end = time.perf_counter() + times.append(end - start) + + avg = statistics.mean(times) + std = statistics.stdev(times) if len(times) > 1 else 0.0 + return avg, std + + +def _run_benchmarks( + benchmarks: Iterable[tuple[str, Callable[[tuple], object]]], + data: tuple, + iterations: int, +): + """Yield (name, avg, std) for each benchmark, skipping unavailable ones.""" + for name, func in benchmarks: + try: + avg, std = _benchmark_func(func, data, iterations) + except ModuleNotFoundError as exc: + print(f"Skipping {name}: {exc}") + continue + yield name, avg, std + + +def builtin_hash(data: tuple) -> int: + """Wrapper for Python's built-in hash().""" + return hash(data) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--iterations", + type=int, + default=10_000, + help="Number of measured iterations per hash function.", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for test payload." + ) + args = parser.parse_args() + + data = _generate_test_data(args.seed) + benchmarks = ( + ("SHA256 (pickle)", sha256), + ("xxHash (pickle)", xxhash), + ("built-in hash()", builtin_hash), + ) + + print("=" * 60) + print("HASH FUNCTION MICRO BENCHMARK") + print("=" * 60) + print("Test data: (32-byte bytes object, 32-int tuple)") + print(f"Iterations: {args.iterations:,}") + print("=" * 60) + + results = list(_run_benchmarks(benchmarks, data, args.iterations)) + builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None) + + print("\nResults:") + for name, avg, std in results: + print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs") + + if builtin_entry: + _, builtin_avg, _ = builtin_entry + print("\n" + "=" * 60) + print("SUMMARY (relative to built-in hash())") + print("=" * 60) + for name, avg, _ in results: + if name == "built-in hash()": + continue + speed_ratio = avg / builtin_avg + print(f"• {name} is {speed_ratio:.1f}x slower than built-in hash()") + else: + print("\nBuilt-in hash() result missing; cannot compute speed ratios.") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index dedb564fffac8..b5373d383b548 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -32,12 +32,11 @@ def benchmark_propose(args): model_config = ModelConfig( model="facebook/opt-125m", - task="generate", max_model_len=args.num_token + args.num_spec_token, tokenizer="facebook/opt-125m", tokenizer_mode="auto", dtype="auto", - seed=None, + seed=0, trust_remote_code=False, ) proposer = NgramProposer( @@ -108,7 +107,10 @@ def benchmark_batched_propose(args): device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), ) # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group diff --git a/benchmarks/benchmark_prefix_block_hash.py b/benchmarks/benchmark_prefix_block_hash.py new file mode 100644 index 0000000000000..8bcd8af0d3102 --- /dev/null +++ b/benchmarks/benchmark_prefix_block_hash.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Simple benchmark to compare prefix-cache block hashing algorithms. + +Example: + python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 +""" + +from __future__ import annotations + +import argparse +import random +import statistics +import sys +import time +from collections.abc import Callable, Iterable, Sequence + +from vllm.utils.hashing import get_hash_fn_by_name +from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash + +SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor") + + +def _generate_blocks( + num_blocks: int, block_size: int, vocab_size: int, seed: int +) -> list[list[int]]: + rng = random.Random(seed) + return [ + [rng.randrange(vocab_size) for _ in range(block_size)] + for _ in range(num_blocks) + ] + + +def _hash_all_blocks( + hash_fn: Callable[[object], bytes], + blocks: Iterable[Sequence[int]], +) -> float: + parent_hash: BlockHash | None = None + start = time.perf_counter() + for block in blocks: + parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None) + end = time.perf_counter() + return end - start + + +def _benchmark( + hash_algo: str, + blocks: list[list[int]], + trials: int, +) -> tuple[float, float, float] | None: + try: + hash_fn = get_hash_fn_by_name(hash_algo) + init_none_hash(hash_fn) + timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)] + except ModuleNotFoundError as exc: + print(f"Skipping {hash_algo}: {exc}", file=sys.stderr) + return None + + avg = statistics.mean(timings) + best = min(timings) + # throughput: tokens / second + tokens_hashed = len(blocks) * len(blocks[0]) + throughput = tokens_hashed / best + return avg, best, throughput + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.") + parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.") + parser.add_argument( + "--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)." + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument( + "--trials", type=int, default=5, help="Number of timed trials per algorithm." + ) + parser.add_argument( + "--algorithms", + nargs="+", + default=SUPPORTED_ALGOS, + choices=SUPPORTED_ALGOS, + help="Hash algorithms to benchmark.", + ) + args = parser.parse_args() + + blocks = _generate_blocks( + args.num_blocks, args.block_size, args.vocab_size, args.seed + ) + print( + f"Benchmarking {len(args.algorithms)} algorithms on " + f"{args.num_blocks} blocks (block size={args.block_size})." + ) + + for algo in args.algorithms: + result = _benchmark(algo, blocks, args.trials) + if result is None: + continue + + avg, best, throughput = result + print( + f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s " + f"throughput: {throughput / 1e6:.2f}M tokens/s" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 28fc383a318dd..e6391134ff932 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -40,7 +40,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.utils.argparse_utils import FlexibleArgumentParser try: - from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.tokenizers import get_tokenizer except ImportError: from backend_request_func import get_tokenizer diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 55001cf3722a0..33aca831883aa 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -46,7 +46,7 @@ from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase try: - from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.tokenizers import get_tokenizer except ImportError: from backend_request_func import get_tokenizer @@ -574,7 +574,7 @@ async def benchmark( ) print( "{:<40} {:<10.2f}".format( - "Total Token throughput (tok/s):", metrics.total_token_throughput + "Total token throughput (tok/s):", metrics.total_token_throughput ) ) @@ -963,8 +963,7 @@ def create_argument_parser(): parser.add_argument( "--profile", action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "VLLM_TORCH_PROFILER_DIR to enable profiler.", + help="Use vLLM Profiling. --profiler-config must be provided on the server.", ) parser.add_argument( "--result-dir", diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index d809bf1db8cbc..fb3329975cee3 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -14,6 +14,9 @@ from tqdm import tqdm import vllm._custom_ops as ops from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) @dataclass @@ -22,6 +25,7 @@ class bench_params_t: hidden_size: int add_residual: bool dtype: torch.dtype + group_size: list[int] def description(self): return ( @@ -29,6 +33,7 @@ class bench_params_t: f"x D {self.hidden_size} " f"x R {self.add_residual} " f"x DT {self.dtype}" + f"x GS {self.group_size}" ) @@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]: HIDDEN_SIZES = list(range(1024, 8129, 1024)) ADD_RESIDUAL = [True, False] DTYPES = [torch.bfloat16, torch.float] + GROUP_SIZES = [[1, 64], [1, 128]] - combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES) bench_params = list( - map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations) + map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations) ) return bench_params @@ -52,6 +58,7 @@ def unfused_int8_impl( x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): # Norm torch_out = None @@ -69,6 +76,7 @@ def unfused_fp8_impl( x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): # Norm torch_out = None @@ -81,23 +89,63 @@ def unfused_fp8_impl( torch_out, _ = ops.scaled_fp8_quant(torch_out) +def unfused_groupwise_fp8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = per_token_group_quant_fp8( + torch_out, group_size=group_size[1], use_ue8m0=False + ) + + def fused_impl( rms_norm_layer: RMSNorm, # this stores the weights x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): out, _ = ops.rms_norm_dynamic_per_token_quant( x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual ) +def fused_groupwise_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + out, _ = ops.rms_norm_per_block_quant( + x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + group_size, + residual=residual, + is_scale_transposed=True, + ) + + # Bench functions def bench_fn( rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, quant_dtype: torch.dtype, + group_size: list[int], label: str, sub_label: str, fn: Callable, @@ -110,10 +158,11 @@ def bench_fn( "x": x, "residual": residual, "quant_dtype": quant_dtype, + "group_size": group_size, "fn": fn, } return TBenchmark.Timer( - stmt="fn(rms_norm_layer, x, residual, quant_dtype)", + stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)", globals=globals, label=label, sub_label=sub_label, @@ -147,6 +196,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.int8, + params.group_size, label, sub_label, unfused_int8_impl, @@ -161,6 +211,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.float8_e4m3fn, + params.group_size, label, sub_label, unfused_fp8_impl, @@ -175,6 +226,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.int8, + params.group_size, label, sub_label, fused_impl, @@ -189,6 +241,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.float8_e4m3fn, + params.group_size, label, sub_label, fused_impl, @@ -196,6 +249,36 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu ) ) + # unfused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + unfused_groupwise_fp8_impl, + "unfused_groupwise_fp8_impl", + ) + ) + + # fused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + fused_groupwise_impl, + "fused_groupwise_fp8_impl", + ) + ) + print_timers(timers) return timers diff --git a/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py new file mode 100644 index 0000000000000..04921dafbdbea --- /dev/null +++ b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from enum import Enum +from itertools import product +from typing import Any + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _per_token_group_quant_fp8_colmajor, + silu_mul_per_token_group_quant_fp8_colmajor, +) +from vllm.triton_utils import triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + +from .utils import ArgPool, Bench, CudaGraphBenchParams + +GROUP_SIZE = 128 +FLOAT8_T = torch.float8_e4m3fn + + +def print_timers(timers: list[TMeasurement], cuda_graph_nops: int): + print( + f"Note : The timings reported above is for {cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + f"Please divide by {cuda_graph_nops} for single invocation " + "timings." + ) + compare = TBenchmark.Compare(timers) + compare.print() + + +class ImplType(Enum): + SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1 + REFERENCE = 2 + + def get_impl(self): + if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return silu_mul_per_token_group_quant_fp8_colmajor + elif self == ImplType.REFERENCE: + return reference + raise ValueError(f"Unrecognized ImplType {self}") + + +@dataclass +class BenchmarkTensors: + input: torch.Tensor + output: torch.Tensor + + # Reference act output tensor + ref_act_out: torch.Tensor + ref_quant_out: torch.Tensor + + @staticmethod + def make(T: int, N: int) -> "BenchmarkTensors": + assert T % GROUP_SIZE == 0 + assert N % (GROUP_SIZE * 2) == 0 + + input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda") + + # silu_mul_per_token_group_quant_fp8_colmajor output. + output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to( + FLOAT8_T + ) + + # reference output. + ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda") + ref_quant_out = torch.empty( + (T, N // 2), dtype=torch.bfloat16, device="cuda" + ).to(FLOAT8_T) + + return BenchmarkTensors( + input=input, + output=output, + ref_act_out=ref_act_out, + ref_quant_out=ref_quant_out, + ) + + @property + def T(self): + return self.input.size(0) + + @property + def N(self): + return self.input.size(1) + + def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]: + if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return { + "input": self.input, + "output": self.output, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + elif impl_type == ImplType.REFERENCE: + return { + "input": self.input, + "act_out": self.ref_act_out, + "quant_out": self.ref_quant_out, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + raise ValueError(f"Unrecognized impl_type {impl_type}") + + +def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool): + """ + Reference triton quant kernel from, + vllm.model_executor.layers.quantization.utils.fp8_utils + """ + assert quant_out.size() == x.size() + # Allocate the scale tensor column-major format. + shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1] + x_q = quant_out + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + + M = x.numel() // GROUP_SIZE + N = GROUP_SIZE + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + finfo = torch.finfo(FLOAT8_T) + fp8_min = finfo.min + fp8_max = finfo.max + + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + GROUP_SIZE, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + return x_q, x_s + + +def reference( + input: torch.Tensor, + act_out: torch.Tensor, + quant_out: torch.Tensor, + use_ue8m0: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + torch.ops._C.silu_and_mul(act_out, input) + return reference_quant(act_out, quant_out, use_ue8m0) + + +def bench_impl( + bench_tensors: list[BenchmarkTensors], impl_type: ImplType +) -> TMeasurement: + T = bench_tensors[0].T + N = bench_tensors[0].N + + arg_pool_size = len(bench_tensors) + kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors] + + # warmup + for kwargs in kwargs_list: + impl_type.get_impl()(**kwargs) + torch.cuda.synchronize() + + # Merge into a single kwargs and qualify arguments as ArgPool + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} + for _kwargs in kwargs_list: + for k, v in _kwargs.items(): + kwargs[k].values.append(v) + + cuda_graph_params = None + cuda_graph_params = CudaGraphBenchParams(arg_pool_size) + timer = None + with Bench( + cuda_graph_params, + "silu-mul-quant", + f"num_tokens={T}, N={N}", + impl_type.name, + impl_type.get_impl(), + **kwargs, + ) as bench: + timer = bench.run() + return timer + + +def test_correctness(T: int, N: int): + print(f"Testing num_tokens={T}, N={N} ...") + + bench_tensor = BenchmarkTensors.make(T, N) + + def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]: + return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl)) + + # reference output + ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE) + + # test ouptut + out_q, out_s = output_from_impl( + ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + + torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32)) + torch.testing.assert_close(ref_out_s, out_s) + + +def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]: + timers = [] + for N, T in product(Ns, Ts): + test_correctness(T, N) + + bench_tensors: list[BenchmarkTensors] = [ + BenchmarkTensors.make(T, N) for _ in range(arg_pool_size) + ] + + silu_mul_quant_timer = bench_impl( + bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + timers.append(silu_mul_quant_timer) + reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE) + timers.append(reference_timer) + + print_timers( + [silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size + ) + + print_timers(timers, cuda_graph_nops=arg_pool_size) + + return timers + + +if __name__ == "__main__": + T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)] + N = [2048, 4096, 8192] + + print(f"T = {T}, N = {N}") + run(T, N, arg_pool_size=8) diff --git a/benchmarks/kernels/benchmark_mla_k_concat.py b/benchmarks/kernels/benchmark_mla_k_concat.py new file mode 100644 index 0000000000000..fb3b6c8f12003 --- /dev/null +++ b/benchmarks/kernels/benchmark_mla_k_concat.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation +in MLA (Multi-head Latent Attention) prefill. + +This validates that the optimization from commit 8d4142bd is beneficial across +various batch sizes, not just the originally tested batch size of 32768. +""" + +import time +from collections.abc import Callable + +import torch + +# DeepSeek-V3 MLA dimensions +NUM_HEADS = 128 +QK_NOPE_HEAD_DIM = 128 +PE_DIM = 64 + + +def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor: + """Original torch.cat approach with expand.""" + return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + +def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor: + """Optimized direct copy approach (avoids expand + cat overhead).""" + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + + +def benchmark_method( + method: Callable, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + num_warmup: int = 10, + num_iters: int = 100, +) -> float: + """Benchmark a concatenation method and return mean latency in ms.""" + # Warmup + for _ in range(num_warmup): + _ = method(k_nope, k_pe) + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(num_iters): + _ = method(k_nope, k_pe) + torch.cuda.synchronize() + end = time.perf_counter() + + return (end - start) / num_iters * 1000 # Convert to ms + + +@torch.inference_mode() +def run_benchmark(dtype: torch.dtype, dtype_name: str): + """Run benchmark for a specific dtype.""" + torch.set_default_device("cuda") + + # Batch sizes to test (powers of 2 from 32 to 65536) + batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536] + + print("=" * 80) + print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation") + print("=" * 80) + print( + f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], " + f"k_pe=[B, 1, {PE_DIM}]" + ) + print(f"dtype: {dtype_name}") + print() + print( + f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | " + f"{'Speedup':>8} | {'Reduction':>10}" + ) + print("-" * 70) + + results = [] + for batch_size in batch_sizes: + # Create input tensors (generate in float32 then convert for FP8 compatibility) + k_nope = torch.randn( + batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda" + ).to(dtype) + k_pe = torch.randn( + batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda" + ).to(dtype) + + # Benchmark both methods + cat_time = benchmark_method(cat_method, k_nope, k_pe) + direct_time = benchmark_method(direct_copy_method, k_nope, k_pe) + + speedup = cat_time / direct_time + reduction = (1 - direct_time / cat_time) * 100 + + results.append((batch_size, cat_time, direct_time, speedup, reduction)) + + print( + f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | " + f"{speedup:>7.2f}x | {reduction:>9.1f}%" + ) + + print("=" * 80) + + # Summary statistics + speedups = [r[3] for r in results] + print("\nSpeedup summary:") + print(f" Min: {min(speedups):.2f}x") + print(f" Max: {max(speedups):.2f}x") + print(f" Mean: {sum(speedups) / len(speedups):.2f}x") + + # Find crossover point + crossover_batch = None + for batch_size, _, _, speedup, _ in results: + if speedup >= 1.0: + crossover_batch = batch_size + break + + print("\nConclusion:") + if crossover_batch: + print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}") + # Filter for large batches (>= 512 which is typical for prefill) + large_batch_speedups = [r[3] for r in results if r[0] >= 512] + if large_batch_speedups: + avg_large = sum(large_batch_speedups) / len(large_batch_speedups) + print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x") + print(" - MLA prefill typically uses large batches, so optimization is effective") + + return results + + +@torch.inference_mode() +def main(): + # Test bfloat16 + print("\n") + run_benchmark(torch.bfloat16, "bfloat16") + + # Test float8_e4m3fn + print("\n") + run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_moe_align_block_size.py b/benchmarks/kernels/benchmark_moe_align_block_size.py index f540cff6261a8..5f9a131f79b0e 100644 --- a/benchmarks/kernels/benchmark_moe_align_block_size.py +++ b/benchmarks/kernels/benchmark_moe_align_block_size.py @@ -24,12 +24,15 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: num_tokens_range = [1, 16, 256, 4096] num_experts_range = [16, 64, 224, 256, 280, 512] topk_range = [1, 2, 8] -configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) +ep_size_range = [1, 8] +configs = list( + itertools.product(num_tokens_range, num_experts_range, topk_range, ep_size_range) +) @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["num_tokens", "num_experts", "topk"], + x_names=["num_tokens", "num_experts", "topk", "ep_size"], x_vals=configs, line_arg="provider", line_vals=["vllm"], @@ -38,16 +41,26 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range args={}, ) ) -def benchmark(num_tokens, num_experts, topk, provider): +def benchmark(num_tokens, num_experts, topk, ep_size, provider): """Benchmark function for Triton.""" block_size = 256 + torch.cuda.manual_seed_all(0) topk_ids = get_topk_ids(num_tokens, num_experts, topk) + e_map = None + if ep_size != 1: + local_e = num_experts // ep_size + e_ids = torch.randperm(num_experts, device="cuda", dtype=torch.int32)[:local_e] + e_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) + quantiles = [0.5, 0.2, 0.8] if provider == "vllm": ms, min_ms, max_ms = triton.testing.do_bench( - lambda: moe_align_block_size(topk_ids, block_size, num_experts), + lambda: moe_align_block_size( + topk_ids, block_size, num_experts, e_map, ignore_invalid_experts=True + ), quantiles=quantiles, ) diff --git a/benchmarks/kernels/benchmark_mrope.py b/benchmarks/kernels/benchmark_mrope.py index 83bd91917508f..09de5fa822f86 100644 --- a/benchmarks/kernels/benchmark_mrope.py +++ b/benchmarks/kernels/benchmark_mrope.py @@ -99,7 +99,6 @@ def benchmark_mrope( # the parameters to compute the q k v size based on tp_size mrope_helper_class = get_rope( head_size=head_dim, - rotary_dim=head_dim, max_position=max_position, is_neox_style=is_neox_style, rope_parameters=rope_parameters, diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 074b7a440b612..7a1bc050bb33f 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device): def benchmark(batch_size, seq_len, num_heads, provider): dtype = torch.bfloat16 max_position = 8192 - base = 10000 - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) + rope_parameters = {"partial_rotary_factor": rotary_dim / head_size} + rope = get_rope(head_size, max_position, is_neox_style, rope_parameters) rope = rope.to(dtype=dtype, device=device) cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index fbbb03c5ed465..85b286f8d8d0a 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -251,17 +251,6 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON endif() # Build ACL with CMake - set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF") - set(CMAKE_BUILD_TYPE "Release") - set(ARM_COMPUTE_ARCH "armv8.2-a") - set(ARM_COMPUTE_ENABLE_ASSERTS "OFF") - set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF") - set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") - set(ARM_COMPUTE_ENABLE_OPENMP "ON") - set(ARM_COMPUTE_ENABLE_WERROR "OFF") - set(ARM_COMPUTE_BUILD_EXAMPLES "OFF") - set(ARM_COMPUTE_BUILD_TESTING "OFF") - set(_cmake_config_cmd ${CMAKE_COMMAND} -G Ninja -B build -DARM_COMPUTE_BUILD_SHARED_LIB=OFF diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 2cf3c1a755d3c..0d4f9b7aa07c8 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -35,16 +35,21 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") # sm90a set(SUPPORT_ARCHS) -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3) - list(APPEND SUPPORT_ARCHS 9.0a) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3) + list(APPEND SUPPORT_ARCHS "9.0a") endif() -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8) - list(APPEND SUPPORT_ARCHS 10.0a) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9) + # CUDA 12.9 has introduced "Family-Specific Architecture Features" + # this supports all compute_10x family + list(APPEND SUPPORT_ARCHS "10.0f") +elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) + list(APPEND SUPPORT_ARCHS "10.0a") endif() cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}") if(FLASH_MLA_ARCHS) + message(STATUS "FlashMLA CUDA architectures: ${FLASH_MLA_ARCHS}") set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS}) list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math") @@ -126,7 +131,8 @@ if(FLASH_MLA_ARCHS) $<$:-UPy_LIMITED_API> $<$:-UPy_LIMITED_API>) else() - # Create empty targets for setup.py when not targeting sm90a systems + message(STATUS "FlashMLA will not compile: unsupported CUDA architecture ${CUDA_ARCHS}") + # Create empty targets for setup.py on unsupported systems add_custom_target(_flashmla_C) add_custom_target(_flashmla_extension_C) endif() diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 5047c354ff7d2..bdb2ba74d944d 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -140,16 +140,21 @@ function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR) run_python(_VLLM_TORCH_GOMP_PATH " import os, glob -try: - import torch - torch_pkg = os.path.dirname(torch.__file__) - site_root = os.path.dirname(torch_pkg) - torch_libs = os.path.join(site_root, 'torch.libs') - print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0]) -except: - print('') +import torch +torch_pkg = os.path.dirname(torch.__file__) +site_root = os.path.dirname(torch_pkg) + +# Search both torch.libs and torch/lib +roots = [os.path.join(site_root, 'torch.libs'), os.path.join(torch_pkg, 'lib')] +candidates = [] +for root in roots: + if not os.path.isdir(root): + continue + candidates.extend(glob.glob(os.path.join(root, 'libgomp*.so*'))) + +print(candidates[0] if candidates else '') " - "failed to probe torch.libs for libgomp") + "failed to probe for libgomp") if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}") return() diff --git a/csrc/cache.h b/csrc/cache.h index f2a5ec0acf5cd..cbe44c09eb624 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -58,6 +59,15 @@ void cp_gather_cache( torch::Tensor const& cu_seq_lens, // [BATCH+1] int64_t batch_size, std::optional seq_starts = std::nullopt); +// Gather and upconvert FP8 KV cache to BF16 workspace +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size); + // Indexer K quantization and cache function void indexer_k_quant_and_cache( torch::Tensor& k, // [num_tokens, head_dim] @@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache( torch::Tensor& dst_k, // [num_tokens, head_dim] torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] const torch::Tensor& block_table, // [batch_size, num_blocks] - const torch::Tensor& cu_seq_lens); // [batch_size + 1] \ No newline at end of file + const torch::Tensor& cu_seq_lens); // [batch_size + 1] diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a5457206c706..f11c5f24c12ec 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,6 +2,7 @@ #include #include #include +#include #include "cuda_utils.h" #include "cuda_compat.h" @@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel( const int quant_block_size, // quantization block size const int cache_block_size, // cache block size const int cache_stride, // stride for each token in kv_cache - const bool use_ue8m0 // use ue8m0 scale format + + const bool use_ue8m0 // use ue8m0 scale format ) { constexpr int VEC_SIZE = 4; const int64_t token_idx = blockIdx.x; @@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache( } namespace vllm { + +// Gather and upconvert FP8 KV cache tokens to BF16 workspace +// Similar to cp_gather_cache but specifically for FP8->BF16 conversion +__global__ void cp_gather_and_upconvert_fp8_kv_cache( + const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + __nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ seq_lens, // [BATCH] + const int32_t* __restrict__ workspace_starts, // [BATCH] + const int32_t block_size, const int32_t head_dim, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride) { + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = workspace_starts[bid]; + const int32_t seq_len = seq_lens[bid]; + const int32_t tot_slots = seq_len; + const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + + const int32_t split_start = split * split_slots; + const int32_t split_end = min((split + 1) * split_slots, tot_slots); + + const bool is_active_split = (split_start < tot_slots); + + if (!is_active_split) return; + + // Adjust the pointer for the block_table for this batch + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = split_start; + int32_t offset_div = offset / block_size; + offset = offset % block_size; + const int32_t* batch_block_table = block_table + batch_offset; + + // Adjust dst pointer based on the cumulative sequence lengths + dst += seq_start * dst_entry_stride; + + const int tid = threadIdx.x; + + // Process each token in this split + for (int pid = split_start; pid < split_end; ++pid) { + auto block_id = batch_block_table[offset_div]; + const uint8_t* token_ptr = + src_cache + block_id * cache_block_stride + offset * cache_entry_stride; + __nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride; + + // FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16) + const uint8_t* no_pe_ptr = token_ptr; + const float* scales_ptr = reinterpret_cast(token_ptr + 512); + const __nv_bfloat16* rope_ptr = + reinterpret_cast(token_ptr + 512 + 16); + + // Parallelize fp8 dequant (512 elements) and rope copy (64 elements) + if (tid < 512) { + // FP8 dequantization + const int tile = tid >> 7; // each tile is 128 elements + const float scale = scales_ptr[tile]; + const uint8_t val = no_pe_ptr[tid]; + dst_ptr[tid] = + fp8::scaled_convert<__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale); + } else if (tid < 576) { + // Rope copy (64 bf16 elements) + const int rope_idx = tid - 512; + dst_ptr[512 + rope_idx] = rope_ptr[rope_idx]; + } + + // Move to next token + offset += 1; + if (offset == block_size) { + offset_div += 1; + offset = 0; + } + } +} + template // Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by // block_size. @@ -1202,6 +1280,57 @@ void cp_gather_cache( } } +void cp_gather_and_upconvert_fp8_kv_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] + torch::Tensor const& dst, // [TOT_TOKENS, 576] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& seq_lens, // [BATCH] + torch::Tensor const& workspace_starts, // [BATCH] + int64_t batch_size) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t head_dim = dst.size(1); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32"); + TORCH_CHECK(workspace_starts.dtype() == torch::kInt32, + "workspace_starts must be int32"); + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == seq_lens.device(), + "src_cache and seq_lens must be on the same device"); + TORCH_CHECK(src_cache.device() == workspace_starts.device(), + "src_cache and workspace_starts must be on the same device"); + + TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8"); + TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16"); + TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA"); + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(576); + + vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( + src_cache.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), + block_table.data_ptr(), seq_lens.data_ptr(), + workspace_starts.data_ptr(), block_size, head_dim, + block_table_stride, cache_block_stride, cache_entry_stride, + dst_entry_stride); +} + // Macro to dispatch the kernel based on the data type. #define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ vllm::indexer_k_quant_and_cache_kernel \ diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index 92f8bee5a47a0..02c722ba031a4 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata( input.casual = casual; input.isa = isa; input.enable_kv_split = enable_kv_split; - TORCH_CHECK(casual, "Only supports casual mask for now."); VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] { diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index 98f55d7c014be..e3e077b845f4f 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -186,7 +186,7 @@ struct AttentionMetadata { // - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2 // * q_tile_size * 4, partial output, max + sum (float) // Reduction scratchpad contains: -// - flags: bool array to indicate wether the split is finished +// - flags: bool array to indicate whether the split is finished // - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size // - max, sum: 2 * split_num * q_tile_size * 4 class AttentionScratchPad { @@ -1246,14 +1246,8 @@ class AttentionMainLoop { // rescale sum and partial outputs if (need_rescale) { // compute rescale factor -#ifdef DEFINE_FAST_EXP - vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); - rescale_factor_vec = fast_exp(rescale_factor_vec); - rescale_factor = rescale_factor_vec.get_last_elem(); -#else rescale_factor = std::exp(rescale_factor); vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); -#endif // rescale sum new_sum_val += rescale_factor * init_sum_val; @@ -1889,15 +1883,8 @@ class AttentionMainLoop { : curr_output_buffer; float rescale_factor = final_max > curr_max ? curr_max - final_max : final_max - curr_max; - -#ifdef DEFINE_FAST_EXP - vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); - rescale_factor_vec = fast_exp(rescale_factor_vec); - rescale_factor = rescale_factor_vec.get_last_elem(); -#else rescale_factor = std::exp(rescale_factor); vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); -#endif local_sum[head_idx] = final_max > curr_max ? final_sum + rescale_factor * curr_sum diff --git a/csrc/cpu/cpu_attn_macros.h b/csrc/cpu/cpu_attn_macros.h index 6458e43419370..35716a0790ab3 100644 --- a/csrc/cpu/cpu_attn_macros.h +++ b/csrc/cpu/cpu_attn_macros.h @@ -60,4 +60,54 @@ #endif +#ifdef __aarch64__ + // Implementation copied from Arm Optimized Routines (expf AdvSIMD) + // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c + #include + #define DEFINE_FAST_EXP \ + const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \ + const float ln2_hi = 0x1.62e4p-1f; \ + const float ln2_lo = 0x1.7f7d1cp-20f; \ + const float c0 = 0x1.0e4020p-7f; \ + const float c2 = 0x1.555e66p-3f; \ + const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; \ + const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); \ + const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); \ + const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); \ + const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); \ + const float32x4_t pos_special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); \ + const float32x4_t neg_special_bound = vnegq_f32(pos_special_bound); \ + const float32x4_t inf = \ + vdupq_n_f32(std::numeric_limits::infinity()); \ + const float32x4_t zero = vdupq_n_f32(0.0f); \ + auto neon_expf = [&](float32x4_t values) __attribute__((always_inline)) { \ + float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \ + float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); \ + r = vfmsq_laneq_f32(r, n, ln2_c02, 1); \ + uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \ + float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); \ + float32x4_t r2 = vmulq_f32(r, r); \ + float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); \ + float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); \ + q = vfmaq_f32(q, p, r2); \ + p = vmulq_f32(c4, r); \ + float32x4_t poly = vfmaq_f32(p, q, r2); \ + poly = vfmaq_f32(scale, poly, scale); \ + const uint32x4_t hi_mask = vcgeq_f32(values, pos_special_bound); \ + const uint32x4_t lo_mask = vcleq_f32(values, neg_special_bound); \ + poly = vbslq_f32(hi_mask, inf, poly); \ + return vbslq_f32(lo_mask, zero, poly); \ + }; \ + auto fast_exp = [&](vec_op::FP32Vec16& vec) \ + __attribute__((always_inline)) { \ + float32x4x4_t result; \ + result.val[0] = neon_expf(vec.reg.val[0]); \ + result.val[1] = neon_expf(vec.reg.val[1]); \ + result.val[2] = neon_expf(vec.reg.val[2]); \ + result.val[3] = neon_expf(vec.reg.val[3]); \ + return vec_op::FP32Vec16(result); \ + }; + +#endif // __aarch64__ + #endif \ No newline at end of file diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index e1d131e4a7851..de0c505b7a62f 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -118,6 +118,24 @@ } \ } +#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \ + if (expr) { \ + constexpr bool const_expr = true; \ + __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + __VA_ARGS__(); \ + } + +#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \ + if (group_size == 128) { \ + constexpr int const_group_size = 128; \ + __VA_ARGS__(); \ + } else if (group_size == 64) { \ + constexpr int const_group_size = 64; \ + __VA_ARGS__(); \ + } + #define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \ switch (NUM_DIMS) { \ case 2: { \ diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp index df47bb8dd1d7d..58dc402016881 100644 --- a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -93,16 +93,16 @@ torch::Tensor dynamic_4bit_int_moe_cpu( } auto Y_all = at::empty({offsets[E], H}, x_c.options()); - at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + at::parallel_for(0, offsets[E], 0, [&](int64_t idx_begin, int64_t idx_end) { c10::InferenceMode guard; - for (int64_t e = e_begin; e < e_end; ++e) { - const int64_t te = counts[e]; - if (te == 0) { + for (int64_t e = 0; e < E; ++e) { + int64_t start = std::max(offsets[e], idx_begin); + int64_t end = std::min(offsets[e + 1], idx_end); + int64_t te = end - start; + if (te <= 0) { continue; } - const int64_t start = offsets[e]; - auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); auto w13_e = w13_packed.select(/*dim=*/0, e); diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 69b4c1fb11d1a..5fa367abd96f5 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) { return cuda_cast(sigmoid_accurate(f)); } -template +template +__device__ inline T apply_scoring(T val) { + if constexpr (SF == SCORING_SIGMOID) { + return apply_sigmoid(val); + } else { + return val; + } +} + +template __device__ void topk_with_k2(T* output, T const* input, T const* bias, cg::thread_block_tile<32> const& tile, int32_t const lane_id, - int const num_experts_per_group, - int const scoring_func) { + int const num_experts_per_group) { // Get the top2 per thread T largest = neg_inf(); T second_largest = neg_inf(); if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - T value = input[i]; - // Apply scoring function if needed - if (scoring_func == SCORING_SIGMOID) { - value = apply_sigmoid(value); - } + T value = apply_scoring(input[i]); value = value + bias[i]; if (value > largest) { @@ -472,17 +476,11 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } } else { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - T value = input[i]; - // Apply scoring function if needed - if (scoring_func == SCORING_SIGMOID) { - value = apply_sigmoid(value); - } + T value = apply_scoring(input[i]); value = value + bias[i]; largest = value; } } - - __syncwarp(); // Ensure all threads have valid data before reduction // Get the top2 warpwise T max1 = cg::reduce(tile, largest, cg::greater()); @@ -501,13 +499,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } } -template +template __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, int64_t const num_tokens, int64_t const num_cases, int64_t const n_group, - int64_t const num_experts_per_group, - int const scoring_func) { + int64_t const num_experts_per_group) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; @@ -525,21 +522,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - topk_with_k2(output, input, group_bias, tile, lane_id, - num_experts_per_group, scoring_func); + topk_with_k2(output, input, group_bias, tile, lane_id, + num_experts_per_group); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -template +template __global__ void group_idx_and_topk_idx_kernel( T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, bool renormalize, - double routed_scaling_factor, int scoring_func) { + double routed_scaling_factor) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = @@ -549,6 +546,11 @@ __global__ void group_idx_and_topk_idx_kernel( topk_values += case_id * topk; topk_indices += case_id * topk; + constexpr bool kUseStaticNGroup = (NGroup > 0); + // use int32 to avoid implicit conversion + int32_t const n_group_i32 = + kUseStaticNGroup ? NGroup : static_cast(n_group); + int32_t align_num_experts_per_group = warp_topk::round_up_to_multiple_of(num_experts_per_group); @@ -574,17 +576,17 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { // calculate group_idx - int32_t target_num_min = WARP_SIZE - n_group + topk_group; + int32_t target_num_min = + WARP_SIZE - n_group_i32 + static_cast(topk_group); // The check is necessary to avoid abnormal input - if (lane_id < n_group && is_finite(group_scores[lane_id])) { + if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) { value = group_scores[lane_id]; } - int count_equal_to_top_value = WARP_SIZE - n_group; + int count_equal_to_top_value = WARP_SIZE - n_group_i32; int pre_count_equal_to_top_value = 0; // Use loop to find the largset top_group while (count_equal_to_top_value < target_num_min) { - __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { value = neg_inf(); @@ -604,7 +606,7 @@ __global__ void group_idx_and_topk_idx_kernel( int count_equalto_topkth_group = 0; bool if_proceed_next_topk = topk_group_value != neg_inf(); if (case_id < num_tokens && if_proceed_next_topk) { - for (int i_group = 0; i_group < n_group; i_group++) { + auto process_group = [&](int i_group) { if ((group_scores[i_group] > topk_group_value) || ((group_scores[i_group] == topk_group_value) && (count_equalto_topkth_group < num_equalto_topkth_group))) { @@ -613,11 +615,10 @@ __global__ void group_idx_and_topk_idx_kernel( i += WARP_SIZE) { T candidates = neg_inf(); if (i < num_experts_per_group) { - // Apply scoring function (if any) and add bias + // apply scoring function (if any) and add bias T input = scores[offset + i]; if (is_finite(input)) { - T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) - : input; + T score = apply_scoring(input); candidates = score + bias[offset + i]; } } @@ -627,12 +628,21 @@ __global__ void group_idx_and_topk_idx_kernel( count_equalto_topkth_group++; } } + }; + + if constexpr (kUseStaticNGroup) { +#pragma unroll + for (int i_group = 0; i_group < NGroup; ++i_group) { + process_group(i_group); + } + } else { + for (int i_group = 0; i_group < n_group_i32; ++i_group) { + process_group(i_group); + } } queue.done(); - __syncwarp(); // Get the topk_idx queue.dumpIdx(s_topk_idx); - __syncwarp(); } // Load the valid score value @@ -646,12 +656,13 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { // Load the score value (without bias) for normalization T input = scores[s_topk_idx[i]]; - value = - (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input; + value = apply_scoring(input); s_topk_value[i] = value; } - topk_sum += - cg::reduce(tile, cuda_cast(value), cg::plus()); + if (renormalize) { + topk_sum += + cg::reduce(tile, cuda_cast(value), cg::plus()); + } } } @@ -660,13 +671,9 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { if (if_proceed_next_topk) { for (int i = lane_id; i < topk; i += WARP_SIZE) { - float value; - if (renormalize) { - value = cuda_cast(s_topk_value[i]) / topk_sum * - routed_scaling_factor; - } else { - value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; - } + float base = cuda_cast(s_topk_value[i]); + float value = renormalize ? (base / topk_sum * routed_scaling_factor) + : (base * routed_scaling_factor); topk_indices[i] = s_topk_idx[i]; topk_values[i] = value; } @@ -684,6 +691,45 @@ __global__ void group_idx_and_topk_idx_kernel( #endif } +template +inline void launch_group_idx_and_topk_kernel( + cudaLaunchConfig_t const& config, T* scores, T* group_scores, + float* topk_values, IdxT* topk_indices, T const* bias, + int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, + int64_t const topk, int64_t const num_experts, + int64_t const num_experts_per_group, bool const renormalize, + double const routed_scaling_factor) { + auto launch = [&](auto* kernel_instance2) { + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts_per_group, + renormalize, routed_scaling_factor); + }; + + switch (n_group) { + case 4: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + case 8: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + case 16: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + case 32: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + default: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + } +} + template void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, @@ -694,7 +740,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, cudaStream_t const stream = 0) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; - auto* kernel_instance1 = &topk_with_k2_kernel; cudaLaunchConfig_t config; config.gridDim = topk_with_k2_num_blocks; config.blockDim = BLOCK_SIZE; @@ -705,16 +750,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, - num_tokens, num_cases, n_group, num_experts / n_group, - scoring_func); + auto const sf = static_cast(scoring_func); + int64_t const num_experts_per_group = num_experts / n_group; + auto launch_topk_with_k2 = [&](auto* kernel_instance1) { + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, + num_tokens, num_cases, n_group, num_experts_per_group); + }; + switch (sf) { + case SCORING_NONE: { + auto* kernel_instance1 = &topk_with_k2_kernel; + launch_topk_with_k2(kernel_instance1); + break; + } + case SCORING_SIGMOID: { + auto* kernel_instance1 = &topk_with_k2_kernel; + launch_topk_with_k2(kernel_instance1); + break; + } + default: + // should be guarded by higher level checks. + TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); + } int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; size_t dynamic_smem_in_bytes = warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, topk); - auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; config.gridDim = topk_with_k_group_num_blocks; config.blockDim = BLOCK_SIZE; config.dynamicSmemBytes = dynamic_smem_in_bytes; @@ -723,10 +785,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts / n_group, - renormalize, routed_scaling_factor, scoring_func); + switch (sf) { + case SCORING_NONE: { + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, renormalize, routed_scaling_factor); + break; + } + case SCORING_SIGMOID: { + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, renormalize, routed_scaling_factor); + break; + } + default: + TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); + } } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 27b6ffaa67176..4fd8fc5c54202 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); -} +} \ No newline at end of file diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index b3d0c0aa58e9e..5c9e474024082 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -14,7 +14,6 @@ namespace vllm { namespace moe { - namespace batched_moe_align_block_size { // Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. @@ -80,17 +79,32 @@ __global__ void batched_moe_align_block_size_kernel( } // namespace batched_moe_align_block_size template -__global__ void moe_align_block_size_kernel( +__device__ void _moe_align_block_size( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, - size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) { + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, + int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id, + int32_t topk_num, int32_t* token_mask, bool has_expert_map) { extern __shared__ int32_t shared_counts[]; - // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { - sorted_token_ids[it] = numel; + // Compute input buffer offsets. Typically these will all be 0, except when + // using Multi LoRA. + int sorted_token_ids_offset = max_num_tokens_padded * model_offset; + int expert_ids_offset = max_num_m_blocks * model_offset; + int cumsum_offset = (num_experts + 1) * model_offset; + + // Use separate threadblocks to fill sorted_token_ids. + // This is safe since the current kernel does not use sorted_token_ids. + if (blockIdx.x % 2) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; + it += blockDim.x) { + sorted_token_ids[sorted_token_ids_offset + it] = numel; + } + return; } const int warp_id = threadIdx.x / WARP_SIZE; @@ -112,9 +126,16 @@ __global__ void moe_align_block_size_kernel( if (expert_id >= num_experts) { continue; } + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid experts + if (expert_id == -1) continue; + } int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); + int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], + mask); } __syncthreads(); @@ -135,48 +156,196 @@ __global__ void moe_align_block_size_kernel( int cumsum_val; BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val); if (expert_id <= num_experts) { - cumsum[expert_id] = cumsum_val; + cumsum[cumsum_offset + expert_id] = cumsum_val; } if (expert_id == num_experts) { - *total_tokens_post_pad = cumsum_val; + total_tokens_post_pad[model_offset] = cumsum_val; } __syncthreads(); if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + for (int i = cumsum[cumsum_offset + threadIdx.x]; + i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) { + expert_ids[expert_ids_offset + i / block_size] = threadIdx.x; } } // Fill remaining expert_ids with 0 - const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; - const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); - for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { - expert_ids[i] = 0; + const size_t fill_start_idx = + cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x; + for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) { + expert_ids[expert_ids_offset + i] = inactive_expert_id; + } +} + +template +__device__ void _moe_align_block_size_small_batch_expert( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, + size_t numel, int32_t max_num_tokens_padded, int32_t max_num_m_blocks, + int32_t inactive_expert_id, int32_t model_offset, int32_t topk_num, + int32_t* token_mask, bool has_expert_map) { + // Compute input buffer offsets. Typically these will all be 0, except when + // using Multi LoRA. + int sorted_token_ids_offset = max_num_tokens_padded * model_offset; + int expert_ids_offset = max_num_m_blocks * model_offset; + + // Use an additional group of threads to fill sorted_token_ids. + // Since the current kernel will use sorted_token_ids afterward, + // we fill sorted_token_ids within the same threadblock to make + // synchronization easier. + if (threadIdx.x < fill_threads) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; + it += fill_threads) { + sorted_token_ids[sorted_token_ids_offset + it] = numel; + } + // Three __syncthreads() corresponding to the other threads + __syncthreads(); + __syncthreads(); + __syncthreads(); + return; + } + + const size_t tid = threadIdx.x - fill_threads; + const size_t stride = blockDim.x - fill_threads; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(tid + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid expert + if (expert_id == -1) continue; + } + int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num]; + tokens_cnts[(tid + 1) * num_experts + expert_id] += mask; + } + + __syncthreads(); + + if (tid < num_experts) { + tokens_cnts[tid] = 0; + for (int i = 1; i <= stride; ++i) { + tokens_cnts[i * num_experts + tid] += + tokens_cnts[(i - 1) * num_experts + tid]; + } + } + + __syncthreads(); + + if (tid == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = + cumsum[i - 1] + + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * + block_size; + } + total_tokens_post_pad[model_offset] = + static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (tid < num_experts) { + for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { + expert_ids[expert_ids_offset + i / block_size] = tid; + } + } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + tid; + for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) { + expert_ids[expert_ids_offset + i] = inactive_expert_id; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid expert + if (expert_id == -1) continue; + } + int32_t rank_post_pad = + tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + + if (token_mask == nullptr || token_mask[i / topk_num]) { + sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i; + ++tokens_cnts[tid * num_experts + expert_id]; + } } } template -__global__ void count_and_sort_expert_tokens_kernel( +__device__ void _count_and_sort_expert_tokens( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - size_t numel, int32_t num_experts) { - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t* __restrict__ token_mask, + int32_t model_offset, int32_t topk_num, bool has_expert_map) { + const size_t tid = blockIdx.y * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.y; for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i]; if (expert_id >= num_experts) { continue; } - int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); - sorted_token_ids[rank_post_pad] = i; + + if (has_expert_map) { + expert_id = expert_map[expert_id]; + // filter invalid experts + if (expert_id == -1) continue; + } + + if (token_mask == nullptr || token_mask[i / topk_num]) { + int32_t rank_post_pad = atomicAdd( + &cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1); + sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] = + i; + } } } +template +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, + int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, + int32_t topk_num, bool has_expert_map) { + _moe_align_block_size( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, padded_num_experts, experts_per_warp, block_size, numel, + cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size), + 0, 0, topk_num, nullptr, has_expert_map); +} + +template +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) { + _count_and_sort_expert_tokens( + topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, + max_num_tokens_padded, nullptr, 0, topk_num, has_expert_map); +} + template __global__ void moe_sum_kernel( scalar_t* __restrict__ out, // [..., d] @@ -193,78 +362,111 @@ __global__ void moe_sum_kernel( } } -template +template __global__ void moe_align_block_size_small_batch_expert_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t max_num_tokens_padded) { - // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { - sorted_token_ids[it] = numel; + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, + size_t numel, int32_t max_num_tokens_padded, int32_t topk_num, + bool has_expert_map) { + _moe_align_block_size_small_batch_expert( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, block_size, numel, max_num_tokens_padded, + CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr, + has_expert_map); +} + +template +__global__ void moe_lora_align_block_size_kernel( + scalar_t* __restrict__ topk_ids, int32_t* __restrict__ token_lora_mapping, + int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, + int max_loras, size_t numel, int max_num_tokens_padded, + int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int32_t topk_num, + int32_t* total_tokens_post_pad, int32_t* adapter_enabled, + int32_t* __restrict__ cumsum, int32_t experts_per_warp, + int32_t padded_num_experts, int32_t* lora_ids, + int32_t* __restrict__ token_mask, bool has_expert_map) { + int lora_idx = blockIdx.x / 2; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; } - const size_t tid = threadIdx.x; - const size_t stride = blockDim.x; - - extern __shared__ int32_t shared_mem[]; - int32_t* cumsum = shared_mem; - int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; - } - - for (size_t i = tid; i < numel; i += stride) { - ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; - } - - __syncthreads(); - - if (threadIdx.x < num_experts) { - tokens_cnts[threadIdx.x] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[i * num_experts + threadIdx.x] += - tokens_cnts[(i - 1) * num_experts + threadIdx.x]; - } - } - - __syncthreads(); - + // Populate the token_mask based on the token-LoRA mapping + int num_tokens = numel / topk_num; if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = - cumsum[i - 1] + - CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * - block_size; + total_tokens_post_pad[lora_id] = 0; + + for (int i = 0; i < num_tokens; i++) { + token_mask[(lora_id * num_tokens) + i] = + (int)token_lora_mapping[i] == lora_id; } - *total_tokens_post_pad = static_cast(cumsum[num_experts]); } __syncthreads(); - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; + _moe_align_block_size( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, padded_num_experts, experts_per_warp, block_size, numel, + cumsum, max_num_tokens_padded, max_num_m_blocks, lora_id, -1, topk_num, + &token_mask[(lora_id * num_tokens)], has_expert_map); +} + +template +__global__ void lora_count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts, + int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask, + int32_t* lora_ids, bool has_expert_map) { + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1) { + return; + } + + int num_tokens = numel / topk_num; + + _count_and_sort_expert_tokens( + topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts, + max_num_tokens_padded, &token_mask[(lora_id * num_tokens)], lora_id, + topk_num, has_expert_map); +} + +template +__global__ void moe_lora_align_block_size_small_batch_expert_kernel( + scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping, + int64_t block_size, int32_t* __restrict__ expert_map, int num_experts, + int max_loras, size_t numel, int max_num_tokens_padded, + int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, int topk_num, + int32_t* total_tokens_post_pad, int32_t* adapter_enabled, int32_t* lora_ids, + int32_t* token_mask, bool has_expert_map) { + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; + } + + int num_tokens = numel / topk_num; + if (threadIdx.x == 0) { + total_tokens_post_pad[lora_id] = 0; + + for (int i = 0; i < num_tokens; i++) { + token_mask[(lora_id * num_tokens) + i] = + (int)token_lora_mapping[i] == lora_id; } } - // Fill remaining expert_ids with 0 - const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; - const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); - for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { - expert_ids[i] = 0; - } + __syncthreads(); - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - int32_t rank_post_pad = - tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[threadIdx.x * num_experts + expert_id]; - } + _moe_align_block_size_small_batch_expert( + topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map, + num_experts, block_size, numel, max_num_tokens_padded, max_num_m_blocks, + -1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)], + has_expert_map); } } // namespace moe @@ -275,7 +477,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { + torch::Tensor num_tokens_post_pad, + std::optional maybe_expert_map) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int64_t padded_num_experts = @@ -287,14 +490,19 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, // BlockScan uses 1024 threads and assigns one thread per expert. TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + bool has_expert_map = maybe_expert_map.has_value(); + torch::Tensor expert_map; + if (has_expert_map) { + expert_map = maybe_expert_map.value(); + } else { + expert_map = torch::empty({0}, options_int); + } VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); - torch::Tensor cumsum_buffer = - torch::empty({num_experts + 1}, options_int); bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); @@ -304,43 +512,58 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; auto small_batch_expert_kernel = vllm::moe::moe_align_block_size_small_batch_expert_kernel< - scalar_t>; - small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>( + scalar_t, fill_threads>; + small_batch_expert_kernel<<<1, fill_threads + threads, + shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), sorted_token_ids.size(0)); + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), num_experts, block_size, + topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1), + has_expert_map); } else { + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); auto align_kernel = vllm::moe::moe_align_block_size_kernel; size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); - align_kernel<<<1, threads, shared_mem_size, stream>>>( + // launch two threadblocks + // blockIdx.x == 0: counting experts and aligning + // blockIdx.x == 1: filling sorted_token_ids + align_kernel<<<2, threads, shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, - padded_num_experts, experts_per_warp, block_size, - topk_ids.numel(), cumsum_buffer.data_ptr(), - sorted_token_ids.size(0)); + num_tokens_post_pad.data_ptr(), + expert_map.data_ptr(), num_experts, padded_num_experts, + experts_per_warp, block_size, topk_ids.numel(), + cumsum_buffer.data_ptr(), sorted_token_ids.size(0), + topk_ids.size(1), has_expert_map); const int block_threads = std::min(256, (int)threads); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; const int max_blocks = 65535; const int actual_blocks = std::min(num_blocks, max_blocks); + dim3 gridDims(1, actual_blocks); auto sort_kernel = vllm::moe::count_and_sort_expert_tokens_kernel; - sort_kernel<<>>( + sort_kernel<<>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), topk_ids.numel(), num_experts); + cumsum_buffer.data_ptr(), expert_map.data_ptr(), + topk_ids.numel(), num_experts, sorted_token_ids.size(0), + topk_ids.size(1), has_expert_map); } }); } @@ -414,3 +637,123 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] break; } } + +void moe_lora_align_block_size( + torch::Tensor topk_ids, torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids, std::optional maybe_expert_map) { + const int topk_num = topk_ids.size(1); + + TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); + + int device_max_shared_mem; + auto dev = topk_ids.get_device(); + cudaDeviceGetAttribute(&device_max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t padded_num_experts = + ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + + // BlockScan uses 1024 threads and assigns one thread per expert. + TORCH_CHECK(padded_num_experts < 1024, + "padded_num_experts must be less than 1024"); + + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + torch::Tensor token_mask = + torch::empty({max_loras * topk_ids.size(0)}, options_int); + bool has_expert_map = maybe_expert_map.has_value(); + torch::Tensor expert_map; + if (has_expert_map) { + expert_map = maybe_expert_map.value(); + } else { + expert_map = torch::empty({0}, options_int); + } + + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { + bool small_batch_expert_mode = + (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t num_thread = max((int32_t)num_experts, 128); + const int32_t shared_mem = + (num_thread + 1) * num_experts * sizeof(int32_t) + + (num_experts + 1) * sizeof(int32_t); + if (shared_mem > device_max_shared_mem) { + TORCH_CHECK(false, "Shared memory usage exceeds device limit."); + } + + // threadIdx.x >= fill_threads: counting experts and aligning + // threadIdx.x < fill_threads: filling sorted_token_ids + constexpr int32_t fill_threads = 256; + + dim3 blockDim(num_thread + fill_threads); + auto kernel = + vllm::moe::moe_lora_align_block_size_small_batch_expert_kernel< + scalar_t, fill_threads>; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); + kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), block_size, + expert_map.data_ptr(), num_experts, max_loras, + topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), lora_ids.data_ptr(), + token_mask.data_ptr(), has_expert_map); + } else { + int num_thread = 1024; + dim3 blockDim(num_thread); + size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE); + + size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t); + + // cumsum buffer + torch::Tensor cumsum = + torch::zeros({max_loras * (num_experts + 1)}, options_int); + + auto align_kernel = + vllm::moe::moe_lora_align_block_size_kernel; + + // launch two threadblocks for each lora + // blockIdx.x % 2 == 0: counting experts and aligning + // blockIdx.x % 2 == 1: filling sorted_token_ids + align_kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), block_size, + expert_map.data_ptr(), num_experts, max_loras, + topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), cumsum.data_ptr(), + WARP_SIZE, padded_num_experts, lora_ids.data_ptr(), + token_mask.data_ptr(), has_expert_map); + + const int block_threads = std::min(256, (int)num_thread); + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + dim3 gridDims(max_loras, actual_blocks); + auto sort_kernel = + vllm::moe::lora_count_and_sort_expert_tokens_kernel; + + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), cumsum.data_ptr(), + expert_map.data_ptr(), topk_ids.numel(), num_experts, + max_num_tokens_padded, topk_num, token_mask.data_ptr(), + lora_ids.data_ptr(), has_expert_map); + } + }); +} \ No newline at end of file diff --git a/csrc/moe/moe_lora_align_sum_kernels.cu b/csrc/moe/moe_lora_align_sum_kernels.cu deleted file mode 100644 index 360f1312cf579..0000000000000 --- a/csrc/moe/moe_lora_align_sum_kernels.cu +++ /dev/null @@ -1,174 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "../cuda_compat.h" -#include "../dispatch_utils.h" -#include "core/math.hpp" - -namespace { - -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, - int32_t col) { - return row * total_col + col; -} - -} // namespace - -// TODO: Refactor common parts with moe_align_sum_kernels -template -__global__ void moe_lora_align_sum_kernel( - scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping, - int64_t block_size, int num_experts, int max_loras, size_t numel, - int max_num_tokens_padded, int max_num_m_blocks, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled, - int32_t* lora_ids) { - const size_t tokens_per_thread = div_ceil(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - int lora_idx = blockIdx.x; - int lora_id = lora_ids[lora_idx]; - if (lora_id == -1 || adapter_enabled[lora_id] == 0) { - return; - } - extern __shared__ int32_t shared_mem[]; - int32_t* cumsum = shared_mem; - token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); - - // Initialize sorted_token_ids with numel - for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { - sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel; - } - - // Initialize expert_ids with -1 - for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) { - expert_ids[lora_id * max_num_m_blocks + it] = -1; - } - - // Initialize total_tokens_post_pad with 0 - if (threadIdx.x == 0) { - total_tokens_post_pad[lora_id] = 0; - } - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int mask = token_lora_mapping[i / topk_num] == lora_id; - int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]); - tokens_cnts[idx] += mask; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - if (threadIdx.x < num_experts) { - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; - } - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + - div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size) * - block_size; - } - total_tokens_post_pad[lora_id] = static_cast(cumsum[num_experts]); - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding - * blocks and stores the corresponding expert_id for each block. - */ - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] = - threadIdx.x; - } - } - - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id]; - - int mask = (int)token_lora_mapping[i / topk_num] == lora_id; - atomicAdd( - &sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)], - (i - numel) * mask); - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask; - } -} - -void moe_lora_align_block_size( - torch::Tensor topk_ids, torch::Tensor token_lora_mapping, - int64_t num_experts, int64_t block_size, int64_t max_loras, - int64_t max_num_tokens_padded, int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, - torch::Tensor lora_ids) { - const int topk_num = topk_ids.size(1); - - TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); - - int device_max_shared_mem; - auto dev = topk_ids.get_device(); - cudaDeviceGetAttribute(&device_max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE, - TORCH_CHECK(num_thread <= 1024, - "num_thread must be less than 1024, " - "and fallback is not implemented yet."); - const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + - (num_experts + 1) * sizeof(int32_t); - - if (shared_mem > device_max_shared_mem) { - TORCH_CHECK(false, - "Shared memory usage exceeds device limit, and global memory " - "fallback is not implemented yet."); - } - - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { - dim3 blockDim(num_thread); - auto kernel = moe_lora_align_sum_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem)); - kernel<<>>( - topk_ids.data_ptr(), - token_lora_mapping.data_ptr(), block_size, num_experts, - max_loras, topk_ids.numel(), max_num_tokens_padded, - max_num_m_blocks, sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr(), - adapter_enabled.data_ptr(), lora_ids.data_ptr()); - }); -} \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 11c6875f7f1d0..337dcc50b079e 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -11,7 +11,8 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); + torch::Tensor num_tokens_post_pad, + std::optional maybe_expert_map); void batched_moe_align_block_size(int64_t max_tokens_per_batch, int64_t block_size, @@ -26,7 +27,7 @@ void moe_lora_align_block_size( int64_t max_num_tokens_padded, int64_t max_num_m_blocks, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, - torch::Tensor lora_ids); + torch::Tensor lora_ids, std::optional maybe_expert_map); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index e0a8280722f3c..779ad70ad1e09 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "moe_align_block_size(Tensor topk_ids, int num_experts," " int block_size, Tensor! sorted_token_ids," " Tensor! experts_ids," - " Tensor! num_tokens_post_pad) -> ()"); + " Tensor! num_tokens_post_pad," + " Tensor? maybe_expert_map) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); // Aligning the number of tokens to be processed by each expert such @@ -46,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor !experts_ids," " Tensor !num_tokens_post_pad," " Tensor !adapter_enabled," - " Tensor !lora_ids) -> () "); + " Tensor !lora_ids," + " Tensor? maybe_expert_map) -> () "); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); #ifndef USE_ROCM diff --git a/csrc/ops.h b/csrc/ops.h index 4bb7857b15032..37e3aaf7499d5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -102,13 +102,16 @@ void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& output_mask, const torch::Tensor& repetition_penalties); -void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, - const torch::Tensor& rowEnds, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1); +void top_k_per_row_prefill(const torch::Tensor& logits, + const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK); void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, - const torch::Tensor& seq_lens, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1); + const torch::Tensor& seqLens, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK); void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, @@ -128,6 +131,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out, std::optional scale_ub, std::optional residual); +void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, double const epsilon, + std::optional scale_ub, + std::optional residual, + int64_t group_size, bool is_scale_transposed); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); @@ -252,7 +262,8 @@ void get_cutlass_moe_mm_data( void get_cutlass_moe_mm_problem_sizes( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets); + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt); void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, @@ -299,6 +310,14 @@ void per_token_group_quant_int8(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double int8_min, double int8_max); + +// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales. +void per_token_group_quant_8bit_packed(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s_packed, + int64_t group_size, double eps, + double min_8bit, double max_8bit); + #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w4a8/get_group_starts.cuh b/csrc/quantization/cutlass_w4a8/get_group_starts.cuh new file mode 100644 index 0000000000000..fec142d0d87a1 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/get_group_starts.cuh @@ -0,0 +1,104 @@ +// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh +#pragma once + +#include +#include +#include + +#include "core/scalar_type.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/float8.h" + +// ElementB is int32 (packed int4) +// ElementGroupScale is cutlass::Array (packed fp8) +template +__global__ void get_group_gemm_starts( + int64_t* expert_offsets, ElementA** a_offsets, ElementB** b_offsets, + ElementC** out_offsets, ElementAccumulator** a_scales_offsets, + ElementAccumulator** b_scales_offsets, + ElementGroupScale** b_group_scales_offsets, ElementA* a_base_as_int, + ElementB* b_base_as_int, ElementC* out_base_as_int, + ElementAccumulator* a_scales_base_as_int, + ElementAccumulator* b_scales_base_as_int, + ElementGroupScale* b_group_scales_base_as_int, int64_t n, int64_t k, + int64_t scale_k) { + int expert_id = threadIdx.x; + + int64_t expert_offset = expert_offsets[expert_id]; + + // same as w8a8 + a_offsets[expert_id] = a_base_as_int + expert_offset * k; + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + a_scales_offsets[expert_id] = a_scales_base_as_int + expert_offset; + b_scales_offsets[expert_id] = b_scales_base_as_int + (n * expert_id); + + // w4a8 specific + constexpr int pack_factor = 8; // pack 8 int4 into int32 + b_offsets[expert_id] = b_base_as_int + (expert_id * k * n / pack_factor); + b_group_scales_offsets[expert_id] = + b_group_scales_base_as_int + (expert_id * scale_k * n); +} + +#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + get_group_gemm_starts> \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(a_ptrs.data_ptr()), \ + static_cast(b_ptrs.data_ptr()), \ + static_cast(out_ptrs.data_ptr()), \ + static_cast(a_scales_ptrs.data_ptr()), \ + static_cast(b_scales_ptrs.data_ptr()), \ + static_cast**>( \ + b_group_scales_ptrs.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + static_cast*>( \ + b_group_scales.data_ptr()), \ + n, k, scale_k); \ + } + +namespace { + +void run_get_group_gemm_starts( + torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, + torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, + torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, + torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor& out_tensors, + torch::Tensor const& a_scales, torch::Tensor const& b_scales, + torch::Tensor const& b_group_scales, const int64_t b_group_size) { + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32 + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_group_scales.dtype() == + torch::kFloat8_e4m3fn); // the underlying torch type is e4m3 + TORCH_CHECK(out_tensors.dtype() == + torch::kBFloat16); // only support bf16 for now + // expect int64_t to avoid overflow during offset calculations + TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); + + int num_experts = static_cast(expert_offsets.size(0)); + // logical k, n + int64_t n = out_tensors.size(1); + int64_t k = a_tensors.size(1); + int64_t scale_k = cutlass::ceil_div(k, b_group_size); + + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + if (false) { + } + __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +} // namespace \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu new file mode 100644 index 0000000000000..4b425790dbac7 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu @@ -0,0 +1,483 @@ +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +// vllm includes +#include +#include +#include +#include "cutlass_extensions/torch_utils.hpp" +#include "cutlass_extensions/common.hpp" + +#include "core/registration.h" +#include "get_group_starts.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "w4a8_utils.cuh" + +namespace vllm::cutlass_w4a8_moe { + +using namespace cute; + +// ------------------------------------------------------------------------------------- +// Static configuration shared across all instantiations +// ------------------------------------------------------------------------------------- +using ProblemShape = + cutlass::gemm::GroupProblemShape>; // per + // group +using MmaType = cutlass::float_e4m3_t; +using QuantType = cutlass::int4b_t; + +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; +static int constexpr PackFactor = 8; // 8 int4 packed into int32 + +// A matrix configuration +using ElementA = MmaType; +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = + 128 / + cutlass::sizeof_bits::value; // Alignment of A matrix in units of + // elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input +// layouts +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; + +// Need to pass a pointer type to make the 3rd dimension of Stride be _0 +using StrideA = + cute::remove_pointer_t>; +using StrideB = + cute::remove_pointer_t>; + +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in +// contiguous locations in global memory. It specifies the reordering within a +// single warp's fragment +using LayoutAtomQuant = + decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, Layout>, StrideB>{})); + +using ElementScale = cutlass::float_e4m3_t; +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = + cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = + cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + +// per-channel and per-token scales for epilogue +using ElementSChannel = float; + +template +struct W4A8GroupedGemmKernel { + using TileShape = + decltype(cute::append(TileShape_MN{}, cute::Int{})); + using ClusterShape = ClusterShape_MNK; + + // per-channel, per-token scales epilogue + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogueArray; + using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementSChannel, ElementC, + typename cutlass::layout::LayoutTranspose::type*, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type*, + AlignmentD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + // =========================================================== MIXED INPUT + // WITH SCALES + // =========================================================================== + // The Scale information must get paired with the operand that will be scaled. + // In this example, B is scaled so we make a tuple of B's information and the + // scale information. + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, + LayoutB_Reordered*, AlignmentB, ElementA, LayoutA_Transpose*, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, CollectiveMainloopShuffled, CollectiveEpilogue>; + + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::InternalStrideC; + using StrideD = typename GemmKernelShuffled::InternalStrideD; + + using StrideC_ref = cutlass::detail::TagToStrideC_t; + using StrideD_ref = cutlass::detail::TagToStrideC_t; + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + using StrideS_ref = cutlass::detail::TagToStrideB_t; + + // static asserts for passing in strides/layouts + // pack to 2x int64 + static_assert(sizeof(StrideS) == 2 * sizeof(int64_t)); + // pack to 3xint32, + static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0, + "LayoutB_Reordered size must be divisible by 4 bytes"); + + static void grouped_mm( + torch::Tensor& out_tensors, const torch::Tensor& a_tensors, + const torch::Tensor& b_tensors, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, + const int64_t b_group_size, const torch::Tensor& expert_offsets, + const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides, + const torch::Tensor& b_strides, const torch::Tensor& c_strides, + const torch::Tensor& group_scale_strides) { + auto device = a_tensors.device(); + auto device_id = device.index(); + const at::cuda::OptionalCUDAGuard device_guard(device); + auto stream = at::cuda::getCurrentCUDAStream(device_id); + + int num_experts = static_cast(expert_offsets.size(0)); + int n = static_cast(b_tensors.size(1)); + int k = static_cast(b_tensors.size(2)) * PackFactor; + + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(device); + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int); + + // get the correct offsets to pass to gemm + run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, + a_scales_ptrs, b_scales_ptrs, b_group_scales_ptrs, + a_tensors, b_tensors, out_tensors, a_scales, + b_scales, b_group_scales, b_group_size); + + // construct args + using Args = typename GemmShuffled::Arguments; + using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; + using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; + Args arguments; + + ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast( + problem_sizes_torch.data_ptr()); + ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; + + // SwapAB so B operands come first + MainloopArguments mainloop_arguments{ + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast**>( + b_group_scales_ptrs.data_ptr()), + static_cast(group_scale_strides.data_ptr()), + static_cast(b_group_size)}; + + EpilogueArguments epilogue_arguments{ + // since we are doing SwapAB the channel scales comes first, then token + // scales + ChTokScalesEpilogue::prepare_args( // see ScaledEpilogueArray + static_cast( + b_scales_ptrs.data_ptr()), // per-channel + static_cast( + a_scales_ptrs.data_ptr()), // per-token + true, true), + nullptr, // C + static_cast(c_strides.data_ptr()), // C + static_cast(out_ptrs.data_ptr()), // D + static_cast(c_strides.data_ptr()) // D + }; + + static const cutlass::KernelHardwareInfo hw_info{ + device_id, + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + device_id)}; + + arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, + mainloop_arguments, epilogue_arguments, hw_info}; + + // Allocate workspace + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); + torch::Tensor workspace = + torch::empty(workspace_size, + torch::TensorOptions().dtype(torch::kU8).device(device)); + + // Run GEMM + GemmShuffled gemm; + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(stream)); + } +}; + +// ---------------------------------------------------------------------------- +// Kernel instantiations and dispatch logic +// ---------------------------------------------------------------------------- +using Coop = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; +using CoopEpi = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + +// Kernel_TileShape_ClusterShape_Schedule +using Kernel_128x16_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_128x16_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x16_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x16_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x32_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x32_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x64_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x64_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_256x128_1x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; +using Kernel_256x128_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +using Kernel_128x256_2x1x1_Coop = + W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; + +void mm_dispatch( + torch::Tensor& out_tensors, const torch::Tensor& a_tensors, + const torch::Tensor& b_tensors, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, + const int64_t b_group_size, const torch::Tensor& expert_offsets, + const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, + const torch::Tensor& b_strides, const torch::Tensor& c_strides, + const torch::Tensor& group_scale_strides, const std::string& schedule) { + if (schedule == "Kernel_128x16_1x1x1_Coop") { + Kernel_128x16_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_128x16_2x1x1_Coop") { + Kernel_128x16_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x16_1x1x1_Coop") { + Kernel_256x16_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x16_2x1x1_Coop") { + Kernel_256x16_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x32_1x1x1_Coop") { + Kernel_256x32_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x32_2x1x1_Coop") { + Kernel_256x32_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x64_1x1x1_Coop") { + Kernel_256x64_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x64_2x1x1_Coop") { + Kernel_256x64_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x128_1x1x1_Coop") { + Kernel_256x128_1x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_256x128_2x1x1_Coop") { + Kernel_256x128_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else if (schedule == "Kernel_128x256_2x1x1_Coop") { + Kernel_128x256_2x1x1_Coop::grouped_mm( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, + b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, group_scale_strides); + } else { + TORCH_CHECK(false, + "cutlass_w4a8_moe_mm: unknown schedule string: ", schedule); + } +} + +void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors, + const torch::Tensor& b_tensors, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, + const int64_t b_group_size, const torch::Tensor& expert_offsets, + const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, + const torch::Tensor& b_strides, const torch::Tensor& c_strides, + const torch::Tensor& group_scale_strides, + std::optional maybe_schedule) { + // user has specified a schedule + if (maybe_schedule) { + mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + b_group_scales, b_group_size, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides, group_scale_strides, + *maybe_schedule); + return; + } + + // use heuristic + int m_full = a_tensors.size(0); + int n = b_tensors.size(1); + int k = b_tensors.size(2) * PackFactor; // logical k + int num_experts = b_tensors.size(0); + // per-expert batch size assuming uniform distribution + int m_expert = m_full / num_experts; + + std::string schedule; + if (m_expert <= 16) { + schedule = "Kernel_128x16_2x1x1_Coop"; + } else if (m_expert <= 32) { + schedule = "Kernel_256x32_1x1x1_Coop"; + } else if (m_expert <= 64) { + schedule = "Kernel_256x64_1x1x1_Coop"; + } else if (m_expert <= 128) { + schedule = "Kernel_256x128_2x1x1_Coop"; + } else { // m_expert > 128 + schedule = "Kernel_128x256_2x1x1_Coop"; + } + + mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + b_group_scales, b_group_size, expert_offsets, problem_sizes, + a_strides, b_strides, c_strides, group_scale_strides, schedule); +} + +std::tuple encode_and_reorder_int4b( + torch::Tensor const& b_tensors) { + TORCH_CHECK(b_tensors.dtype() == torch::kInt32); + TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k) + TORCH_CHECK(b_tensors.is_contiguous()); + TORCH_CHECK(b_tensors.is_cuda()); + + int n = static_cast(b_tensors.size(1)); + int k = static_cast(b_tensors.size(2)) * PackFactor; // logical k + + // CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0. + // These misalignments cause silent OOB unless run under Compute Sanitizer. + TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256"); + TORCH_CHECK(n % 16 == 0, "n must be divisible by 16"); + + // we will store the layout to an int32 tensor; + // this is the number of elements we need per layout + constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t); + + torch::Tensor b_tensors_packed = torch::empty_like(b_tensors); + int num_experts = static_cast(b_tensors.size(0)); + + auto b_ptr = static_cast(b_tensors.const_data_ptr()); + auto b_packed_ptr = static_cast(b_tensors_packed.data_ptr()); + + // multiply by ull so result does not overflow int32 + size_t num_int4_elems = 1ull * num_experts * n * k; + bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr, + num_int4_elems); + TORCH_CHECK(ok, "unified_encode_int4b failed"); + + // construct the layout once; assumes each expert has the same layout + using LayoutType = LayoutB_Reordered; + std::vector layout_B_reordered_host(num_experts); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, Int<1>{}}); + auto shape_B = cute::make_shape(n, k, Int<1>{}); + auto layout_B = make_layout(shape_B, stride_B); + LayoutType layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); + + // reorder weights for each expert + for (int i = 0; i < num_experts; i++) { + // since the storage type of int4b is 1 byte but one element is 4 bits + // we need to adjust the offset + int64_t offset = + 1ull * i * n * k * cutlass::sizeof_bits::value / 8; + cutlass::reorder_tensor(b_packed_ptr + offset, layout_B, + layout_B_reordered); + } + + // save the packed layout to torch tensor so we can re-use it + auto cpu_opts = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); + torch::Tensor layout_cpu = + torch::empty({num_experts, layout_width}, cpu_opts); + + int32_t* layout_data = layout_cpu.data_ptr(); + for (int i = 0; i < num_experts; ++i) { + std::memcpy(layout_data + i * layout_width, // dst (int32*) + &layout_B_reordered, // src (LayoutType*) + sizeof(LayoutType)); // number of bytes + } + + torch::Tensor packed_layout = + layout_cpu.to(b_tensors.device(), /*non_blocking=*/false); + + return {b_tensors_packed, packed_layout}; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_w4a8_moe_mm", &mm); + m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b); +} + +} // namespace vllm::cutlass_w4a8_moe +///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu index 2d1568b08651c..f77af06cd6c08 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -7,6 +7,7 @@ #include #include #include "cutlass_extensions/torch_utils.hpp" +#include "w4a8_utils.cuh" #include "core/registration.h" @@ -395,71 +396,6 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { return packed_scales; } -/* - GPU-accelerated implementation of cutlass::unified_encode_int4b. - Constructs a lookup table in constant memory to map 8 bits - (two 4-bit values) at a time. Assumes memory is contiguous - and pointers are 16-byte aligned. -*/ -__constant__ uint8_t kNibbleLUT[256]; - -__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out, - size_t nbytes) { - constexpr size_t V = sizeof(uint4); // 16 bytes - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t nthreads = size_t(gridDim.x) * blockDim.x; - const size_t nvec = nbytes / V; - - // 1-D grid-stride loop over 16-byte chunks - for (size_t vec = tid; vec < nvec; vec += nthreads) { - uint4 v = reinterpret_cast(in)[vec]; - uint8_t* b = reinterpret_cast(&v); -#pragma unroll - for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]]; - reinterpret_cast(out)[vec] = v; - } -} - -static bool upload_lut() { - std::array lut{}; - auto map_nib = [](uint8_t v) -> uint8_t { - // 1..7 -> (8 - v); keep 0 and 8..15 - return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v); - }; - for (int b = 0; b < 256; ++b) { - uint8_t lo = b & 0xF; - uint8_t hi = (b >> 4) & 0xF; - lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo)); - } - cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(), - /*offset=*/0, cudaMemcpyHostToDevice); - - return (e == cudaSuccess); -} - -static bool unified_encode_int4b(cutlass::int4b_t const* in, - cutlass::int4b_t* out, size_t num_int4_elems) { - // Build/upload LUT - if (!upload_lut()) return false; - - static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1, - "int4 storage must be 1 byte"); - const size_t nbytes = num_int4_elems >> 1; - - auto* in_bytes = reinterpret_cast(in); - auto* out_bytes = reinterpret_cast(out); - - // kernel launch params - constexpr int block = 256; - const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors - int grid = int((nvec + block - 1) / block); - if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel - - unified_encode_int4b_device<<>>(in_bytes, out_bytes, nbytes); - cudaError_t err = cudaGetLastError(); - return (err == cudaSuccess); -} - torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { TORCH_CHECK(B.dtype() == torch::kInt32); TORCH_CHECK(B.dim() == 2); @@ -477,8 +413,8 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { LayoutB_Reordered layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B); - bool ok = - vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr, + n * k); TORCH_CHECK(ok, "unified_encode_int4b failed"); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); diff --git a/csrc/quantization/cutlass_w4a8/w4a8_utils.cu b/csrc/quantization/cutlass_w4a8/w4a8_utils.cu new file mode 100644 index 0000000000000..f238d0a5b2d78 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_utils.cu @@ -0,0 +1,90 @@ +#include "w4a8_utils.cuh" + +#include +#include +#include + +namespace vllm::cutlass_w4a8_utils { + +/* + GPU-accelerated implementation of cutlass::unified_encode_int4b. + Constructs a lookup table in constant memory to map 8 bits + (two 4-bit values) at a time. Assumes memory is contiguous + and pointers are 16-byte aligned. +*/ +__constant__ uint8_t kNibbleLUT[256]; + +__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out, + size_t nbytes) { + constexpr size_t V = sizeof(uint4); // 16 bytes + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t nthreads = size_t(gridDim.x) * blockDim.x; + const size_t nvec = nbytes / V; + + // 1-D grid-stride loop over 16-byte chunks + for (size_t vec = tid; vec < nvec; vec += nthreads) { + uint4 v = reinterpret_cast(in)[vec]; + uint8_t* b = reinterpret_cast(&v); +#pragma unroll + for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]]; + reinterpret_cast(out)[vec] = v; + } +} + +static bool upload_lut() { + std::array lut{}; + auto map_nib = [](uint8_t v) -> uint8_t { + // 1..7 -> (8 - v); keep 0 and 8..15 + return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v); + }; + for (int b = 0; b < 256; ++b) { + uint8_t lo = b & 0xF; + uint8_t hi = (b >> 4) & 0xF; + lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo)); + } + cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(), + /*offset=*/0, cudaMemcpyHostToDevice); + + return (e == cudaSuccess); +} + +bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out, + size_t num_int4_elems) { + // Build/upload LUT + if (!upload_lut()) return false; + + static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1, + "int4 storage must be 1 byte"); + const size_t nbytes = num_int4_elems >> 1; + + auto* in_bytes = reinterpret_cast(in); + auto* out_bytes = reinterpret_cast(out); + + // kernel launch params + constexpr int block = 256; + const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors + int grid = int((nvec + block - 1) / block); + if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel + + unified_encode_int4b_device<<>>(in_bytes, out_bytes, nbytes); + + // launch errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("unified_encode_int4b_device launch error: %s (%d)\n", + cudaGetErrorString(err), err); + return false; + } + + // runtime errors + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf("unified_encode_int4b_device runtime error: %s (%d)\n", + cudaGetErrorString(err), err); + return false; + } + + return true; +} + +} // namespace vllm::cutlass_w4a8_utils \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh b/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh new file mode 100644 index 0000000000000..25090091a368d --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh @@ -0,0 +1,11 @@ +#pragma once + +#include +#include "cutlass/numeric_types.h" + +namespace vllm::cutlass_w4a8_utils { + +bool unified_encode_int4b(cutlass::int4b_t const* in, cutlass::int4b_t* out, + size_t num_int4_elems); + +} // namespace vllm::cutlass_w4a8_utils \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu index 9cba2828aac2e..d9c4d24d8e1f2 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -15,6 +15,8 @@ */ #include +#include +#include "cutlass_extensions/common.hpp" #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, @@ -32,23 +34,34 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& alpha); #endif -void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha) { -#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 - return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); -#elif defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 - return cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); +void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& A_sf, + const torch::Tensor& B_sf, + const torch::Tensor& alpha) { + // Make sure we’re on A’s device. + const c10::cuda::OptionalCUDAGuard device_guard(device_of(A)); + const int32_t sm = get_sm_version_num(); + +#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + if (sm >= 100 && sm < 120) { + cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); + return; + } #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, - "No compiled nvfp4 mm kernel, vLLM should " - "be compiled using CUDA 12.8 and target " - "compute capability 100 or above."); + +#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120 + if (sm >= 120 && sm < 130) { + cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); + return; + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm, + ". Recompile with CUDA >= 12.8 and CC >= 100."); } bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) { int runtimeVersion; cudaRuntimeGetVersion(&runtimeVersion); return cuda_device_capability >= 100 && runtimeVersion >= 12080; -} \ No newline at end of file +} diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 92d6c2f402a24..2080ef3cd39b5 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -31,14 +31,15 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( // RMS Norm + Quant if constexpr (std::is_same_v) { + token_scale = 1.0f / token_scale; vllm::vectorized::norm_and_quant( - out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } else { // FP8 - Do not invert token_scale for exact match with FBGemm vllm::vectorized::norm_and_quant( - out, input, weight, rms, token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } } @@ -75,14 +76,52 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( // RMS Norm + Quant if constexpr (std::is_same_v) { + token_scale = 1.0f / token_scale; vllm::norm_and_quant( - out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } else { // FP8 - Do not invert s_token_scale for exact match with FBGemm vllm::norm_and_quant( - out, input, weight, rms, token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } } + +// RMS norm + quant kernel +template +__global__ void rms_norm_per_block_quant_kernel( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens, hidden_size / group_size] + // or + // [hidden_size / group_size, num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + float rms; + // Compute RMS + // Always able to vectorize due to constraints on hidden_size + vllm::vectorized::compute_rms( + &rms, input, hidden_size, var_epsilon, residual); + + // Compute Scale + // Always able to vectorize due to constraints on hidden_size and group_size + vllm::vectorized::compute_dynamic_per_token_scales< + scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( + nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual); + + // RMS Norm + Quant + // Always able to vectorize due to constraints on hidden_size + // For int8, don't invert token_scale here: do it inside the norm_and_quant + // kernel. We do it because particular elements of token_scale can be shared + // between multiple threads, so this way, we avoid extra synchronization + // overhead. + vllm::vectorized::norm_and_quant< + scalar_t, scalar_out_t, std::is_same_v, + has_residual, is_scale_transposed, group_size>( + out, input, weight, rms, scales, hidden_size, residual); +} + } // namespace vllm // Residual add + RMS norm + dynamic per token @@ -103,30 +142,19 @@ void rms_norm_dynamic_per_token_quant_dispatch( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (residual.has_value()) { + VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { vllm::rms_norm_dynamic_per_token_quant_kernel + has_residual> <<>>( out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, residual->data_ptr()); + var_epsilon, hidden_size, + has_residual ? residual->data_ptr() : nullptr); }); - - } else { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { - vllm::rms_norm_dynamic_per_token_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, nullptr); - }); - } + }); } void rms_norm_dynamic_per_token_quant( @@ -157,3 +185,79 @@ void rms_norm_dynamic_per_token_quant( out, input, weight, scales, var_epsilon, scale_ub, residual); }); } + +// Residual add + RMS norm + dynamic per token +void rms_norm_per_block_quant_dispatch( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens, hidden_size / group_size] or + // [hidden_size / group_size, num_tokens] + int32_t group_size, + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional const& scale_ub, + std::optional& residual, bool is_scale_transposed) { + int32_t hidden_size = input.size(-1); + auto num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + const int max_block_size = (num_tokens <= 256) ? 512 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_per_block_quant_fp_dispatch", [&] { + using scalar_in_t = scalar_t; + VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { + VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { + VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + var_epsilon, hidden_size, + has_residual ? residual->data_ptr() + : nullptr); + }); + }); + }); + }); + }); +} + +void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, double const var_epsilon, + std::optional scale_ub, + std::optional residual, + int64_t group_size, bool is_scale_transposed) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + + if (scale_ub.has_value()) { + TORCH_CHECK(out.dtype() == kFp8Type); + } + TORCH_CHECK(weight.dtype() == input.dtype()); + TORCH_CHECK(scales.dtype() == torch::kFloat32); + if (residual) { + TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + } + + TORCH_CHECK(group_size == 128 || group_size == 64, + "Unsupported group size: ", group_size); + + rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size, + var_epsilon, scale_ub, residual, + is_scale_transposed); +} \ No newline at end of file diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 2d2fd771205c7..cb7adc3125734 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -9,6 +9,7 @@ #include "quant_conversions.cuh" #include "../../cub_helpers.h" +#include "../../cuda_compat.h" namespace vllm { @@ -43,62 +44,150 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, *rms = s_rms; } -template +__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid, + int64_t thread_in_warp, + int64_t reduced_elems) { + static_assert(WARP_SIZE == 32 || WARP_SIZE == 64); + if constexpr (WARP_SIZE == 64) { + if (thread_in_warp + 64 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 64]); + } + if (thread_in_warp + 32 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 32]); + if (thread_in_warp + 16 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 16]); + if (thread_in_warp + 8 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 8]); + if (thread_in_warp + 4 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 4]); + if (thread_in_warp + 2 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 2]); + if (thread_in_warp + 1 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 1]); + return val[tid]; +} + +template __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, - scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; - constexpr scalar_out_t qmax{quant_type_max_v}; - + int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, + int32_t const group_size = 0) { float block_absmax_val_maybe = 0.0f; - for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); - if constexpr (has_residual) { - x += static_cast(residual[token_offset + i]); - } - - x = static_cast(static_cast(x * rms) * weight[i]); - block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); - } - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; - } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // Shared memory store - all_token_scales[blockIdx.x] = scale; // Global output store - } + constexpr scalar_out_t qmax{quant_type_max_v}; __syncthreads(); + if (group_size > 0) { + __shared__ float s_max_vals[1024]; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t num_groups = hidden_size / group_size; + int64_t const threads_per_group = blockDim.x / num_groups; + int64_t const thread_in_group = threadIdx.x % threads_per_group; + int64_t const group_offset = threadIdx.x / threads_per_group * group_size; + int64_t const thread_offset = group_offset + thread_in_group; + int64_t const thread_end = + min(group_offset + group_size, static_cast(hidden_size)); + for (auto i = thread_offset; i < thread_end; i += threads_per_group) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); + } + s_max_vals[threadIdx.x] = block_absmax_val_maybe; + __syncthreads(); - *token_scale = s_token_scale; + int64_t const warp_size = WARP_SIZE; + int64_t const num_warps = blockDim.x / warp_size; + int64_t const warp_id = threadIdx.x / warp_size; + int64_t const thread_in_warp = threadIdx.x % warp_size; + int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps; + for (auto i = 0; i < groups_per_warp; ++i) { + int64_t const group_id = i * num_warps + warp_id; + if (group_id < num_groups) { + int64_t warp_start = group_id * threads_per_group; + int64_t const start = warp_start + thread_in_warp; + int64_t const warp_end = min(warp_start + threads_per_group, + static_cast(hidden_size)); + for (auto j = start; j + warp_size < warp_end; j += warp_size) { + s_max_vals[start] = + fmaxf(s_max_vals[start], s_max_vals[j + warp_size]); + } + warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp, + min(warp_end - warp_start, warp_size)); + } + } + __syncthreads(); + + if (thread_in_group == 0 && thread_offset < thread_end) { + block_absmax_val_maybe = s_max_vals[threadIdx.x]; + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + // Global output store + if constexpr (is_scale_transposed) { + all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + + blockIdx.x] = scale; + } else { + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = scale; + } + } + __syncthreads(); + } else { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + + for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); + } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store + } + __syncthreads(); + + *token_scale = s_token_scale; + } } template + bool has_residual = false, bool is_scale_transposed = false> __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, - float const rms, float const scale, + float const rms, float* const scale, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr) { + scalar_t* __restrict__ residual = nullptr, + int32_t const group_size = 0) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { float x = static_cast(input[token_offset + i]); @@ -109,8 +198,21 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, // Norm x = static_cast(static_cast(x * rms) * weight[i]); // Quant + // If groupwise is_scale_inverted is true, so we invert the scale here. + int64_t scale_idx = 0; + if (group_size > 0) { + if constexpr (is_scale_transposed) { + scale_idx = (i / group_size) * gridDim.x + blockIdx.x; + } else { + scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size; + } + } + auto scale_val = + (group_size > 0 + ? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]) + : *scale); output[token_offset + i] = - ScaledQuant::quant_fn(x, scale); + ScaledQuant::quant_fn(x, scale_val); } } @@ -178,95 +280,191 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, // Vectorized version of vllm::compute_dynamic_per_token_scales // hidden_size must be a multiple of 4 -template +template __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; - - // Vectorized input/weight/residual to better utilize memory bandwidth. - vec4_t const* vec_input = - reinterpret_cast const*>(&input[token_offset]); - vec4_t const* vec_weight = - reinterpret_cast const*>(weight); - vec4_t const* vec_residual = nullptr; - if constexpr (has_residual) { - vec_residual = - reinterpret_cast const*>(&residual[token_offset]); - } - constexpr scalar_out_t qmax{quant_type_max_v}; const int VEC_SIZE = 4; - int32_t const num_vec_elems = hidden_size >> 2; float block_absmax_val_maybe = 0.0f; -#pragma unroll 4 - for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { - vec4_t in = vec_input[i]; - vec4_t const w = vec_weight[i]; + // Vectorized input/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = nullptr; + vec4_t const* vec_weight = nullptr; + vec4_t const* vec_residual = nullptr; - vec4_t x; -#pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - x.val[j] = static_cast(in.val[j]); - } + if constexpr (group_size > 0) { + __shared__ float s_max_vals[1024]; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t const num_groups = hidden_size / group_size; + int64_t const threads_per_group = blockDim.x / num_groups; + int64_t const thread_in_group = threadIdx.x % threads_per_group; + int64_t const group_offset = + threadIdx.x / threads_per_group * (group_size >> 2); + int64_t const thread_offset = group_offset + thread_in_group; + int64_t const thread_end = min(group_offset + (group_size >> 2), + static_cast(hidden_size >> 2)); + vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_weight = reinterpret_cast const*>(weight); if constexpr (has_residual) { - vec4_t r = vec_residual[i]; + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + int32_t const num_vec_elems = thread_end; + +#pragma unroll 4 + for (auto i = thread_offset; i < num_vec_elems; i += threads_per_group) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { - x.val[j] += static_cast(r.val[j]); + x.val[j] = static_cast(in.val[j]); + } + + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } + +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + block_absmax_val_maybe = + fmaxf(block_absmax_val_maybe, + fabs(static_cast(x.val[j] * rms) * w.val[j])); } } + s_max_vals[threadIdx.x] = block_absmax_val_maybe; + __syncthreads(); + + int64_t const warp_size = WARP_SIZE; + int64_t const num_warps = blockDim.x / warp_size; + int64_t const warp_id = threadIdx.x / warp_size; + int64_t const thread_in_warp = threadIdx.x % warp_size; + int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps; + for (auto i = 0; i < groups_per_warp; ++i) { + int64_t const group_id = i * num_warps + warp_id; + if (group_id < num_groups) { + int64_t warp_start = group_id * threads_per_group; + int64_t const start = warp_start + thread_in_warp; + int64_t const warp_end = min(warp_start + threads_per_group, + static_cast(hidden_size)); + for (auto j = start; j + warp_size < warp_end; j += warp_size) { + s_max_vals[start] = + fmaxf(s_max_vals[start], s_max_vals[j + warp_size]); + } + warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp, + min(warp_end - warp_start, warp_size)); + } + } + __syncthreads(); + + if (thread_in_group == 0 && thread_offset < thread_end) { + block_absmax_val_maybe = s_max_vals[threadIdx.x]; + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + // Global output store + if constexpr (is_scale_transposed) { + all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + + blockIdx.x] = scale; + } else { + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = scale; + } + } + __syncthreads(); + + } else { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_weight = reinterpret_cast const*>(weight); + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + int32_t const num_vec_elems = (hidden_size >> 2); + +#pragma unroll 4 + for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; #pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - block_absmax_val_maybe = - fmaxf(block_absmax_val_maybe, - fabs(static_cast(x.val[j] * rms) * w.val[j])); + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] = static_cast(in.val[j]); + } + + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } + +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + block_absmax_val_maybe = + fmaxf(block_absmax_val_maybe, + fabs(static_cast(x.val[j] * rms) * w.val[j])); + } } - } - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // shared memory store + all_token_scales[blockIdx.x] = scale; // global output store } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // shared memory store - all_token_scales[blockIdx.x] = scale; // global output store - } - __syncthreads(); + __syncthreads(); - *token_scale = s_token_scale; + *token_scale = s_token_scale; + } } // hidden_size must be a multiple of 4 template + bool has_residual = false, bool is_scale_transposed = false, + int32_t group_size = 0> __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, - float const rms, float const scale, + float const rms, float* const scale, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; // Vectorized input/output/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = @@ -311,10 +509,26 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, } q8x4_t out; + + float scale_val; + + if constexpr (group_size > 0) { + int64_t const num_groups = hidden_size / group_size; + int64_t scale_idx = 0; + if constexpr (is_scale_transposed) { + scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x; + } else { + scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size; + } + scale_val = + is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]; + } else { + scale_val = *scale; + } #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { out.val[j] = ScaledQuant::quant_fn( - static_cast(x.val[j] * rms) * w.val[j], scale); + static_cast(x.val[j] * rms) * w.val[j], scale_val); } vec_output[i] = out; } diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index 2f52a6b7a0246..9f02f4f179741 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -617,7 +617,7 @@ struct MacheteCollectiveMma { // Same as upstream, should be kept the same when possible, not formatted for // easier comparison - // with `SwapAB ? N : M -> M` since we dont support SwapAB + // with `SwapAB ? N : M -> M` since we don't support SwapAB // clang-format off template static bool diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index 49cafcc32adc6..99fec8fd6febc 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -136,15 +136,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, void get_cutlass_moe_mm_problem_sizes_caller( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets) { + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt) { auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); // Swap-AB should be disabled for FP4 path - bool may_swap_ab = (!blockscale_offsets.has_value()) && - (topk_ids.numel() <= SWAP_AB_THRESHOLD); + bool may_swap_ab = + force_swap_ab.value_or((!blockscale_offsets.has_value()) && + (topk_ids.numel() <= SWAP_AB_THRESHOLD)); launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, atomic_buffer, num_experts, n, k, stream, diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index c5012a8669317..5de21cfbbaafb 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -80,7 +80,8 @@ void get_cutlass_moe_mm_data_caller( void get_cutlass_moe_mm_problem_sizes_caller( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets); + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt); void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, @@ -303,14 +304,15 @@ void get_cutlass_moe_mm_data( void get_cutlass_moe_mm_problem_sizes( const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, - const int64_t k, const std::optional& blockscale_offsets) { + const int64_t k, const std::optional& blockscale_offsets, + std::optional force_swap_ab = std::nullopt) { int32_t version_num = get_sm_version_num(); #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, - blockscale_offsets); + blockscale_offsets, force_swap_ab); return; #endif TORCH_CHECK_NOT_IMPLEMENTED( diff --git a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu index e3ab0676b254e..49d1b2086b8db 100644 --- a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -22,6 +22,62 @@ __device__ __forceinline__ float GroupReduceMax(float val) { return val; } +template +__device__ __forceinline__ float ComputeGroupScale( + const T* __restrict__ group_input, T* __restrict__ smem_group, + const int group_size, const int lane_id, const int threads_per_group, + const float eps, const float max_8bit) { + float local_absmax = eps; + + constexpr int vec_size = 16 / sizeof(T); + + // copy global -> shared & compute absmax + auto scalar_op_cache = [&] __device__(T & dst, const T& src) { + float abs_v = fabsf(static_cast(src)); + local_absmax = fmaxf(local_absmax, abs_v); + dst = src; + }; + + vllm::vectorize_with_alignment( + group_input, // in + smem_group, // out (shared) + group_size, // elements per group + lane_id, // thread id + threads_per_group, // stride in group + scalar_op_cache); // scalar handler + + local_absmax = GroupReduceMax(local_absmax); + + float y_s = local_absmax / max_8bit; + if constexpr (SCALE_UE8M0) { + y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + } + + return y_s; +} + +template +__device__ __forceinline__ void QuantizeGroup( + const T* __restrict__ smem_group, DST_DTYPE* __restrict__ group_output, + const int group_size, const int lane_id, const int threads_per_group, + const float y_s, const float min_8bit, const float max_8bit) { + constexpr int vec_size = 16 / sizeof(T); + + // quantize shared -> global 8-bit + auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { + float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); + dst = DST_DTYPE(q); + }; + + vllm::vectorize_with_alignment( + smem_group, // in (shared) + group_output, // out (global quant tensor) + group_size, // elements + lane_id, // tid + threads_per_group, // stride + scalar_op_quant); // scalar handler +} + template __global__ void per_token_group_quant_8bit_kernel( @@ -38,8 +94,6 @@ __global__ void per_token_group_quant_8bit_kernel( const int64_t global_group_id = block_group_id + local_group_id; const int64_t block_group_offset = global_group_id * group_size; - float local_absmax = eps; - using scale_element_t = float; static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); @@ -68,30 +122,9 @@ __global__ void per_token_group_quant_8bit_kernel( T* smem = reinterpret_cast(smem_raw); T* smem_group = smem + local_group_id * group_size; - constexpr int vec_size = 16 / sizeof(T); - using vec_t = vllm::vec_n_t; - - // copy global -> shared & compute absmax - auto scalar_op_cache = [&] __device__(T & dst, const T& src) { - float abs_v = fabsf(static_cast(src)); - local_absmax = fmaxf(local_absmax, abs_v); - dst = src; - }; - - vllm::vectorize_with_alignment( - group_input, // in - smem_group, // out (shared) - group_size, // elements per group - lane_id, // thread id - threads_per_group, // stride in group - scalar_op_cache); // scalar handler - - local_absmax = GroupReduceMax(local_absmax); - - float y_s = local_absmax / max_8bit; - if constexpr (SCALE_UE8M0) { - y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); - } + const float y_s = ComputeGroupScale( + group_input, smem_group, group_size, lane_id, threads_per_group, eps, + max_8bit); scale_element_t y_s_quant = y_s; @@ -101,19 +134,24 @@ __global__ void per_token_group_quant_8bit_kernel( __syncthreads(); - // quantize shared -> global 8-bit - auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { - float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); - dst = DST_DTYPE(q); - }; + QuantizeGroup(smem_group, group_output, group_size, lane_id, + threads_per_group, y_s, min_8bit, max_8bit); +} - vllm::vectorize_with_alignment( - smem_group, // in (shared) - group_output, // out (global quant tensor) - group_size, // elements - lane_id, // tid - threads_per_group, // stride - scalar_op_quant); // scalar handler +inline int GetGroupsPerBlock(int64_t num_groups) { + if (num_groups % 16 == 0) { + return 16; + } + if (num_groups % 8 == 0) { + return 8; + } + if (num_groups % 4 == 0) { + return 4; + } + if (num_groups % 2 == 0) { + return 2; + } + return 1; } void per_token_group_quant_8bit(const torch::Tensor& input, @@ -133,17 +171,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input, constexpr int THREADS_PER_GROUP = 16; - int groups_per_block = 1; - - if (num_groups % 16 == 0) { - groups_per_block = 16; - } else if (num_groups % 8 == 0) { - groups_per_block = 8; - } else if (num_groups % 4 == 0) { - groups_per_block = 4; - } else if (num_groups % 2 == 0) { - groups_per_block = 2; - } + const int groups_per_block = GetGroupsPerBlock(num_groups); auto dst_type = output_q.scalar_type(); const int num_blocks = num_groups / groups_per_block; @@ -206,6 +234,148 @@ void per_token_group_quant_8bit(const torch::Tensor& input, #undef LAUNCH_KERNEL } +template +__global__ void per_token_group_quant_8bit_packed_kernel( + const T* __restrict__ input, void* __restrict__ output_q, + unsigned int* __restrict__ output_s_packed, const int group_size, + const int num_groups, const int groups_per_block, const int groups_per_row, + const int mn, const int tma_aligned_mn, const float eps, + const float min_8bit, const float max_8bit) { + const int threads_per_group = 16; + const int64_t local_group_id = threadIdx.x / threads_per_group; + const int lane_id = threadIdx.x % threads_per_group; + + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + if (global_group_id >= num_groups) { + return; + } + + const int64_t block_group_offset = global_group_id * group_size; + + const T* group_input = input + block_group_offset; + DST_DTYPE* group_output = + static_cast(output_q) + block_group_offset; + + // shared memory to cache each group's data to avoid double DRAM reads. + extern __shared__ __align__(16) char smem_raw[]; + T* smem = reinterpret_cast(smem_raw); + T* smem_group = smem + local_group_id * group_size; + const float y_s = + ComputeGroupScale(group_input, smem_group, group_size, lane_id, + threads_per_group, eps, max_8bit); + + // pack 4 scales into a uint32 + if (lane_id == 0) { + // map flat group id to 2D indices (mn_idx, sf_k_idx) + const int sf_k_idx = static_cast(global_group_id % groups_per_row); + const int mn_idx = static_cast(global_group_id / groups_per_row); + + if (mn_idx < mn) { + // each uint32 in output_s_packed stores 4 packed scales + const int sf_k_pack_idx = sf_k_idx / 4; + const int pos = sf_k_idx % 4; + + // reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit + // exponent, and place it into the correct byte of the 32-bit word. + const unsigned int bits = __float_as_uint(y_s); + const unsigned int exponent = (bits >> 23u) & 0xffu; + const unsigned int contrib = exponent << (pos * 8u); + + const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx; + // atomically OR 8-bit exponent into the packed scales buffer + atomicOr(output_s_packed + out_idx, contrib); + } + } + + __syncthreads(); + + QuantizeGroup(smem_group, group_output, group_size, lane_id, + threads_per_group, y_s, min_8bit, max_8bit); +} + +void per_token_group_quant_8bit_packed(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s_packed, + int64_t group_size, double eps, + double min_8bit, double max_8bit) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(output_q.is_contiguous()); + + const int64_t k = input.size(-1); + TORCH_CHECK(k % group_size == 0, "Last dimension (", k, + ") must be divisible by group_size (", group_size, ")."); + + const int64_t mn = input.numel() / k; + const int64_t groups_per_row = k / group_size; + const int64_t num_groups = mn * groups_per_row; + + TORCH_CHECK(output_s_packed.dim() == 2, + "output_s_packed must be 2D, got dim=", output_s_packed.dim(), + "."); + + const int64_t k_num_packed_sfk = (groups_per_row + 3) / 4; + const int64_t tma_aligned_mn = ((mn + 3) / 4) * 4; + + TORCH_CHECK(output_s_packed.scalar_type() == at::ScalarType::Int, + "output_s_packed must have dtype int32 for UE8M0-packed scales."); + // DeepGEMM expects SFA scales in MN-major form with shape + // [mn, ceil_div(K, 128 * 4)] and TMA-aligned stride on the last + // dimension. + TORCH_CHECK(output_s_packed.size(0) == mn && + output_s_packed.size(1) == k_num_packed_sfk, + "output_s_packed shape must be [", mn, ", ", k_num_packed_sfk, + "], but got [", output_s_packed.size(0), ", ", + output_s_packed.size(1), "]."); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + constexpr int THREADS_PER_GROUP = 16; + + const int groups_per_block = GetGroupsPerBlock(num_groups); + + auto dst_type = output_q.scalar_type(); + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; + + // zero-initialize packed scales, since we use atomicOr to accumulate + // exponents from different groups. + output_s_packed.zero_(); + +#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + size_t smem_bytes = \ + static_cast(groups_per_block) * group_size * sizeof(T); \ + per_token_group_quant_8bit_packed_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + reinterpret_cast(output_s_packed.data_ptr()), \ + static_cast(group_size), static_cast(num_groups), \ + groups_per_block, static_cast(groups_per_row), \ + static_cast(mn), static_cast(tma_aligned_mn), \ + static_cast(eps), static_cast(min_8bit), \ + static_cast(max_8bit)); \ + } while (0) + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "per_token_group_quant_8bit_packed", ([&] { + if (dst_type == at::ScalarType::Float8_e4m3fn) { + LAUNCH_PACKED_KERNEL(scalar_t, __nv_fp8_e4m3); + } else if (dst_type == at::ScalarType::Char) { + LAUNCH_PACKED_KERNEL(scalar_t, int8_t); + } else { + TORCH_CHECK( + false, + "per_token_group_quant_8bit_packed only supports FP8/INT8 " + "outputs."); + } + })); + +#undef LAUNCH_PACKED_KERNEL +} + void per_token_group_quant_fp8(const torch::Tensor& input, torch::Tensor& output_q, torch::Tensor& output_s, int64_t group_size, double eps, double fp8_min, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 2ef579a1b7537..8ebe55cef391d 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, } #endif // defined(__HIP__GFX9__) TODO: Add NAVI support +// Find the min val of div2 that doesn't increase N/(div1*div2) int mindiv(int N, int div1, int div2) { int nPrRnd = div1 * div2; - int rnds0 = N / nPrRnd; - nPrRnd -= div1 * 3; - int rnds3 = N / nPrRnd; - nPrRnd -= div1; - int rnds4 = N / nPrRnd; - nPrRnd -= div1; - int rnds5 = N / nPrRnd; - nPrRnd -= div1; - int rnds6 = N / nPrRnd; - nPrRnd -= div1; - int rnds7 = N / nPrRnd; - nPrRnd -= div1; - int rnds8 = N / nPrRnd; - nPrRnd -= div1; - int rnds9 = N / nPrRnd; - nPrRnd -= div1; - int rtn = div2; - if (rnds0 == rnds3) rtn = div2 - 3; - if (rnds0 == rnds4) rtn = div2 - 4; - if (rnds0 == rnds5) rtn = div2 - 5; - if (rnds0 == rnds6) rtn = div2 - 6; - if (rnds0 == rnds7) rtn = div2 - 7; - if (rnds0 == rnds8) rtn = div2 - 8; - if (rnds0 == rnds9) rtn = div2 - 9; - return rtn; + int rnds[13]; + for (int i = 0; i < 13; i++) { + rnds[i] = (N + nPrRnd - 1) / nPrRnd; + nPrRnd -= div1; + } + for (int i = 12; i >= 0; i--) + if (rnds[0] == rnds[i]) return (div2 - i); } torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, @@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int max_lds_len = get_lds_size() / 2; -#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ - _N) \ - { \ - dim3 block(64, _WvPrGrp); \ - if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \ - wvSplitK_hf_sml_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - } else if (K_in * N_in <= max_lds_len * 1.2) { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \ - wvSplitK_hf_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - } else { \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \ - wvSplitK_hf_big_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - } \ +#define WVSPLITK(_YTILE, _UNRL, _N) \ + { \ + dim3 block(64, 16); \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \ + if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \ + wvSplitK_hf_sml_ \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ + else if (K_in * N_in <= max_lds_len * 1.2) \ + wvSplitK_hf_ \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ + else \ + wvSplitK_hf_big_ \ + <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, c, __wvPrGrp, CuCount); \ + } + +#define WVSPLIT_TILE(_sYT, __N) \ + { \ + bool fit_lds = (K_in * N_in <= max_lds_len); \ + if (_sYT <= 1) \ + WVSPLITK(1, 4, __N) \ + else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \ + WVSPLITK(2, 2, __N) \ + else if (_sYT <= 4 * 3) \ + WVSPLITK(3, 2, __N) \ + else if (__N == 4) \ + WVSPLITK(4, 1, __N) \ + else \ + WVSPLITK(4, 2, __N) \ } AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] { @@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, ? reinterpret_cast(in_bias->data_ptr()) : nullptr; fptype* c = reinterpret_cast(out_c.data_ptr()); + + // first shoot for biggest tile-size that keeps all simd busy, + // then cut the active waves to balance their distribution... + int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4); + switch (N_in) { case 1: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1) + WVSPLIT_TILE(sYT, 1) break; case 2: - WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2) + WVSPLIT_TILE(sYT, 2) break; case 3: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3) + WVSPLIT_TILE(sYT, 3) break; case 4: - WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4) + WVSPLIT_TILE(sYT, 4) break; default: throw std::runtime_error( diff --git a/csrc/sampler.cu b/csrc/sampler.cu index 410b8988f4939..fc2154beff9e0 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -44,41 +44,300 @@ __global__ void apply_repetition_penalties_kernel( } } -static inline __device__ uint16_t extractBinIdx(float x) { - union { - __half h; - uint16_t u16; - } tmp; - tmp.h = __float2half_rn(x); - tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000); - return 511 - (tmp.u16 >> 7); +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; } -template -__device__ void topKPerRowJob(const float* logits, const int rowStart, - const int rowEnd, const int rowIdx, - int* outIndices, int stride0, int stride1) { - // The number of elements per thread for the final top-k sort. - static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; - // The class to sort the elements during the final top-k sort. - using TopKSort = cub::BlockRadixSort; +template +static inline __device__ uint32_t extractBinIdx(float x) { + if constexpr (step == 0) { + __half hx = __float2half(x); + uint16_t bits = __half_as_ushort(hx); + bits = (bits & 0x8000) ? bits : ~bits & 0x7fff; + return bits >> 5; + } else { + uint32_t bits = __float_as_uint(x); + bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; + if constexpr (step == 1) { + return bits >> 21; + } else if constexpr (step == 2) { + return (bits >> 10) & 0x7ff; + } else if constexpr (step == 3) { + return bits & 0x3ff; + } + } +} + +template +static inline __device__ bool isPartialMatch(float x, uint32_t pattern) { + if constexpr (shift == 0) { + return true; + } + uint32_t bits = __float_as_uint(x); + bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; + return (bits ^ pattern) >> shift == 0; +} + +/** + * Map a Func over the input data, using vectorized load instructions if + * possible. + * + * @tparam T element type + * @tparam IdxT indexing type + * @tparam Func void (T x, IdxT idx) + * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing + * @param in the input data + * @param len the number of elements to read + * @param f the lambda taking two arguments (T x, IdxT idx) + */ +template +__device__ void vectorized_process(size_t thread_rank, size_t num_threads, + const T* in, idxT len, Func f) { + constexpr int WARP_SIZE = 32; + using WideT = float4; + if constexpr (sizeof(T) >= sizeof(WideT)) { + for (idxT i = thread_rank; i < len; i += num_threads) { + f(in[i], i); + } + } else { + static_assert(sizeof(WideT) % sizeof(T) == 0); + constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); + // TODO: it's UB + union { + WideT scalar; + T array[items_per_scalar]; + } wide; + + int skip_cnt = + (reinterpret_cast(in) % sizeof(WideT)) + ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / + sizeof(T)) + : 0; + if (skip_cnt > len) { + skip_cnt = len; + } + const WideT* in_cast = reinterpret_cast(in + skip_cnt); + const idxT len_cast = (len - skip_cnt) / items_per_scalar; + + for (idxT i = thread_rank; i < len_cast; i += num_threads) { + wide.scalar = in_cast[i]; + const idxT real_i = skip_cnt + i * items_per_scalar; +#pragma unroll + for (int j = 0; j < items_per_scalar; ++j) { + f(wide.array[j], real_i + j); + } + } + + static_assert(WARP_SIZE >= items_per_scalar); + // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt + // no need to use loop + if (thread_rank < skip_cnt) { + f(in[thread_rank], thread_rank); + } + // because len_cast = (len - skip_cnt) / items_per_scalar, + // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; + // and so + // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= + // WARP_SIZE no need to use loop + const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; + if (remain_i < len) { + f(in[remain_i], remain_i); + } + } +} + +template +__device__ bool processHistogramStep( + const int* indices, const float* logits, int rowEnd, uint32_t& logitPattern, + int& thresholdBinIdx, SmemOutputType& smemOutput, int* smemThresholdBinIdx, + int* smemFinalDstIdx, int* smemFinalBinSize, int* smemFoundTopKValues, + SmemFinalType& smemFinal, int stride1, int rowStart, int topK) { + // Clear the histogram. +#pragma unroll + for (int idx = threadIdx.x; idx < kNumBins; idx += kNumThreadsPerBlock) { + smemFinal.histo.data[idx] = 0; + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Update pattern + constexpr auto patternShift = step < 2 ? 0 : step == 2 ? 21 : 10; + if constexpr (step == 2) { + logitPattern = static_cast(thresholdBinIdx & 0x7ff) + << patternShift; + } else if constexpr (step == 3) { + logitPattern |= static_cast(thresholdBinIdx & 0x7ff) + << patternShift; + } + + auto distributeToBins = [&](float logit, int /* idx */ = 0) { + if (isPartialMatch(logit, logitPattern)) { + uint32_t binIdx = extractBinIdx(logit); + atomicAdd(&smemFinal.histo.data[binIdx], 1); + } + }; + + // Distribute the elements to the histogram bins. + if (stride1 == 1) { + vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart, + rowEnd - rowStart, distributeToBins); + } else { + for (int idx = rowStart + threadIdx.x; idx < rowEnd; + idx += kNumThreadsPerBlock) { + float logit = logits[idx * stride1]; + distributeToBins(logit, idx); + } + } + // Make sure the histogram is ready. + __syncthreads(); + + // Reads the value of the starting position in the smemOutput array + int lastValue = smemFoundTopKValues[0]; + + for (int round = 0; round < kNumBins / kNumThreadsPerBlock; round++) { + // Read the values from SMEM. + int idx = threadIdx.x + kNumThreadsPerBlock * round; + int binCount{0}; + binCount = smemFinal.histo.data[idx]; + + // Make sure each thread has read its value. + __syncthreads(); + + // Compute the prefix sum. + int prefixSum{0}, totalSum{0}; + using Scan = cub::BlockScan; + Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum); + + // Update the histogram with the prefix sums. + prefixSum += lastValue; + totalSum += lastValue; + smemFinal.histo.data[idx] = prefixSum; + + // Make sure the data is in shared memory. + __syncthreads(); + + // Find the last valid bin. + bool foundThreshold = false; + if (prefixSum < topK) { + int nextPrefixSum = threadIdx.x == kNumThreadsPerBlock - 1 + ? totalSum + : smemFinal.histo.data[idx + 1]; + + if (nextPrefixSum >= topK) { + smemThresholdBinIdx[0] = idx; + smemFinalBinSize[0] = nextPrefixSum - prefixSum; + foundThreshold = true; + } + } + + // Early exit: if any thread found the threshold, we can skip remaining + // rounds + if (__syncthreads_or(foundThreshold)) { + break; + } + + lastValue = totalSum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The threshold bin. + thresholdBinIdx = smemThresholdBinIdx[0]; + + auto processBins = [&](float logit, int idx) { + if (isPartialMatch(logit, logitPattern)) { + uint32_t binIdx = extractBinIdx(logit); + if (binIdx < thresholdBinIdx) { + // The element is part of the top-k selection + int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1); + + if constexpr (mergeBlocks) { + smemOutput[dstIdx] = indices[idx]; + } else if constexpr (multipleBlocksPerRow) { + smemOutput[dstIdx] = idx + rowStart; + reinterpret_cast(smemOutput + topK)[dstIdx] = logit; + } else { + smemOutput[dstIdx] = idx; + } + } + if constexpr (step < 3) { + // Only fill the final items for sorting if the threshold bin fits + if (binIdx == thresholdBinIdx && + smemFinalBinSize[0] <= kNumFinalItems) { + int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); + smemFinal.items.logits[dstIdx] = logit; + if constexpr (mergeBlocks) { + smemFinal.items.indices[dstIdx] = indices[idx]; + } else if constexpr (multipleBlocksPerRow) { + smemFinal.items.indices[dstIdx] = idx + rowStart; + } else { + smemFinal.items.indices[dstIdx] = idx; + } + } + } else { + if (binIdx == thresholdBinIdx) { + // The elements in the threshold bin share the same 32 bits at step 3 + int dstIdx = atomicAdd(&smemFinal.histo.data[binIdx], 1); + if (dstIdx < topK) { + if constexpr (mergeBlocks) { + smemOutput[dstIdx] = indices[idx]; + } else if constexpr (multipleBlocksPerRow) { + smemOutput[dstIdx] = idx + rowStart; + reinterpret_cast(smemOutput + topK)[dstIdx] = logit; + } else { + smemOutput[dstIdx] = idx; + } + } + } + } + } + }; + + if (stride1 == 1) { + vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart, + rowEnd - rowStart, processBins); + } else { + for (int idx = rowStart + threadIdx.x; idx < rowEnd; + idx += kNumThreadsPerBlock) { + float logit = logits[idx * stride1]; + processBins(logit, idx); + } + } + + // Make sure the elements are in shared memory. + __syncthreads(); + + // Check if we should continue to next step + return smemFinalBinSize[0] > kNumFinalItems; +} + +// Follows half - 11 - 11 - 10 bit iterations +template +static __device__ void topKPerRowJob(const int* indices, const float* logits, + int rowStart, int rowEnd, int* outIndices, + float* outLogits, int stride1, int topK) { // The number of slots for the final pass. - static constexpr int kNumFinalItems = 3072; + static constexpr int kNumFinalItems = 2048; // The number of elements per thread for the final sort. static constexpr int kNumFinalItemsPerThread = kNumFinalItems / kNumThreadsPerBlock; // The class to sort the elements during the final pass. using FinalSort = cub::BlockRadixSort; - + using FinalSortTempStorage = + std::conditional_t; // The class to compute the inclusive prefix-sum over the histogram. using Scan = cub::BlockScan; - // Shared memory to compute the block scan. - __shared__ typename Scan::TempStorage smemScan; - // The structure to store the final items (for the final pass). struct FinalItems { // Shared memory to store the indices for the final pass. @@ -87,200 +346,225 @@ __device__ void topKPerRowJob(const float* logits, const int rowStart, float logits[kNumFinalItems]; }; + struct Histogram { + typename Scan::TempStorage scan; + int data[kNumBins]; + }; + // Shared memory to compute the block sort. __shared__ union { FinalItems items; - typename FinalSort::TempStorage finalSort; - typename TopKSort::TempStorage topKSort; + FinalSortTempStorage finalSort; + Histogram histo; } smemFinal; - // Shared memory to store the histogram. - __shared__ int smemHistogram[kNumBins]; // Shared memory to store the selected indices. - __shared__ int smemIndices[kTopK]; + // If we are processing using multiple blocks, we need to store the logits and + // indices. + extern __shared__ int32_t smemOutput[]; + // Shared memory to store the threshold bin. __shared__ int smemThresholdBinIdx[1]; // Shared memory counter to register the candidates for the final phase. __shared__ int smemFinalDstIdx[1]; + // Shared memory to determine if the threshold bin fits in the final items. + __shared__ int smemFinalBinSize[1]; + // Shared memory to keep track of the top-k values found so far by the + // previous iterations + __shared__ int smemFoundTopKValues[1]; // The length of the row. int rowLen = rowEnd - rowStart; // Shortcut if the length of the row is smaller than Top-K. Indices are not // sorted by their corresponding logit. - if (rowLen <= kTopK) { + if (rowLen <= topK) { for (int rowIt = threadIdx.x; rowIt < rowLen; rowIt += kNumThreadsPerBlock) { - int idx = rowStart + rowIt; - outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; + if constexpr (multipleBlocksPerRow) { + outIndices[rowIt] = rowIt + rowStart; + outLogits[rowIt] = logits[rowIt + rowStart]; + } else { + outIndices[rowIt] = rowIt; + } } - for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; + for (int rowIt = rowLen + threadIdx.x; rowIt < topK; rowIt += kNumThreadsPerBlock) { - outIndices[rowIdx * kTopK + rowIt] = -1; + outIndices[rowIt] = -1; + if constexpr (multipleBlocksPerRow) { + outLogits[rowIt] = -FLT_MAX; + } } + return; } - - // Clear the histogram. - if (threadIdx.x < kNumBins) { - smemHistogram[threadIdx.x] = 0; - } - - // Make sure the histogram is ready. - __syncthreads(); - - // Fetch elements one-by-one. - for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; - rowIt += kNumThreadsPerBlock) { - uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]); - atomicAdd(&smemHistogram[idx], 1); - } - - // Make sure the histogram is ready. - __syncthreads(); - - // Read the values from SMEM. - int binCount{0}; - if (threadIdx.x < kNumBins) { - binCount = smemHistogram[threadIdx.x]; - } - - // Make sure each thread has read its value. - __syncthreads(); - - // Compute the prefix sum. - int prefixSum{0}, totalSum{0}; - Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum); - - // Update the histogram with the prefix sums. - if (threadIdx.x < kNumBins) { - smemHistogram[threadIdx.x] = prefixSum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // Find the last valid bin. - if (threadIdx.x < kNumBins) { - int nextPrefixSum = - threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1]; - if (prefixSum < kTopK && nextPrefixSum >= kTopK) { - smemThresholdBinIdx[0] = threadIdx.x; - } - } - - // Clear the counter to store the items for the final phase. + // Initialize values if (threadIdx.x == 0) { smemFinalDstIdx[0] = 0; + smemFoundTopKValues[0] = 0; + } + __syncthreads(); + int thresholdBinIdx = -1; + uint32_t logitPattern = 0; + + // Step 0: Process first 11 bits of half representation + bool continueToNextStep = + processHistogramStep<0, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); + + if (continueToNextStep) { + // Step 1: Process next 11 bits + continueToNextStep = + processHistogramStep<1, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); } - // Make sure the data is in shared memory. - __syncthreads(); + if (continueToNextStep) { + // Step 2: Process next 11 bits + continueToNextStep = + processHistogramStep<2, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); + } - // The threshold bin. - int thresholdBinIdx = smemThresholdBinIdx[0]; + if (continueToNextStep) { + // Step 3: Process last 10 bits + processHistogramStep<3, kNumThreadsPerBlock, kNumBins, kNumFinalItems, + multipleBlocksPerRow, mergeBlocks>( + indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, + smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, + smemFoundTopKValues, smemFinal, stride1, rowStart, topK); + } - // Fetch elements one-by-one and populate the shared memory buffers. - for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; - rowIt += kNumThreadsPerBlock) { - float logit = logits[rowIdx * stride0 + rowIt * stride1]; - uint16_t idx = extractBinIdx(logit); - if (idx < thresholdBinIdx) { - int dstIdx = atomicAdd(&smemHistogram[idx], 1); - smemIndices[dstIdx] = rowIt; - } else if (idx == thresholdBinIdx) { - int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); - if (dstIdx < kNumFinalItems) { - smemFinal.items.logits[dstIdx] = logit; - smemFinal.items.indices[dstIdx] = rowIt; + if (!continueToNextStep) { + // The histogram did not proceed to the final 10 bits, therefore we need to + // sort the final items The logits of the elements to be sorted in the final + // pass. + if constexpr (useRadixSort) { + // Sorting with radix sort + float finalLogits[kNumFinalItemsPerThread]; + // The indices of the elements to be sorted in the final pass. + int finalIndices[kNumFinalItemsPerThread]; + +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + finalLogits[ii] = -FLT_MAX; + } + + // Read the elements from SMEM. +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + if (srcIdx < smemFinalDstIdx[0]) { + finalLogits[ii] = smemFinal.items.logits[srcIdx]; + finalIndices[ii] = smemFinal.items.indices[srcIdx]; + } + } + // Make sure the shared memory has been read. + __syncthreads(); + + // Sort the elements. + FinalSort(smemFinal.finalSort) + .SortDescendingBlockedToStriped(finalLogits, finalIndices); + + // Copy the data back to the shared memory storage. + int baseIdx = smemFoundTopKValues[0]; + +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + int dstIdx = baseIdx + srcIdx; + + if (dstIdx < topK) { + smemOutput[dstIdx] = finalIndices[ii]; + if constexpr (multipleBlocksPerRow) { + reinterpret_cast(smemOutput + topK)[dstIdx] = + finalLogits[ii]; + } + } + } + } else { + // Sorting with insertion sort + auto baseIdx = smemFoundTopKValues[0]; + for (int i = threadIdx.x; i < smemFinalDstIdx[0]; + i += kNumThreadsPerBlock) { + int outIndex = 0; + auto logit = smemFinal.items.logits[i]; + for (int j = 0; j < smemFinalDstIdx[0]; j++) { + auto otherLogit = smemFinal.items.logits[j]; + if (logit < otherLogit || (logit == otherLogit && i < j)) { + outIndex++; + } + } + // Store if outIndex is in bounds + if (outIndex + baseIdx < topK) { + smemOutput[outIndex + baseIdx] = smemFinal.items.indices[i]; + if constexpr (multipleBlocksPerRow) { + reinterpret_cast(smemOutput + topK)[outIndex + baseIdx] = + smemFinal.items.logits[i]; + } + } + } + } + __syncthreads(); + } + + // Store to global memory. + for (int i = threadIdx.x; i < topK; i += kNumThreadsPerBlock) { + if constexpr (multipleBlocksPerRow) { + outIndices[i] = smemOutput[i]; + outLogits[i] = reinterpret_cast(smemOutput + topK)[i]; + } else { + if (stride1 == 1) { + // stride1 == 1 will use vectorized_process, which indexes already skip + // the rowStart. + outIndices[i] = smemOutput[i]; + } else { + outIndices[i] = smemOutput[i] - rowStart; } } } - - // Make sure the elements are in shared memory. - __syncthreads(); - - // The logits of the elements to be sorted in the final pass. - float finalLogits[kNumFinalItemsPerThread]; - // The indices of the elements to be sorted in the final pass. - int finalIndices[kNumFinalItemsPerThread]; - -// Init. -#pragma unroll - for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - finalLogits[ii] = -FLT_MAX; - } - -// Read the elements from SMEM. -#pragma unroll - for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; - if (srcIdx < smemFinalDstIdx[0]) { - finalLogits[ii] = smemFinal.items.logits[srcIdx]; - finalIndices[ii] = smemFinal.items.indices[srcIdx]; - } - } - - // Make sure the shared memory has been read. - __syncthreads(); - - // Sort the elements. - FinalSort(smemFinal.finalSort) - .SortDescendingBlockedToStriped(finalLogits, finalIndices); - - // Copy the data back to the shared memory storage. - int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0; -#pragma unroll - for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { - int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; - int dstIdx = baseIdx + srcIdx; - if (dstIdx < kTopK) { - smemIndices[dstIdx] = finalIndices[ii]; - } - } - - // Make sure the data is in shared memory. - __syncthreads(); - -// Store to global memory. -#pragma unroll - for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { - int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; - outIndices[offset] = - smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart; - } } -template -static __global__ void topKPerRow(const float* logits, const int* rowStarts, - const int* rowEnds, int* outIndices, - int stride0, int stride1) { +template +static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill( + const float* logits, const int* rowStarts, const int* rowEnds, + int* outIndices, int stride0, int stride1, const int topK, + const int offsetIndex) { // The number of bins in the histogram. - static constexpr int kNumBins = 512; - - // The top-k width. - static constexpr int kTopK = 2048; + static constexpr int kNumBins = 2048; // The row computed by this block. - int rowIdx = blockIdx.x; + int rowIdx = blockIdx.x + offsetIndex; // The range of logits within the row. int rowStart = rowStarts[rowIdx]; int rowEnd = rowEnds[rowIdx]; - topKPerRowJob( - logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); + // Local pointers to this block + outIndices += rowIdx * topK; + logits += rowIdx * stride0; + + topKPerRowJob( + nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK); } -template -static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, - int* outIndices, int stride0, - int stride1, int next_n) { +template +static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode( + const float* logits, const int* seqLens, int* outIndices, int stride0, + int stride1, const int topK, int next_n, float* outLogits = nullptr, + const int numBlocksToMerge = 0, const int* indices = nullptr) { // The number of bins in the histogram. - static constexpr int kNumBins = 512; - - // The top-k width. - static constexpr int kTopK = 2048; + static constexpr int kNumBins = 2048; // The row computed by this block. int rowIdx = blockIdx.x; @@ -290,8 +574,25 @@ static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, int seq_len = seqLens[rowIdx / next_n]; int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; - topKPerRowJob( - logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); + // Local pointers to this block + if constexpr (!multipleBlocksPerRow && !mergeBlocks) { + outIndices += rowIdx * topK; + } else if constexpr (multipleBlocksPerRow) { + const auto blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192 + rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192 + rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize; + outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK; + outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK; + } else if constexpr (mergeBlocks) { + rowEnd = numBlocksToMerge * topK; + indices += rowIdx * numBlocksToMerge * topK; + outIndices += rowIdx * topK; + } + logits += rowIdx * stride0; + + topKPerRowJob( + indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK); } } // namespace vllm @@ -339,28 +640,84 @@ void apply_repetition_penalties_( void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, const torch::Tensor& seqLens, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1) { - // Compute the results on the device. + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK) { + constexpr int kSortingAlgorithmThreshold = 12288; + constexpr int kSplitWorkThreshold = 200 * 1000; + constexpr int kNumThreadsPerBlock = 512; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const auto numColumns = logits.size(1); + + if (numColumns < kSortingAlgorithmThreshold) { + // Use insertion sort + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(topK), + static_cast(next_n)); + } else if (numColumns < kSplitWorkThreshold) { + // From this threshold, use radix sort instead + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(topK), + static_cast(next_n)); + } else { + // Long sequences are run in two steps + constexpr auto multipleBlocksPerRowConfig = 10; + + const auto outIndicesAux = + torch::empty({numRows, multipleBlocksPerRowConfig, topK}, + torch::dtype(torch::kInt32).device(logits.device())); + const auto outLogitsAux = + torch::empty({numRows, multipleBlocksPerRowConfig, topK}, + torch::dtype(torch::kFloat).device(logits.device())); + + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + outIndicesAux.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(topK), + static_cast(next_n), outLogitsAux.data_ptr()); + + constexpr int kNumThreadsPerBlockMerge = 1024; + vllm::topKPerRowDecode + <<>>( + outLogitsAux.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), multipleBlocksPerRowConfig * topK, 1, + static_cast(topK), static_cast(next_n), nullptr, + multipleBlocksPerRowConfig, outIndicesAux.data_ptr()); + } +} + +void top_k_per_row_prefill(const torch::Tensor& logits, + const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1, + int64_t topK) { + constexpr int kSortingAlgorithmThreshold = 12288; constexpr int kNumThreadsPerBlock = 512; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::topKPerRowDecode - <<>>( - logits.data_ptr(), seqLens.data_ptr(), - indices.data_ptr(), static_cast(stride0), - static_cast(stride1), static_cast(next_n)); -} + int numInsertionBlocks = + std::min(static_cast(numRows), kSortingAlgorithmThreshold); + vllm::topKPerRowPrefill + <<>>(logits.data_ptr(), rowStarts.data_ptr(), + rowEnds.data_ptr(), indices.data_ptr(), + static_cast(stride0), static_cast(stride1), + static_cast(topK), 0); -void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, - const torch::Tensor& rowEnds, torch::Tensor& indices, - int64_t numRows, int64_t stride0, int64_t stride1) { - // Compute the results on the device. - constexpr int kNumThreadsPerBlock = 512; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - vllm::topKPerRow - <<>>( - logits.data_ptr(), rowStarts.data_ptr(), - rowEnds.data_ptr(), indices.data_ptr(), - static_cast(stride0), static_cast(stride1)); + if (numRows > kSortingAlgorithmThreshold) { + int numRadixBlocks = numRows - kSortingAlgorithmThreshold; + vllm::topKPerRowPrefill + <<>>(logits.data_ptr(), rowStarts.data_ptr(), + rowEnds.data_ptr(), indices.data_ptr(), + static_cast(stride0), static_cast(stride1), + static_cast(topK), kSortingAlgorithmThreshold); + } } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 914227838558a..83d4943d62776 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -179,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Optimized top-k per row operation ops.def( - "top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " + "top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, " "Tensor! indices, int numRows, int stride0, " - "int stride1) -> ()"); - ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); + "int stride1, int topK) -> ()"); + ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill); ops.def( "top_k_per_row_decode(Tensor logits, int next_n, " - "Tensor seq_lens, Tensor! indices, int numRows, " - "int stride0, int stride1) -> ()"); + "Tensor seq_lens, Tensor! indices, " + "int numRows, int stride0, int stride1, int topK) -> ()"); ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); // Layernorm-quant @@ -215,6 +215,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, &rms_norm_dynamic_per_token_quant); + // Fused Layernorm + Block quant kernels + ops.def( + "rms_norm_per_block_quant(Tensor! result, Tensor input, " + "Tensor weight, Tensor! scale, float epsilon, " + "Tensor? scale_ub, Tensor!? residual, int group_size, " + "bool is_scale_transposed) -> ()"); + ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( @@ -342,6 +350,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); // conditionally compiled so impl registration is in source file + // CUTLASS w4a8 grouped GEMM + ops.def( + "cutlass_w4a8_moe_mm(" + " Tensor! out_tensors," + " Tensor a_tensors," + " Tensor b_tensors," + " Tensor a_scales," + " Tensor b_scales," + " Tensor b_group_scales," + " int b_group_size," + " Tensor expert_offsets," + " Tensor problem_sizes," + " Tensor a_strides," + " Tensor b_strides," + " Tensor c_strides," + " Tensor group_scale_strides," + " str? maybe_schedule" + ") -> ()"); + ops.def( + "cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, " + "Tensor)"); + // conditionally compiled so impl registration is in source file + #endif // Dequantization for GGML. @@ -458,7 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! problem_sizes1, " " Tensor! problem_sizes2, " " int num_experts, int n, int k, " - " Tensor? blockscale_offsets) -> ()"); + " Tensor? blockscale_offsets, " + " bool? force_swap_ab) -> ()"); ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, &get_cutlass_moe_mm_problem_sizes); @@ -617,6 +649,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("per_token_group_fp8_quant", torch::kCUDA, &per_token_group_quant_fp8); + // Compute per-token-group 8-bit quantized tensor and UE8M0-packed, + // TMA-aligned scales for DeepGEMM. + ops.def( + "per_token_group_fp8_quant_packed(Tensor input, Tensor! output_q, " + "Tensor! output_s_packed, int group_size, float eps, float fp8_min, " + "float fp8_max) -> ()"); + ops.impl("per_token_group_fp8_quant_packed", torch::kCUDA, + &per_token_group_quant_8bit_packed); + // Compute per-token-group INT8 quantized tensor and scaling factor. ops.def( "per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! " @@ -713,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); + cache_ops.def( + "cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, " + "Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int " + "batch_size) -> ()"); + cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA, + &cp_gather_and_upconvert_fp8_kv_cache); + cache_ops.def( "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor " "slot_mapping, " diff --git a/docker/Dockerfile b/docker/Dockerfile index eb7c105071c00..ae2624ace67b9 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -32,7 +32,7 @@ ARG DEADSNAKES_GPGKEY_URL # The PyPA get-pip.py script is a self contained script+zip file, that provides # both the installer script and the pip base85-encoded zip archive. This allows -# bootstrapping pip in environment where a dsitribution package does not exist. +# bootstrapping pip in environment where a distribution package does not exist. # # By parameterizing the URL for get-pip.py installation script, we allow # third-party to use their own copy of the script stored in a private mirror. @@ -73,15 +73,13 @@ ARG INSTALL_KV_CONNECTORS=false #################### BASE BUILD IMAGE #################### # prepare basic build environment FROM ${BUILD_BASE_IMAGE} AS base + ARG CUDA_VERSION ARG PYTHON_VERSION -ARG TARGETPLATFORM -ARG INSTALL_KV_CONNECTORS=false + ENV DEBIAN_FRONTEND=noninteractive -ARG GET_PIP_URL - -# Install system dependencies and uv, then create Python virtual environment +# Install system dependencies including build tools RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ @@ -107,32 +105,30 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && ln -s /opt/venv/bin/pip /usr/bin/pip \ && python3 --version && python3 -m pip --version -ARG PIP_INDEX_URL UV_INDEX_URL -ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL -ARG PYTORCH_CUDA_INDEX_BASE_URL -ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER - # Activate virtual environment and add uv to PATH ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH" ENV VIRTUAL_ENV="/opt/venv" -# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out -# Reference: https://github.com/astral-sh/uv/pull/1694 +# Environment for uv ENV UV_HTTP_TIMEOUT=500 ENV UV_INDEX_STRATEGY="unsafe-best-match" -# Use copy mode to avoid hardlink failures with Docker cache mounts ENV UV_LINK_MODE=copy -RUN <> /etc/environment -# Install Python and other dependencies +# Install Python and system dependencies RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ @@ -355,58 +421,104 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version -# Install CUDA development tools and build essentials for runtime JIT compilation +# Install CUDA development tools for runtime JIT compilation # (FlashInfer, DeepGEMM, EP kernels all require compilation at runtime) RUN CUDA_VERSION_DASH=$(echo $CUDA_VERSION | cut -d. -f1,2 | tr '.' '-') && \ apt-get update -y && \ apt-get install -y --no-install-recommends \ - cuda-nvcc-${CUDA_VERSION_DASH} \ - cuda-cudart-${CUDA_VERSION_DASH} \ - cuda-nvrtc-${CUDA_VERSION_DASH} \ - cuda-cuobjdump-${CUDA_VERSION_DASH} \ - libcublas-${CUDA_VERSION_DASH} && \ + cuda-nvcc-${CUDA_VERSION_DASH} \ + cuda-cudart-${CUDA_VERSION_DASH} \ + cuda-nvrtc-${CUDA_VERSION_DASH} \ + cuda-cuobjdump-${CUDA_VERSION_DASH} \ + libcurand-dev-${CUDA_VERSION_DASH} \ + libcublas-${CUDA_VERSION_DASH} \ + # Fixes nccl_allocator requiring nccl.h at runtime + # https://github.com/vllm-project/vllm/blob/1336a1ea244fa8bfd7e72751cabbdb5b68a0c11a/vllm/distributed/device_communicators/pynccl_allocator.py#L22 + libnccl-dev && \ rm -rf /var/lib/apt/lists/* +# Install uv for faster pip installs +RUN python3 -m pip install uv + +# Environment for uv +ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" +ENV UV_LINK_MODE=copy + +# Workaround for triton/pytorch issues +RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ + +# ============================================================ +# SLOW-CHANGING DEPENDENCIES BELOW +# These are the expensive layers that we want to cache +# ============================================================ + +# Install PyTorch and core CUDA dependencies +# This is ~2GB and rarely changes +ARG PYTORCH_CUDA_INDEX_BASE_URL +COPY requirements/common.txt /tmp/common.txt +COPY requirements/cuda.txt /tmp/requirements-cuda.txt +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r /tmp/requirements-cuda.txt \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') && \ + rm /tmp/requirements-cuda.txt /tmp/common.txt + +# Install FlashInfer pre-compiled kernel cache and binaries +# This is ~1.1GB and only changes when FlashInfer version bumps +# https://docs.flashinfer.ai/installation.html +ARG FLASHINFER_VERSION=0.5.3 +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system flashinfer-cubin==${FLASHINFER_VERSION} \ + && uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \ + --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ + && flashinfer show-config + +# ============================================================ +# OPENAI API SERVER DEPENDENCIES +# Pre-install these to avoid reinstalling on every vLLM wheel rebuild +# ============================================================ + +# Install gdrcopy (saves ~6s per build) +# TODO (huydhn): There is no prebuilt gdrcopy package on 12.9 at the moment +ARG GDRCOPY_CUDA_VERSION=12.8 +ARG GDRCOPY_OS_VERSION=Ubuntu22_04 +ARG TARGETPLATFORM +COPY tools/install_gdrcopy.sh /tmp/install_gdrcopy.sh +RUN set -eux; \ + case "${TARGETPLATFORM}" in \ + linux/arm64) UUARCH="aarch64" ;; \ + linux/amd64) UUARCH="x64" ;; \ + *) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \ + esac; \ + /tmp/install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}" && \ + rm /tmp/install_gdrcopy.sh + +# Install vllm-openai dependencies (saves ~2.6s per build) +# These are stable packages that don't depend on vLLM itself +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + BITSANDBYTES_VERSION="0.42.0"; \ + else \ + BITSANDBYTES_VERSION="0.46.1"; \ + fi; \ + uv pip install --system accelerate hf_transfer modelscope \ + "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3' + +# ============================================================ +# VLLM INSTALLATION (depends on build stage) +# ============================================================ + ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL ARG PYTORCH_CUDA_INDEX_BASE_URL ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER -# Install uv for faster pip installs -RUN --mount=type=cache,target=/root/.cache/uv \ - python3 -m pip install uv - -# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out -# Reference: https://github.com/astral-sh/uv/pull/1694 -ENV UV_HTTP_TIMEOUT=500 -ENV UV_INDEX_STRATEGY="unsafe-best-match" -# Use copy mode to avoid hardlink failures with Docker cache mounts -ENV UV_LINK_MODE=copy - -# Workaround for https://github.com/openai/triton/issues/2507 and -# https://github.com/pytorch/pytorch/issues/107960 -- hopefully -# this won't be needed for future versions of this docker image -# or future versions of triton. -RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ - # Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/uv \ uv pip install --system dist/*.whl --verbose \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') -# Install FlashInfer pre-compiled kernel cache and binaries -# https://docs.flashinfer.ai/installation.html -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system flashinfer-cubin==0.5.3 \ - && uv pip install --system flashinfer-jit-cache==0.5.3 \ - --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ - && flashinfer show-config - -COPY examples examples -COPY benchmarks benchmarks -COPY ./vllm/collect_env.py . - RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ uv pip list @@ -420,7 +532,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ echo "No DeepGEMM wheels to install; skipping."; \ fi' -# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH (https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/.ci/manywheel/build_cuda.sh#L141C14-L141C36) +# Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH # Install EP kernels wheels (pplx-kernels and DeepEP) that have been built in the `build` stage @@ -429,23 +541,17 @@ RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm uv pip install --system ep_kernels/dist/*.whl --verbose \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') -RUN --mount=type=bind,source=tools/install_gdrcopy.sh,target=/tmp/install_gdrcopy.sh,ro \ - set -eux; \ - case "${TARGETPLATFORM}" in \ - linux/arm64) UUARCH="aarch64" ;; \ - linux/amd64) UUARCH="x64" ;; \ - *) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \ - esac; \ - /tmp/install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}" - # CUDA image changed from /usr/local/nvidia to /usr/local/cuda in 12.8 but will # return to /usr/local/nvidia in 13.0 to allow container providers to mount drivers # consistently from the host (see https://github.com/vllm-project/vllm/issues/18859). # Until then, add /usr/local/nvidia/lib64 before the image cuda path to allow override. ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} +# Copy examples and benchmarks at the end to minimize cache invalidation +COPY examples examples +COPY benchmarks benchmarks +COPY ./vllm/collect_env.py . #################### vLLM installation IMAGE #################### - #################### TEST IMAGE #################### # image to run unit testing suite # note that this uses vllm installed by `pip` @@ -511,18 +617,12 @@ ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 -# install additional dependencies for openai api server +# install kv_connectors if requested RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=requirements/kv_connectors.txt,target=/tmp/kv_connectors.txt,ro \ if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \ uv pip install --system -r /tmp/kv_connectors.txt; \ - fi; \ - if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - BITSANDBYTES_VERSION="0.42.0"; \ - else \ - BITSANDBYTES_VERSION="0.46.1"; \ - fi; \ - uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.0' + fi ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 4aabe2661088a..1b6bdabc7a539 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -65,7 +65,6 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite -# Centralized v1 package - copied to both test and final stages COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1 # ----------------------- @@ -98,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system hf_transfer ENV HF_HUB_ENABLE_HF_TRANSFER=1 -# Copy in the v1 package +# Copy in the v1 package (for python-only install test group) COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1 # Source code is used in the `python_only_compile.sh` test @@ -130,9 +129,6 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ && pip uninstall -y vllm \ && uv pip install --system *.whl -# Copy in the v1 package -COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1 - ARG COMMON_WORKDIR # Copy over the benchmark scripts as well diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index adac43c6accbe..72d2053102c22 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -76,6 +76,9 @@ RUN python3 -m pip install -e tests/vllm_test_utils ENV NIXL_VERSION=0.7.0 RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py +# PyJWT-2.7.0 will influence some wheel behaviors, remove its dist-info to avoid conflicts +RUN rm /usr/lib/python3/dist-packages/PyJWT-2.7.0.dist-info/ -rf + # remove torch bundled oneccl to avoid conflicts RUN --mount=type=cache,target=/root/.cache/pip \ pip uninstall oneccl oneccl-devel -y diff --git a/docs/.nav.yml b/docs/.nav.yml index d30c0f12eba4c..835cc773e7599 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -5,11 +5,7 @@ nav: - Getting Started: - getting_started/quickstart.md - getting_started/installation - - Examples: - - examples/README.md - - Offline Inference: examples/offline_inference - - Online Serving: examples/online_serving - - Others: examples/others + - Examples: examples - General: - usage/v1_guide.md - usage/* @@ -63,6 +59,7 @@ nav: - CLI Reference: cli - Community: - community/* + - Governance: governance - Blog: https://blog.vllm.ai - Forum: https://discuss.vllm.ai - Slack: https://slack.vllm.ai diff --git a/docs/api/README.md b/docs/api/README.md index d3a141f327308..d51329ec2faa3 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -15,6 +15,7 @@ API documentation for vLLM's configuration classes. - [vllm.config.MultiModalConfig][] - [vllm.config.PoolerConfig][] - [vllm.config.StructuredOutputsConfig][] +- [vllm.config.ProfilerConfig][] - [vllm.config.ObservabilityConfig][] - [vllm.config.KVTransferConfig][] - [vllm.config.CompilationConfig][] diff --git a/docs/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png index b327eb2151f50..c8839eb93de95 100644 Binary files a/docs/assets/contributing/dockerfile-stages-dependency.png and b/docs/assets/contributing/dockerfile-stages-dependency.png differ diff --git a/docs/benchmarking/cli.md b/docs/benchmarking/cli.md index 44a4c40125952..dd5a12e408b02 100644 --- a/docs/benchmarking/cli.md +++ b/docs/benchmarking/cli.md @@ -84,7 +84,7 @@ Total input tokens: 1369 Total generated tokens: 2212 Request throughput (req/s): 1.73 Output token throughput (tok/s): 382.89 -Total Token throughput (tok/s): 619.85 +Total token throughput (tok/s): 619.85 ---------------Time to First Token---------------- Mean TTFT (ms): 71.54 Median TTFT (ms): 73.88 @@ -670,6 +670,35 @@ vllm bench serve \ +### 🧪 Hashing Benchmarks + +
+Show more + +Two helper scripts live in `benchmarks/` to compare hashing options used by prefix caching and related utilities. They are standalone (no server required) and help choose a hash algorithm before enabling prefix caching in production. + +- `benchmarks/benchmark_hash.py`: Micro-benchmark that measures per-call latency of three implementations on a representative `(bytes, tuple[int])` payload. + +```bash +python benchmarks/benchmark_hash.py --iterations 20000 --seed 42 +``` + +- `benchmarks/benchmark_prefix_block_hash.py`: End-to-end block hashing benchmark that runs the full prefix-cache hash pipeline (`hash_block_tokens`) across many fake blocks and reports throughput. + +```bash +python benchmarks/benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 --trials 5 +``` + +Supported algorithms: `sha256`, `sha256_cbor`, `xxhash`, `xxhash_cbor`. Install optional deps to exercise all variants: + +```bash +uv pip install xxhash cbor2 +``` + +If an algorithm’s dependency is missing, the script will skip it and continue. + +
+ ### ⚡ Request Prioritization Benchmark
diff --git a/docs/community/sponsors.md b/docs/community/sponsors.md index 8abb07caaab62..847b99cce45c9 100644 --- a/docs/community/sponsors.md +++ b/docs/community/sponsors.md @@ -18,16 +18,19 @@ Compute Resources: - Alibaba Cloud - AMD - Anyscale +- Arm - AWS - Crusoe Cloud - Databricks - DeepInfra - Google Cloud +- IBM - Intel - Lambda Lab - Nebius - Novita AI - NVIDIA +- Red Hat - Replicate - Roblox - RunPod diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index fdd9c317b022f..556d9f8b9420a 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -7,7 +7,7 @@ This guide covers optimization strategies and performance tuning for vLLM V1. ## Preemption -Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. +Due to the autoregressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. In such cases, vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes available again. When this occurs, you may see the following warning: diff --git a/docs/contributing/ci/nightly_builds.md b/docs/contributing/ci/nightly_builds.md new file mode 100644 index 0000000000000..a07b9c1c2fa4a --- /dev/null +++ b/docs/contributing/ci/nightly_builds.md @@ -0,0 +1,160 @@ +# Nightly Builds of vLLM Wheels + +vLLM maintains a per-commit wheel repository (commonly referred to as "nightly") at `https://wheels.vllm.ai` that provides pre-built wheels for every commit on the `main` branch since `v0.5.3`. This document explains how the nightly wheel index mechanism works. + +## Build and Upload Process on CI + +### Wheel Building + +Wheels are built in the `Release` pipeline (`.buildkite/release-pipeline.yaml`) after a PR is merged into the main branch, with multiple variants: + +- **Backend variants**: `cpu` and `cuXXX` (e.g., `cu129`, `cu130`). +- **Architecture variants**: `x86_64` and `aarch64`. + +Each build step: + +1. Builds the wheel in a Docker container. +2. Renames the wheel filename to use the correct manylinux tag (currently `manylinux_2_31`) for PEP 600 compliance. +3. Uploads the wheel to S3 bucket `vllm-wheels` under `/{commit_hash}/`. + +### Index Generation + +After uploading each wheel, the `.buildkite/scripts/upload-wheels.sh` script: + +1. **Lists all existing wheels** in the commit directory from S3 +2. **Generates indices** using `.buildkite/scripts/generate-nightly-index.py`: + - Parses wheel filenames to extract metadata (version, variant, platform tags). + - Creates HTML index files (`index.html`) for PyPI compatibility. + - Generates machine-readable `metadata.json` files. +3. **Uploads indices** to multiple locations (overriding existing ones): + - `/{commit_hash}/` - Always uploaded for commit-specific access. + - `/nightly/` - Only for commits on `main` branch (not PRs). + - `/{version}/` - Only for release wheels (no `dev` in its version). + +!!! tip "Handling Concurrent Builds" + The index generation script can handle multiple variants being built concurrently by always listing all wheels in the commit directory before generating indices, avoiding race conditions. + +## Directory Structure + +The S3 bucket structure follows this pattern: + +```text +s3://vllm-wheels/ +├── {commit_hash}/ # Commit-specific wheels and indices +│ ├── vllm-*.whl # All wheel files +│ ├── index.html # Project list (default variant) +│ ├── vllm/ +│ │ ├── index.html # Package index (default variant) +│ │ └── metadata.json # Metadata (default variant) +│ ├── cu129/ # Variant subdirectory +│ │ ├── index.html # Project list (cu129 variant) +│ │ └── vllm/ +│ │ ├── index.html # Package index (cu129 variant) +│ │ └── metadata.json # Metadata (cu129 variant) +│ ├── cu130/ # Variant subdirectory +│ ├── cpu/ # Variant subdirectory +│ └── .../ # More variant subdirectories +├── nightly/ # Latest main branch wheels (mirror of latest commit) +└── {version}/ # Release version indices (e.g., 0.11.2) +``` + +All built wheels are stored in `/{commit_hash}/`, while different indices are generated and reference them. +This avoids duplication of wheel files. + +For example, you can specify the following URLs to use different indices: + +- `https://wheels.vllm.ai/nightly/cu130` for the latest main branch wheels built with CUDA 13.0. +- `https://wheels.vllm.ai/{commit_hash}` for wheels built at a specific commit (default variant). +- `https://wheels.vllm.ai/0.12.0/cpu` for 0.12.0 release wheels built for CPU variant. + +Please note that not all variants are present on every commit. The available variants are subject to change over time, e.g., changing cu130 to cu131. + +### Variant Organization + +Indices are organized by variant: + +- **Default variant**: Wheels without variant suffix (i.e., built with the current `VLLM_MAIN_CUDA_VERSION`) are placed in the root. +- **Variant subdirectories**: Wheels with variant suffixes (e.g., `+cu130`, `.cpu`) are organized in subdirectories. +- **Alias to default**: The default variant can have an alias (e.g., `cu129` for now) for consistency and convenience. + +The variant is extracted from the wheel filename (as described in the [file name convention](https://packaging.python.org/en/latest/specifications/binary-distribution-format/#file-name-convention)): + +- The variant is encoded in the local version identifier (e.g. `+cu129` or `dev+g.cu130`). +- Examples: + - `vllm-0.11.2.dev278+gdbc3d9991-cp38-abi3-manylinux1_x86_64.whl` → default variant + - `vllm-0.10.2rc2+cu129-cp38-abi3-manylinux2014_aarch64.whl` → `cu129` variant + - `vllm-0.11.1rc8.dev14+gaa384b3c0.cu130-cp38-abi3-manylinux1_x86_64.whl` → `cu130` variant + +## Index Generation Details + +The `generate-nightly-index.py` script performs the following: + +1. **Parses wheel filenames** using regex to extract: + - Package name + - Version (with variant extracted) + - Python tag, ABI tag, platform tag + - Build tag (if present) +2. **Groups wheels by variant**, then by package name: + - Currently only `vllm` is built, but the structure supports multiple packages in the future. +3. **Generates HTML indices** (compliant with the [Simple repository API](https://packaging.python.org/en/latest/specifications/simple-repository-api/#simple-repository-api)): + - Top-level `index.html`: Lists all packages and variant subdirectories + - Package-level `index.html`: Lists all wheel files for that package + - Uses relative paths to wheel files for portability +4. **Generates metadata.json**: + - Machine-readable JSON containing all wheel metadata + - Includes `path` field with URL-encoded relative path to wheel file + - Used by `setup.py` to locate compatible pre-compiled wheels during Python-only builds + +### Special Handling for AWS Services + +The wheels and indices are directly stored on AWS S3, and we use AWS CloudFront as a CDN in front of the S3 bucket. + +Since S3 does not provide proper directory listing, to support PyPI-compatible simple repository API behavior, we deploy a CloudFront Function that: + +- redirects any URL that does not end with `/` and does not look like a file (i.e., does not contain a dot `.` in the last path segment) to the same URL with a trailing `/` +- appends `/index.html` to any URL that ends with `/` + +For example, the following requests would be handled as: + +- `/nightly` -> `/nightly/index.html` +- `/nightly/cu130/` -> `/nightly/cu130/index.html` +- `/nightly/index.html` or `/nightly/vllm.whl` -> unchanged + +!!! note "AWS S3 Filename Escaping" + + S3 will automatically escape filenames upon upload according to its [naming rule](https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html). The direct impact on vllm is that `+` in filenames will be converted to `%2B`. We take special care in the index generation script to escape filenames properly when generating the HTML indices and JSON metadata, to ensure the URLs are correct and can be directly used. + +## Usage of precompiled wheels in `setup.py` {#precompiled-wheels-usage} + +When installing vLLM with `VLLM_USE_PRECOMPILED=1`, the `setup.py` script: + +1. **Determines wheel location** via `precompiled_wheel_utils.determine_wheel_url()`: + - Env var `VLLM_PRECOMPILED_WHEEL_LOCATION` (user-specified URL/path) always takes precedence and skips all other steps. + - Determines the variant from `VLLM_MAIN_CUDA_VERSION` (can be overridden with env var `VLLM_PRECOMPILED_WHEEL_VARIANT`); the default variant will also be tried as a fallback. + - Determines the _base commit_ (explained later) of this branch (can be overridden with env var `VLLM_PRECOMPILED_WHEEL_COMMIT`). +2. **Fetches metadata** from `https://wheels.vllm.ai/{commit}/vllm/metadata.json` (for the default variant) or `https://wheels.vllm.ai/{commit}/{variant}/vllm/metadata.json` (for a specific variant). +3. **Selects compatible wheel** based on: + - Package name (`vllm`) + - Platform tag (architecture match) +4. **Downloads and extracts** precompiled binaries from the wheel: + - C++ extension modules (`.so` files) + - Flash Attention Python modules + - Triton kernel Python files +5. **Patches package_data** to include extracted files in the installation + +!!! note "What is the base commit?" + + The base commit is determined by finding the merge-base + between the current branch and upstream `main`, ensuring + compatibility between source code and precompiled binaries. + +_Note: it's users' responsibility to ensure there is no native code (e.g., C++ or CUDA) changes before using precompiled wheels._ + +## Implementation Files + +Key files involved in the nightly wheel mechanism: + +- **`.buildkite/release-pipeline.yaml`**: CI pipeline that builds wheels +- **`.buildkite/scripts/upload-wheels.sh`**: Script that uploads wheels and generates indices +- **`.buildkite/scripts/generate-nightly-index.py`**: Python script that generates PyPI-compatible indices +- **`setup.py`**: Contains `precompiled_wheel_utils` class for fetching and using precompiled wheels diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index a68d1f0162a10..d37501b86556f 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -113,8 +113,6 @@ See [this page](registration.md) for instructions on how to register your new mo ### How to support models with interleaving sliding windows? -For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `mistralai/Ministral-8B-Instruct-2410`), the scheduler will treat the model as a full-attention model, i.e., kv-cache of all tokens will not be dropped. This is to make sure prefix caching works with these models. Sliding window only appears as a parameter to the attention kernel computation. - To support a model with interleaving sliding windows, we need to take care of the following details: - Make sure the model's `config.json` contains `layer_types`. diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 65382afbe4f21..cbce14ce992ec 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -5,16 +5,15 @@ ## Profile with PyTorch Profiler -We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`. Additionally, you can control the profiling content by specifying the following environment variables: +We support tracing vLLM workers using the `torch.profiler` module. You can enable the torch profiler by setting `--profiler-config` +when launching the server, and setting the entries `profiler` to `'torch'` and `torch_profiler_dir` to the directory where you want to save the traces. Additionally, you can control the profiling content by specifying the following additional arguments in the config: -- `VLLM_TORCH_PROFILER_RECORD_SHAPES=1` to enable recording Tensor Shapes, off by default -- `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default -- `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default -- `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default -- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default -- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default - -The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set. +- `torch_profiler_record_shapes` to enable recording Tensor Shapes, off by default +- `torch_profiler_with_memory` to record memory, off by default +- `torch_profiler_with_stack` to enable recording stack information, on by default +- `torch_profiler_with_flops` to enable recording FLOPs, off by default +- `torch_profiler_use_gzip` to control gzip-compressing profiling files, on by default +- `torch_profiler_dump_cuda_time_total` to control dumping and printing the aggregated CUDA self time table, on by default When using `vllm bench serve`, you can enable profiling by passing the `--profile` flag. @@ -40,8 +39,7 @@ Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline #### OpenAI Server ```bash -VLLM_TORCH_PROFILER_DIR=./vllm_profile \ - vllm serve meta-llama/Llama-3.1-8B-Instruct +vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}' ``` vllm bench command: @@ -104,13 +102,12 @@ To profile the server, you will want to prepend your `vllm serve` command with ` ```bash # server -VLLM_TORCH_CUDA_PROFILE=1 \ nsys profile \ --trace-fork-before-exec=true \ --cuda-graph-trace=node \ --capture-range=cudaProfilerApi \ --capture-range-end repeat \ - vllm serve meta-llama/Llama-3.1-8B-Instruct + vllm serve meta-llama/Llama-3.1-8B-Instruct --profiler-config.profiler cuda # client vllm bench serve \ diff --git a/docs/deployment/docker.md b/docs/deployment/docker.md index 0e636c87f38a4..d70e0142e3202 100644 --- a/docs/deployment/docker.md +++ b/docs/deployment/docker.md @@ -82,7 +82,7 @@ DOCKER_BUILDKIT=1 docker build . \ ## Building for Arm64/aarch64 -A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper. At time of this writing, this should be considered **experimental**. Using the flag `--platform "linux/arm64"` will attempt to build for arm64. +A docker container can be built for aarch64 systems such as the Nvidia Grace-Hopper and Grace-Blackwell. Using the flag `--platform "linux/arm64"` will build for arm64. !!! note Multiple modules must be compiled, so this process can take a while. Recommend using `--build-arg max_jobs=` & `--build-arg nvcc_threads=` @@ -104,6 +104,25 @@ A docker container can be built for aarch64 systems such as the Nvidia Grace-Hop --build-arg RUN_WHEEL_CHECK=false ``` +For (G)B300, we recommend using CUDA 13, as shown in the following command. + +??? console "Command" + + ```bash + DOCKER_BUILDKIT=1 docker build \ + --build-arg CUDA_VERSION=13.0.1 \ + --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 \ + --build-arg max_jobs=256 \ + --build-arg nvcc_threads=2 \ + --build-arg RUN_WHEEL_CHECK=false \ + --build-arg torch_cuda_arch_list='9.0 10.0+PTX' \ + --platform "linux/arm64" \ + --tag vllm/vllm-gb300-openai:latest \ + --target vllm-openai \ + -f docker/Dockerfile \ + . + ``` + !!! note If you are building the `linux/arm64` image on a non-ARM host (e.g., an x86_64 machine), you need to ensure your system is set up for cross-compilation using QEMU. This allows your host machine to emulate ARM64 execution. diff --git a/docs/deployment/integrations/kthena.md b/docs/deployment/integrations/kthena.md new file mode 100644 index 0000000000000..483dd7474440b --- /dev/null +++ b/docs/deployment/integrations/kthena.md @@ -0,0 +1,333 @@ +# Kthena + +[**Kthena**](https://github.com/volcano-sh/kthena) is a Kubernetes-native LLM inference platform that transforms how organizations deploy and manage Large Language Models in production. Built with declarative model lifecycle management and intelligent request routing, it provides high performance and enterprise-grade scalability for LLM inference workloads. + +This guide shows how to deploy a production-grade, **multi-node vLLM** service on Kubernetes. + +We’ll: + +- Install the required components (Kthena + Volcano). +- Deploy a multi-node vLLM model via Kthena’s `ModelServing` CR. +- Validate the deployment. + +--- + +## 1. Prerequisites + +You need: + +- A Kubernetes cluster with **GPU nodes**. +- `kubectl` access with cluster-admin or equivalent permissions. +- **Volcano** installed for gang scheduling. +- **Kthena** installed with the `ModelServing` CRD available. +- A valid **Hugging Face token** if loading models from Hugging Face Hub. + +### 1.1 Install Volcano + +```bash +helm repo add volcano-sh https://volcano-sh.github.io/helm-charts +helm repo update +helm install volcano volcano-sh/volcano -n volcano-system --create-namespace +``` + +This provides the gang-scheduling and network topology features used by Kthena. + +### 1.2 Install Kthena + +```bash +helm install kthena oci://ghcr.io/volcano-sh/charts/kthena --version v0.1.0 --namespace kthena-system --create-namespace +``` + +- The `kthena-system` namespace is created. +- Kthena controllers and CRDs, including `ModelServing`, are installed and healthy. + +Validate: + +```bash +kubectl get crd | grep modelserving +``` + +You should see: + +```text +modelservings.workload.serving.volcano.sh ... +``` + +--- + +## 2. The Multi-Node vLLM `ModelServing` Example + +Kthena provides an example manifest to deploy a **multi-node vLLM cluster running Llama**. Conceptually this is equivalent to the vLLM production stack Helm deployment, but expressed with `ModelServing`. + +A simplified version of the example (`llama-multinode`) looks like: + +- `spec.replicas: 1` – one `ServingGroup` (one logical model deployment). +- `roles`: + - `entryTemplate` – defines **leader** pods that run: + - vLLM’s **multi-node cluster bootstrap script** (Ray cluster). + - vLLM **OpenAI-compatible API server**. + - `workerTemplate` – defines **worker** pods that join the leader’s Ray cluster. + +Key points from the example YAML: + +- **Image**: `vllm/vllm-openai:latest` (matches upstream vLLM images). +- **Command** (leader): + + ```yaml + command: + - sh + - -c + - > + bash /vllm-workspace/examples/online_serving/multi-node-serving.sh leader --ray_cluster_size=2; + python3 -m vllm.entrypoints.openai.api_server + --port 8080 + --model meta-llama/Llama-3.1-405B-Instruct + --tensor-parallel-size 8 + --pipeline-parallel-size 2 + ``` + +- **Command** (worker): + + ```yaml + command: + - sh + - -c + - > + bash /vllm-workspace/examples/online_serving/multi-node-serving.sh worker --ray_address=$(ENTRY_ADDRESS) + ``` + +--- + +## 3. Deploying Multi-Node llama vLLM via Kthena + +### 3.1 Prepare the Manifest + +**Recommended**: use a Secret instead of a raw env var: + +```bash +kubectl create secret generic hf-token \ + -n default \ + --from-literal=HUGGING_FACE_HUB_TOKEN='' +``` + +### 3.2 Apply the `ModelServing` + +```bash +cat <---`. + +The first number indicates `ServingGroup`. The second (`405b`) is the `Role`. The remaining indices identify the pod within the role. + +--- + +## 6. Accessing the vLLM OpenAI-Compatible API + +Expose the entry via a Service: + +```yaml +apiVersion: v1 +kind: Service +metadata: + name: llama-multinode-openai + namespace: default +spec: + selector: + modelserving.volcano.sh/name: llama-multinode + modelserving.volcano.sh/entry: "true" + # optionally further narrow to leader role if you label it + ports: + - name: http + port: 80 + targetPort: 8080 + type: ClusterIP +``` + +Port-forward from your local machine: + +```bash +kubectl port-forward svc/llama-multinode-openai 30080:80 -n default +``` + +Then: + +- List models: + + ```bash + curl -s http://localhost:30080/v1/models + ``` + +- Send a completion request (mirroring vLLM production stack docs): + + ```bash + curl -X POST http://localhost:30080/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Llama-3.1-405B-Instruct", + "prompt": "Once upon a time,", + "max_tokens": 10 + }' + ``` + +You should see an OpenAI-style response from vLLM. + +--- + +## 7. Clean Up + +To remove the deployment and its resources: + +```bash +kubectl delete modelserving llama-multinode -n default +``` + +If you’re done with the entire stack: + +```bash +helm uninstall kthena -n kthena-system # or your Kthena release name +helm uninstall volcano -n volcano-system +``` diff --git a/docs/deployment/integrations/production-stack.md b/docs/deployment/integrations/production-stack.md index 2f1894ccf0022..624e98a08c98d 100644 --- a/docs/deployment/integrations/production-stack.md +++ b/docs/deployment/integrations/production-stack.md @@ -4,7 +4,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le * **Upstream vLLM compatibility** – It wraps around upstream vLLM without modifying its code. * **Ease of use** – Simplified deployment via Helm charts and observability through Grafana dashboards. -* **High performance** – Optimized for LLM workloads with features like multi-model support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others. +* **High performance** – Optimized for LLM workloads with features like multimodel support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others. If you are new to Kubernetes, don't worry: in the vLLM production stack [repo](https://github.com/vllm-project/production-stack), we provide a step-by-step [guide](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) and a [short video](https://www.youtube.com/watch?v=EsTJbQtzj0g) to set up everything and get started in **4 minutes**! diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index abffb7bc5f948..05814cbad9bfc 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -14,6 +14,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following: - [InftyAI/llmaz](integrations/llmaz.md) - [KAITO](integrations/kaito.md) - [KServe](integrations/kserve.md) +- [Kthena](integrations/kthena.md) - [KubeRay](integrations/kuberay.md) - [kubernetes-sigs/lws](frameworks/lws.md) - [meta-llama/llama-stack](integrations/llamastack.md) diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index 7baadf8ba23cb..19c02fc88641c 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -41,7 +41,7 @@ These features allow the most flexibility for cudagraph capture and compilation * `NONE` — turn CUDA Graphs off. Good for debugging. * `PIECEWISE` — a single-mode strategy (and past default). It is the most flexible: attention or other CUDA Graphs-incompatible operations stay eager, everything else goes into CUDA Graphs. Requires piecewise compilation. * `FULL` — a single-mode strategy, which only captures full CUDA Graphs for non-uniform batches, then uniform-decode batches reuse the CUDA Graph of non-uniform batch of the same batch_size, since they are compatible; can be good for small models or workloads with small prompts. -* `FULL_DECODE_ONLY` — full CUDA Graph for uniform decode, no cudagraph for prefill/mixed etc; suitable for decode instances in a P/D setup where prefill is not as important, this way we can save the memory needed for `PIECEWISE` CUDA Graphs. +* `FULL_DECODE_ONLY` — full CUDA Graph for uniform decode, no cudagraph for prefill/mixed etc.; suitable for decode instances in a P/D setup where prefill is not as important, this way we can save the memory needed for `PIECEWISE` CUDA Graphs. * `FULL_AND_PIECEWISE` — (default mode) full CUDA Graph for uniform decode, piecewise CUDA Graphs for others; generally the most performant setting, especially for low latency with small models or MoEs, but also requires the most memory and takes the longest to capture. Defaults: If you’re on v1 with piecewise compilation, we default to `FULL_AND_PIECEWISE` for better performance, (for pooling models, it's still `PIECEWISE`). Otherwise, e.g. if piecewise compilation unavailable, we default to `NONE`. @@ -49,7 +49,7 @@ Defaults: If you’re on v1 with piecewise compilation, we default to `FULL_AND_ While `NONE` , `PIECEWISE`, and `FULL` are single-mode configurations and simply equivalent to past implementations of eager execution, piecewise CUDA Graphs, and full CUDA Graphs respectively, `FULL_DECODE_ONLY` and `FULL_AND_PIECEWISE` are newly appended dual-mode configurations, which require dispatching to switch between concrete runtime modes according to runtime batches dynamically. !!! note - Here, the single-modes `NONE`, `PIECEWISE`, and `FULL` are treated as the runtime modes for CUDA Graphs dispatching. If using a dual-mode, the dispatcher will always dispatch to one of its member modes (plus a potantial `NONE` if no suitable CUDA Graph available), depending on the batch composition. + Here, the single-modes `NONE`, `PIECEWISE`, and `FULL` are treated as the runtime modes for CUDA Graphs dispatching. If using a dual-mode, the dispatcher will always dispatch to one of its member modes (plus a potential `NONE` if no suitable CUDA Graph available), depending on the batch composition. While cascade attention is not cudagraph compatible, it is now compatible with all possible cudagraph mode configurations. If a batch uses cascade attention, it always gets dispatched to `PIECEWISE` mode if available (otherwise `NONE`). diff --git a/docs/design/debug_vllm_compile.md b/docs/design/debug_vllm_compile.md index e565f17da62ad..731e542a0307b 100644 --- a/docs/design/debug_vllm_compile.md +++ b/docs/design/debug_vllm_compile.md @@ -86,7 +86,7 @@ LLM(model, enforce_eager=True) ``` To turn off just torch.compile, pass `mode = NONE` to the compilation config. -(`-cc` is short for `--compilation_config`; `-O.*` dotted syntax is deprecated): +(`-cc` is short for `--compilation_config`): ```sh # Online diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index b4a30cda35a01..5a86940fa9f13 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -79,7 +79,7 @@ The `post_process*` methods take `PoolingRequestOutput` objects as input and gen The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py). -An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/pooling/prithvi_geospatial_mae.py](../../examples/online_serving/pooling/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py)) inference examples. +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_client.py](../../examples/pooling/plugin/prithvi_geospatial_mae_client.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples. ## Using an IO Processor plugin diff --git a/docs/design/metrics.md b/docs/design/metrics.md index 313c9aaebd26b..2722e12fdaeaf 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -21,30 +21,20 @@ The mental model is that server-level metrics help explain the values of request ### v1 Metrics -In v1, the following metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix: +In v1, an extensive set of metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix, for example: - `vllm:num_requests_running` (Gauge) - Number of requests currently running. -- `vllm:num_requests_waiting` (Gauge) - Number of requests currently waiting. - `vllm:kv_cache_usage_perc` (Gauge) - Fraction of used KV cache blocks (0–1). - `vllm:prefix_cache_queries` (Counter) - Number of prefix cache queries. - `vllm:prefix_cache_hits` (Counter) - Number of prefix cache hits. -- `vllm:mm_cache_queries` (Counter) - (For multimodal models) Number of multimodal cache queries. -- `vllm:mm_cache_hits` (Counter) - (For multimodal models) Number of multimodal cache hits. -- `vllm:num_preemptions_total` (Counter) - Number of preemptions. - `vllm:prompt_tokens_total` (Counter) - Total number of prompt tokens processed. - `vllm:generation_tokens_total` (Counter) - Total number of generated tokens. -- `vllm:iteration_tokens_total` (Histogram) - Histogram of tokens processed in each engine step. -- `vllm:cache_config_info` (Gauge) - Information about the cache configuration. - `vllm:request_success_total` (Counter) - Number of finished requests (by finish reason). - `vllm:request_prompt_tokens` (Histogram) - Histogram of input prompt token counts. - `vllm:request_generation_tokens` (Histogram) - Histogram of generation token counts. -- `vllm:request_params_n` (Histogram) - Histogram of request parameter n. -- `vllm:request_params_max_tokens` - (Histogram) - Histogram of max_tokens parameter in requests. - `vllm:time_to_first_token_seconds` (Histogram) - Time to first token (TTFT). - `vllm:inter_token_latency_seconds` (Histogram) - Inter-token latency. - `vllm:e2e_request_latency_seconds` (Histogram) - End-to-end request latency. -- `vllm:request_queue_time_seconds` (Histogram) - Time spent in the queue. -- `vllm:request_inference_time_seconds` (Histogram) - Request inference time. - `vllm:request_prefill_time_seconds` (Histogram) - Request prefill time. - `vllm:request_decode_time_seconds` (Histogram) - Request decode time. @@ -57,15 +47,15 @@ vLLM also provides [a reference example](../../examples/online_serving/prometheu The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important: - `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds. -- `vllm:prompt_tokens_total` - Prompt tokens. -- `vllm:generation_tokens_total` - Generation tokens. +- `vllm:prompt_tokens` - Prompt tokens. +- `vllm:generation_tokens` - Generation tokens. - `vllm:time_per_output_token_seconds` - Inter-token latency (Time Per Output Token, TPOT) in seconds. - `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds. - `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in the RUNNING, WAITING, and SWAPPED states. -- `vllm:gpu_cache_usage_perc` - Percentage of used cache blocks by vLLM. +- `vllm:kv_cache_usage_perc` - Percentage of used cache blocks by vLLM. - `vllm:request_prompt_tokens` - Request prompt length. - `vllm:request_generation_tokens` - Request generation length. -- `vllm:request_success_total` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached. +- `vllm:request_success` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached. - `vllm:request_queue_time_seconds` - Queue time. - `vllm:request_prefill_time_seconds` - Requests prefill time. - `vllm:request_decode_time_seconds` - Requests decode time. @@ -263,6 +253,29 @@ record: - End-to-end latency - the interval between frontend `arrival_time` and the frontend receiving the final token. +### KV Cache Residency Metrics + +We also emit a set of histograms that describe how long sampled KV cache +blocks stay resident and how often they are reused. Sampling +(`--kv-cache-metrics-sample`) keeps the overhead tiny; when a block is +chosen we record: + +- `lifetime` – allocation ⟶ eviction +- `idle before eviction` – last touch ⟶ eviction +- `reuse gaps` – the pauses between touches when the block gets reused + +Those map directly to the Prometheus metrics: + +- `vllm:kv_block_lifetime_seconds` – how long each sampled block exists. +- `vllm:kv_block_idle_before_evict_seconds` – idle tail after the final access. +- `vllm:kv_block_reuse_gap_seconds` – time between consecutive touches. + +The engine core only ships raw eviction events via `SchedulerStats`; the +frontend drains them, turns them into Prometheus observations, and also +exposes the same data through `LLM.get_metrics()` when logging is on. +Looking at lifetime and idle time on one chart makes it easy to spot +stranded cache or workloads that pin prompts for a long decode. + ### Metrics Publishing - Logging The `LoggingStatLogger` metrics publisher outputs a log `INFO` message @@ -548,9 +561,9 @@ model and then validate those tokens with the larger model. - `vllm:spec_decode_draft_acceptance_rate` (Gauge) - `vllm:spec_decode_efficiency` (Gauge) -- `vllm:spec_decode_num_accepted_tokens_total` (Counter) -- `vllm:spec_decode_num_draft_tokens_total` (Counter) -- `vllm:spec_decode_num_emitted_tokens_total` (Counter) +- `vllm:spec_decode_num_accepted_tokens` (Counter) +- `vllm:spec_decode_num_draft_tokens` (Counter) +- `vllm:spec_decode_num_emitted_tokens` (Counter) There is a PR under review () to add "prompt lookup (ngram)" speculative decoding to v1. Other techniques will follow. We should diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 44aaa65218cc4..48341d199cb80 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],
[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | | flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | -| deep gemm+triton2 | standard,
batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],
[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | @@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor | backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | |---------|-----------------------------------------|----------------------------------------------| | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | -| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`,
`BatchedMarlinExperts` | +| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`CutlassBatchedExpertsFp8`,
`BatchedMarlinExperts` | | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | diff --git a/docs/design/optimization_levels.md b/docs/design/optimization_levels.md index 940286071ef3c..4987c1820ad32 100644 --- a/docs/design/optimization_levels.md +++ b/docs/design/optimization_levels.md @@ -4,7 +4,7 @@ ## Overview -vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechnaism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out of the box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten. +vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechanism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out-of-the-box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten. ## Level Summaries and Usage Examples ```bash diff --git a/docs/design/paged_attention.md b/docs/design/paged_attention.md index d87b2a639df12..5cc5878425515 100644 --- a/docs/design/paged_attention.md +++ b/docs/design/paged_attention.md @@ -36,7 +36,7 @@ the input pointers `q`, `k_cache`, and `v_cache`, which point to query, key, and value data on global memory that need to be read and processed. The output pointer `out` points to global memory where the result should be written. These four pointers actually -refer to multi-dimensional arrays, but each thread only accesses the +refer to multidimensional arrays, but each thread only accesses the portion of data assigned to it. I have omitted all other runtime parameters here for simplicity. @@ -229,7 +229,7 @@ manner. ## QK -As shown the pseudo code below, before the entire for loop block, we +As shown the pseudocode below, before the entire for loop block, we fetch the query data for one token and store it in `q_vecs`. Then, in the outer for loop, we iterate through different `k_ptrs` that point to different tokens and prepare the `k_vecs` in the inner for @@ -403,7 +403,7 @@ for ... { // Iteration over different blocks. } ``` -As shown in the above pseudo code, in the outer loop, similar to +As shown in the above pseudocode, in the outer loop, similar to `k_ptr`, `logits_vec` iterates over different blocks and reads `V_VEC_SIZE` elements from `logits`. In the inner loop, each thread reads `V_VEC_SIZE` elements from the same tokens as a diff --git a/docs/design/plugin_system.md b/docs/design/plugin_system.md index 3485c40c36811..b0ca2dad23d5b 100644 --- a/docs/design/plugin_system.md +++ b/docs/design/plugin_system.md @@ -152,5 +152,5 @@ The interface for the model/module may change during vLLM's development. If you ## Deprecation announcement !!! warning "Deprecations" - - `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It will be removed in v0.13.0 or v1.0.0. - - `_Backend` in `vllm.attention` is deprecated. It will be removed in v0.13.0 or v1.0.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead. + - `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It has been removed in v0.13.0. + - `_Backend` in `vllm.attention` is deprecated. It has been removed in v0.13.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead. diff --git a/docs/design/prefix_caching.md b/docs/design/prefix_caching.md index cf792fdabe1a6..6f2eb3062b3b9 100644 --- a/docs/design/prefix_caching.md +++ b/docs/design/prefix_caching.md @@ -22,8 +22,8 @@ In the example above, the KV cache in the first block can be uniquely identified We only cache full blocks. !!! note "Note 2" - The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value. To avoid any hash collisions **in a multi-tenant setup, we advise to use SHA256** as hash function instead of the default builtin hash. - SHA256 is supported since vLLM v0.8.3 and must be enabled with a command line argument. It comes with a performance impact of about 100-200ns per token (~6ms for 50k tokens of context). + The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value. To avoid any hash collisions **in a multi-tenant setup, we use SHA256** as hash function instead of the builtin hash. + SHA256 is supported since vLLM v0.8.3 and the default since v0.10.2. It comes with a negligible performance impact of about 75ns per token (<4ms for 50k tokens of context). **A hashing example with multi-modality inputs** In this example, we illustrate how prefix caching works with multi-modality inputs (e.g., images). Assuming we have a request with the following messages: diff --git a/docs/features/README.md b/docs/features/README.md index 5faf3768f3214..e9e5232929b72 100644 --- a/docs/features/README.md +++ b/docs/features/README.md @@ -54,7 +54,7 @@ th:not(:first-child) { | beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | | | [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | -\* Chunked prefill and prefix caching are only applicable to last-token pooling. +\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention. ^ LoRA is only applicable to the language backbone of multimodal models. ### Feature x Hardware @@ -68,8 +68,8 @@ th:not(:first-child) { | CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/26970) | | [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | enc-dec | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | -| [mm](multimodal_inputs.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/issues/26965) | -| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | +| [mm](multimodal_inputs.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | async output | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | diff --git a/docs/features/disagg_encoder.md b/docs/features/disagg_encoder.md index 7d40af7069822..f18a0e85e4b3b 100644 --- a/docs/features/disagg_encoder.md +++ b/docs/features/disagg_encoder.md @@ -32,14 +32,14 @@ Design doc: NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future. + +Now you can send requests to the proxy server through port 8000. + +## Environment Variables + +- `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server + - Default: 8998 + - Required only for prefiller instances + - Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine + - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank + - Used for the decoder notifying the prefiller + +- `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional) + - Default: 480 + - If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely. + +## KV Role Options + +- **kv_producer**: For prefiller instances that generate KV caches +- **kv_consumer**: For decoder instances that consume KV caches from prefiller +- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined. diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index 4656ee43ea251..c3fd726e9938c 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -443,7 +443,9 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd print(generated_text) ``` -#### Audio Embeddings +For Qwen3-VL, the `image_embeds` should contain both the base image embedding and deepstack features. + +#### Audio Embedding Inputs You can pass pre-computed audio embeddings similar to image embeddings: @@ -795,14 +797,12 @@ The following example demonstrates how to pass image embeddings to the OpenAI se ??? code ```python + from vllm.utils.serial_utils import tensor2base64 + image_embedding = torch.load(...) grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct - buffer = io.BytesIO() - torch.save(image_embedding, buffer) - buffer.seek(0) - binary_data = buffer.read() - base64_image_embedding = base64.b64encode(binary_data).decode('utf-8') + base64_image_embedding = tensor2base64(image_embedding) client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") @@ -892,5 +892,11 @@ For Online Serving, you can also skip sending media if you expect cache hits wit ``` !!! note - Only one message can contain `{"type": "image_embeds"}`. + Multiple messages can now contain `{"type": "image_embeds"}`, enabling you to pass multiple image embeddings in a single request (similar to regular images). The number of embeddings is limited by `--limit-mm-per-prompt`. + + **Important**: The embedding shape format differs based on the number of embeddings: + + - **Single embedding**: 3D tensor of shape `(1, feature_size, hidden_size)` + - **Multiple embeddings**: List of 2D tensors, each of shape `(feature_size, hidden_size)` + If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc. diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index f0e25e31aa0b3..601205e1ed0b1 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -22,7 +22,7 @@ python tools/install_nixl_from_source_ubuntu.py NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables: ```bash -# Example UCX configuration, adjust according to your enviroment +# Example UCX configuration, adjust according to your environment export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1" ``` @@ -146,6 +146,8 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ --decoder-ports 8000 8000 ``` +For multi-host DP deployment, only need to provide the host/port of the head instances. + ### KV Role Options - **kv_producer**: For prefiller instances that generate KV caches diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md index 7b5287bad3bb8..8b4dcf01969ae 100644 --- a/docs/features/quantization/README.md +++ b/docs/features/quantization/README.md @@ -14,7 +14,7 @@ Contents: - [INT4 W4A16](int4.md) - [INT8 W8A8](int8.md) - [FP8 W8A8](fp8.md) -- [NVIDIA TensorRT Model Optimizer](modelopt.md) +- [NVIDIA Model Optimizer](modelopt.md) - [AMD Quark](quark.md) - [Quantized KV Cache](quantized_kvcache.md) - [TorchAO](torchao.md) diff --git a/docs/features/quantization/modelopt.md b/docs/features/quantization/modelopt.md index c48ccb719a79d..b02d5ba9e89a2 100644 --- a/docs/features/quantization/modelopt.md +++ b/docs/features/quantization/modelopt.md @@ -1,6 +1,6 @@ -# NVIDIA TensorRT Model Optimizer +# NVIDIA Model Optimizer -The [NVIDIA TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a library designed to optimize models for inference with NVIDIA GPUs. It includes tools for Post-Training Quantization (PTQ) and Quantization Aware Training (QAT) of Large Language Models (LLMs), Vision Language Models (VLMs), and diffusion models. +The [NVIDIA Model Optimizer](https://github.com/NVIDIA/Model-Optimizer) is a library designed to optimize models for inference with NVIDIA GPUs. It includes tools for Post-Training Quantization (PTQ) and Quantization Aware Training (QAT) of Large Language Models (LLMs), Vision Language Models (VLMs), and diffusion models. We recommend installing the library with: @@ -10,7 +10,7 @@ pip install nvidia-modelopt ## Quantizing HuggingFace Models with PTQ -You can quantize HuggingFace models using the example scripts provided in the TensorRT Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory. +You can quantize HuggingFace models using the example scripts provided in the Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory. Below is an example showing how to quantize a model using modelopt's PTQ API: diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 08a0dd69efa90..93cca23856a9b 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -18,6 +18,7 @@ vLLM currently supports the following reasoning models: | [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ | | [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ | | [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` | ✅ | +| [Holo2 series](https://huggingface.co/collections/Hcompany/holo2) | `holo2` | `json`, `regex` | ✅ | | [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `json`, `regex` | ✅ | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | | [MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2) | `minimax_m2_append_think` | `json`, `regex` | ✅ | @@ -28,6 +29,7 @@ vLLM currently supports the following reasoning models: IBM Granite 3.2 and DeepSeek-V3.1 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. The reasoning feature for the Qwen3 series is enabled by default. To disable it, you must pass `enable_thinking=False` in your `chat_template_kwargs`. DeepSeek-V3.1 tool calling is supported in non-thinking mode. + Holo2 reasoning is enabled by default. To disable it, you must also pass `thinking=False` in your `chat_template_kwargs`. ## Quickstart @@ -297,6 +299,9 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner def is_reasoning_end(self, input_ids: list[int]) -> bool: return self.end_token_id in input_ids + + def is_reasoning_end_streaming(self, input_ids: list[int], delta_ids: list[int]) -> bool: + return self.end_token_id in delta_token_ids ... ``` diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 7d52891bea7b9..3ac987559e622 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -61,7 +61,7 @@ Now let´s see an example for each of the cases, starting with the `choice`, as print(completion.choices[0].message.content) ``` -The next example shows how to use the `regex`. The idea is to generate an email address, given a simple regex template: +The next example shows how to use the `regex`. The supported regex syntax depends on the structured output backend. For example, `xgrammar`, `guidance`, and `outlines` use Rust-style regex, while `lm-format-enforcer` uses Python's `re` module. The idea is to generate an email address, given a simple regex template: ??? code diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index b6dfbf10b4568..70a11d6def566 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -376,6 +376,19 @@ Supported models: Flags: `--tool-call-parser olmo3` +### Gigachat 3 Models (`gigachat3`) + +Use chat template from the Hugging Face model files. + +Supported models: + +* `ai-sage/GigaChat3-702B-A36B-preview` +* `ai-sage/GigaChat3-702B-A36B-preview-bf16` +* `ai-sage/GigaChat3-10B-A1.8B` +* `ai-sage/GigaChat3-10B-A1.8B-bf16` + +Flags: `--tool-call-parser gigachat3` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. @@ -407,7 +420,7 @@ Flags: `--tool-call-parser pythonic --chat-template {see_above}` ## How to Write a Tool Parser Plugin -A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in [vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py](../../vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py). +A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in [vllm/tool_parsers/hermes_tool_parser.py](../../vllm/tool_parsers/hermes_tool_parser.py). Here is a summary of a plugin file: @@ -455,7 +468,7 @@ Here is a summary of a plugin file: # register the tool parser to ToolParserManager ToolParserManager.register_lazy_module( name="example", - module_path="vllm.entrypoints.openai.tool_parsers.example", + module_path="vllm.tool_parsers.example", class_name="ExampleToolParser", ) diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md index d5082bc7dd3a9..cff7ce1a882a1 100644 --- a/docs/getting_started/installation/README.md +++ b/docs/getting_started/installation/README.md @@ -26,3 +26,4 @@ The backends below live **outside** the main `vllm` repository and follow the | Rebellions ATOM / REBEL NPU | `vllm-rbln` | | | IBM Spyre AIU | `vllm-spyre` | | | Cambricon MLU | `vllm-mlu` | | +| Baidu Kunlun XPU | N/A, install from source | | diff --git a/docs/getting_started/installation/cpu.apple.inc.md b/docs/getting_started/installation/cpu.apple.inc.md index 4dc707d5f9a14..9f1f6e3821397 100644 --- a/docs/getting_started/installation/cpu.apple.inc.md +++ b/docs/getting_started/installation/cpu.apple.inc.md @@ -4,9 +4,6 @@ vLLM has experimental support for macOS with Apple Silicon. For now, users must Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - # --8<-- [end:installation] # --8<-- [start:requirements] @@ -20,6 +17,8 @@ Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] +Currently, there are no pre-built Apple silicon CPU wheels. + # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] @@ -78,6 +77,8 @@ uv pip install -e . # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] +Currently, there are no pre-built Arm silicon CPU images. + # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] diff --git a/docs/getting_started/installation/cpu.arm.inc.md b/docs/getting_started/installation/cpu.arm.inc.md index 9cae9ed1a212e..657bf2509db01 100644 --- a/docs/getting_started/installation/cpu.arm.inc.md +++ b/docs/getting_started/installation/cpu.arm.inc.md @@ -1,11 +1,6 @@ # --8<-- [start:installation] -vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform. - -ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. - -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. +vLLM offers basic model inferencing and serving on Arm CPU platform, with support NEON, data types FP32, FP16 and BF16. # --8<-- [end:installation] # --8<-- [start:requirements] @@ -20,6 +15,50 @@ ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] +Pre-built vLLM wheels for Arm are available since version 0.11.2. These wheels contain pre-compiled C++ binaries. + +```bash +export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//') +uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu +``` + +??? console "pip" + ```bash + pip install vllm==${VLLM_VERSION}+cpu --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu + ``` + +The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. + +**Install the latest code** + +LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides working pre-built Arm CPU wheels for every commit since `v0.11.2` on . For native CPU wheels, this index should be used: + +* `https://wheels.vllm.ai/nightly/cpu/vllm` + +To install from nightly index, run: +```bash +uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly/cpu +``` + +??? console "pip (there's a caveat)" + + Using `pip` to install from nightly indices is _not supported_, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). + + If you insist on using `pip`, you have to specify the full URL (link address) of the wheel file (which can be obtained from https://wheels.vllm.ai/nightly/cpu/vllm). + + ```bash + pip install https://wheels.vllm.ai/4fa7ce46f31cbd97b4651694caf9991cc395a259/vllm-0.13.0rc2.dev104%2Bg4fa7ce46f.cpu-cp38-abi3-manylinux_2_35_aarch64.whl # current nightly build (the filename will change!) + ``` + +**Install specific revisions** + +If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL: + +```bash +export VLLM_COMMIT=730bd35378bf2a5b56b6d3a45be28b3092d26519 # use full commit hash from the main branch +uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT}/cpu +``` + # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] @@ -69,6 +108,24 @@ Testing has been conducted on AWS Graviton3 instances for compatibility. # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] +See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image. + +Stable vLLM Docker images are being pre-built for Arm from version 0.12.0. Available image tags are here: [https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo). + +```bash +export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//') +docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v${VLLM_VERSION} +``` + +You can also access the latest code with Docker images. These are not intended for production use and are meant for CI and testing only. They will expire after several days. + +The latest code can contain bugs and may not be stable. Please use it with caution. + +```bash +export VLLM_COMMIT=6299628d326f429eba78736acb44e76749b281f5 # use full commit hash from the main branch +docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT}-arm64-cpu +``` + # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] ```bash diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index d1beab7855b18..210f720e2d92a 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -46,10 +46,37 @@ vLLM is a Python library that supports the following CPU variants. Select your C ### Pre-built wheels -Currently, there are no pre-built CPU wheels. +When specifying the index URL, please make sure to use the `cpu` variant subdirectory. +For example, the nightly build index is: `https://wheels.vllm.ai/nightly/cpu/`. + +=== "Intel/AMD x86" + + --8<-- "docs/getting_started/installation/cpu.x86.inc.md:pre-built-wheels" + +=== "ARM AArch64" + + --8<-- "docs/getting_started/installation/cpu.arm.inc.md:pre-built-wheels" + +=== "Apple silicon" + + --8<-- "docs/getting_started/installation/cpu.apple.inc.md:pre-built-wheels" + +=== "IBM Z (S390X)" + + --8<-- "docs/getting_started/installation/cpu.s390x.inc.md:pre-built-wheels" ### Build wheel from source +#### Set up using Python-only build (without compilation) {#python-only-build} + +Please refer to the instructions for [Python-only build on GPU](./gpu.md#python-only-build), and replace the build commands with: + +```bash +VLLM_USE_PRECOMPILED=1 VLLM_PRECOMPILED_WHEEL_VARIANT=cpu VLLM_TARGET_DEVICE=cpu uv pip install --editable . +``` + +#### Full build (with compilation) {#full-build} + === "Intel/AMD x86" --8<-- "docs/getting_started/installation/cpu.x86.inc.md:build-wheel-from-source" @@ -74,6 +101,18 @@ Currently, there are no pre-built CPU wheels. --8<-- "docs/getting_started/installation/cpu.x86.inc.md:pre-built-images" +=== "ARM AArch64" + + --8<-- "docs/getting_started/installation/cpu.arm.inc.md:pre-built-images" + +=== "Apple silicon" + + --8<-- "docs/getting_started/installation/cpu.apple.inc.md:pre-built-images" + +=== "IBM Z (S390X)" + + --8<-- "docs/getting_started/installation/cpu.s390x.inc.md:pre-built-images" + ### Build image from source === "Intel/AMD x86" @@ -125,6 +164,35 @@ vllm serve facebook/opt-125m --dtype=bfloat16 Note, it is recommended to manually reserve 1 CPU for vLLM front-end process when `world_size == 1`. +### What are supported models on CPU? + +For the full and up-to-date list of models validated on CPU platforms, please see the official documentation: [Supported Models on CPU](https://docs.vllm.ai/en/latest/models/hardware_supported_models/cpu) + +### How to find benchmark configuration examples for supported CPU models? + +For any model listed under [Supported Models on CPU](https://docs.vllm.ai/en/latest/models/hardware_supported_models/cpu), optimized runtime configurations are provided in the vLLM Benchmark Suite’s CPU test cases, defined in [cpu test cases](https://github.com/vllm-project/vllm/blob/main/.buildkite/performance-benchmarks/tests/serving-tests-cpu.json) +For details on how these optimized configurations are determined, see: [performance-benchmark-details](https://github.com/vllm-project/vllm/tree/main/.buildkite/performance-benchmarks#performance-benchmark-details). +To benchmark the supported models using these optimized settings, follow the steps in [running vLLM Benchmark Suite manually](https://docs.vllm.ai/en/latest/contributing/benchmarks/#manually-trigger-the-benchmark) and run the Benchmark Suite on a CPU environment. + +Below is an example command to benchmark all CPU-supported models using optimized configurations. + +```bash +ON_CPU=1 bash .buildkite/performance-benchmarks/scripts/run-performance-benchmarks.sh +``` + +The benchmark results will be saved in `./benchmark/results/`. +In the directory, the generated `.commands` files contain all example commands for the benchmark. + +We recommend configuring tensor-parallel-size to match the number of NUMA nodes on your system. Note that the current release does not support tensor-parallel-size=6. +To determine the number of NUMA nodes available, use the following command: + +```bash +lscpu | grep "NUMA node(s):" | awk '{print $3}' +``` + +For performance reference, users may also consult the [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm&deviceName=cpu) +, which publishes default-model CPU results produced using the same Benchmark Suite. + ### How to decide `VLLM_CPU_OMP_THREADS_BIND`? - Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to the same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If you have any performance problems or unexpected binding behaviours, please try to bind threads as following. diff --git a/docs/getting_started/installation/cpu.s390x.inc.md b/docs/getting_started/installation/cpu.s390x.inc.md index c2163139a7c5d..4984c87c17b01 100644 --- a/docs/getting_started/installation/cpu.s390x.inc.md +++ b/docs/getting_started/installation/cpu.s390x.inc.md @@ -4,9 +4,6 @@ vLLM has experimental support for s390x architecture on IBM Z platform. For now, Currently, the CPU implementation for s390x architecture supports FP32 datatype only. -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - # --8<-- [end:installation] # --8<-- [start:requirements] @@ -21,6 +18,8 @@ Currently, the CPU implementation for s390x architecture supports FP32 datatype # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] +Currently, there are no pre-built IBM Z CPU wheels. + # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] @@ -69,6 +68,8 @@ Execute the following commands to build and install vLLM from source. # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] +Currently, there are no pre-built IBM Z CPU images. + # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] diff --git a/docs/getting_started/installation/cpu.x86.inc.md b/docs/getting_started/installation/cpu.x86.inc.md index 310f179cb89ca..1fad7f4338822 100644 --- a/docs/getting_started/installation/cpu.x86.inc.md +++ b/docs/getting_started/installation/cpu.x86.inc.md @@ -17,6 +17,8 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] +Currently, there are no pre-built x86 CPU wheels. + # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] diff --git a/docs/getting_started/installation/gpu.cuda.inc.md b/docs/getting_started/installation/gpu.cuda.inc.md index 601d3659af886..03ce28c78efc9 100644 --- a/docs/getting_started/installation/gpu.cuda.inc.md +++ b/docs/getting_started/installation/gpu.cuda.inc.md @@ -26,42 +26,49 @@ uv pip install vllm --torch-backend=auto ??? console "pip" ```bash - # Install vLLM with CUDA 12.8. - pip install vllm --extra-index-url https://download.pytorch.org/whl/cu128 + # Install vLLM with CUDA 12.9. + pip install vllm --extra-index-url https://download.pytorch.org/whl/cu129 ``` -We recommend leveraging `uv` to [automatically select the appropriate PyTorch index at runtime](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection) by inspecting the installed CUDA driver version via `--torch-backend=auto` (or `UV_TORCH_BACKEND=auto`). To select a specific backend (e.g., `cu126`), set `--torch-backend=cu126` (or `UV_TORCH_BACKEND=cu126`). If this doesn't work, try running `uv self update` to update `uv` first. +We recommend leveraging `uv` to [automatically select the appropriate PyTorch index at runtime](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection) by inspecting the installed CUDA driver version via `--torch-backend=auto` (or `UV_TORCH_BACKEND=auto`). To select a specific backend (e.g., `cu128`), set `--torch-backend=cu128` (or `UV_TORCH_BACKEND=cu128`). If this doesn't work, try running `uv self update` to update `uv` first. !!! note NVIDIA Blackwell GPUs (B200, GB200) require a minimum of CUDA 12.8, so make sure you are installing PyTorch wheels with at least that version. PyTorch itself offers a [dedicated interface](https://pytorch.org/get-started/locally/) to determine the appropriate pip command to run for a given target configuration. -As of now, vLLM's binaries are compiled with CUDA 12.8 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 12.6, 11.8, and public PyTorch release versions: +As of now, vLLM's binaries are compiled with CUDA 12.9 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 12.8, 13.0, and public PyTorch release versions: ```bash -# Install vLLM with a specific CUDA version (e.g., 11.8 or 12.6). +# Install vLLM with a specific CUDA version (e.g., 13.0). export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//') -export CUDA_VERSION=118 # or 126 -uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu${CUDA_VERSION}-cp38-abi3-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu${CUDA_VERSION} +export CUDA_VERSION=130 # or other +uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu${CUDA_VERSION}-cp38-abi3-manylinux_2_31_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu${CUDA_VERSION} ``` #### Install the latest code -LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on an x86 platform with CUDA 12 for every commit since `v0.5.3`. +LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for every commit since `v0.5.3` on . There are multiple indices that could be used: + +* `https://wheels.vllm.ai/nightly`: the default variant (CUDA with version specified in `VLLM_MAIN_CUDA_VERSION`) built with the last commit on the `main` branch. Currently it is CUDA 12.9. +* `https://wheels.vllm.ai/nightly/`: all other variants. Now this includes `cu130`, and `cpu`. The default variant (`cu129`) also has a subdirectory to keep consistency. + +To install from nightly index, run: ```bash uv pip install -U vllm \ --torch-backend=auto \ - --extra-index-url https://wheels.vllm.ai/nightly + --extra-index-url https://wheels.vllm.ai/nightly # add variant subdirectory here if needed ``` -??? console "pip" - ```bash - pip install -U vllm \ - --pre \ - --extra-index-url https://wheels.vllm.ai/nightly - ``` +!!! warning "`pip` caveat" - `--pre` is required for `pip` to consider pre-released versions. + Using `pip` to install from nightly indices is _not supported_, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). + + If you insist on using `pip`, you have to specify the full URL of the wheel file (which can be obtained from the web page). + + ```bash + pip install -U https://wheels.vllm.ai/nightly/vllm-0.11.2.dev399%2Bg3c7461c18-cp38-abi3-manylinux_2_31_x86_64.whl # current nightly build (the filename will change!) + pip install -U https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-0.11.2.dev399%2Bg3c7461c18-cp38-abi3-manylinux_2_31_x86_64.whl # from specific commit + ``` ##### Install specific revisions @@ -71,33 +78,13 @@ If you want to access the wheels for previous commits (e.g. to bisect the behavi export VLLM_COMMIT=72d9c316d3f6ede485146fe5aabd4e61dbc59069 # use full commit hash from the main branch uv pip install vllm \ --torch-backend=auto \ - --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT} + --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT} # add variant subdirectory here if needed ``` -The `uv` approach works for vLLM `v0.6.6` and later and offers an easy-to-remember command. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. - -??? note "pip" - If you want to access the wheels for previous commits (e.g. to bisect the behavior change, - performance regression), due to the limitation of `pip`, you have to specify the full URL of the - wheel file by embedding the commit hash in the URL: - - ```bash - export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch - pip install https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl - ``` - - Note that the wheels are built with Python 3.8 ABI (see [PEP - 425](https://peps.python.org/pep-0425/) for more details about ABI), so **they are compatible - with Python 3.8 and later**. The version string in the wheel file name (`1.0.0.dev`) is just a - placeholder to have a unified URL for the wheels, the actual versions of wheels are contained in - the wheel metadata (the wheels listed in the extra index url have correct versions). Although we - don't support Python 3.8 any more (because PyTorch 2.5 dropped support for Python 3.8), the - wheels are still built with Python 3.8 ABI to keep the same wheel name as before. - # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] -#### Set up using Python-only build (without compilation) +#### Set up using Python-only build (without compilation) {#python-only-build} If you only need to change Python code, you can build and install vLLM without compilation. Using `uv pip`'s [`--editable` flag](https://docs.astral.sh/uv/pip/packages/#editable-packages), changes you make to the code will be reflected when you run vLLM: @@ -121,18 +108,24 @@ This command will do the following: In case you see an error about wheel not found when running the above command, it might be because the commit you based on in the main branch was just merged and the wheel is being built. In this case, you can wait for around an hour to try again, or manually assign the previous commit in the installation using the `VLLM_PRECOMPILED_WHEEL_LOCATION` environment variable. ```bash -export VLLM_COMMIT=72d9c316d3f6ede485146fe5aabd4e61dbc59069 # use full commit hash from the main branch -export VLLM_PRECOMPILED_WHEEL_LOCATION=https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl +export VLLM_PRECOMPILED_WHEEL_COMMIT=$(git rev-parse HEAD~1) # or earlier commit on main +export VLLM_USE_PRECOMPILED=1 uv pip install --editable . ``` +There are more environment variables to control the behavior of Python-only build: + +* `VLLM_PRECOMPILED_WHEEL_LOCATION`: specify the exact wheel URL or local file path of a pre-compiled wheel to use. All other logic to find the wheel will be skipped. +* `VLLM_PRECOMPILED_WHEEL_COMMIT`: override the commit hash to download the pre-compiled wheel. It can be `nightly` to use the last **already built** commit on the main branch. +* `VLLM_PRECOMPILED_WHEEL_VARIANT`: specify the variant subdirectory to use on the nightly index, e.g., `cu129`, `cpu`. If not specified, the CUDA variant with `VLLM_MAIN_CUDA_VERSION` will be tried, then fallback to the default variant on the remote index. + You can find more information about vLLM's wheels in [Install the latest code](#install-the-latest-code). !!! note There is a possibility that your source code may have a different commit ID compared to the latest vLLM wheel, which could potentially lead to unknown errors. It is recommended to use the same commit ID for the source code as the vLLM wheel you have installed. Please refer to [Install the latest code](#install-the-latest-code) for instructions on how to install a specified wheel. -#### Full build (with compilation) +#### Full build (with compilation) {#full-build} If you want to modify C++ or CUDA code, you'll need to build vLLM from source. This can take several minutes: diff --git a/docs/getting_started/installation/gpu.md b/docs/getting_started/installation/gpu.md index bc7508b29475f..fb750f4499858 100644 --- a/docs/getting_started/installation/gpu.md +++ b/docs/getting_started/installation/gpu.md @@ -52,7 +52,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G --8<-- "docs/getting_started/installation/gpu.xpu.inc.md:set-up-using-python" -### Pre-built wheels +### Pre-built wheels {#pre-built-wheels} === "NVIDIA CUDA" diff --git a/docs/getting_started/installation/gpu.rocm.inc.md b/docs/getting_started/installation/gpu.rocm.inc.md index c80ba9478f6be..21120cc6fcd98 100644 --- a/docs/getting_started/installation/gpu.rocm.inc.md +++ b/docs/getting_started/installation/gpu.rocm.inc.md @@ -5,9 +5,6 @@ vLLM supports AMD GPUs with ROCm 6.3 or above, and torch 2.8.0 and above. !!! tip [Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm. -!!! warning - There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. - # --8<-- [end:installation] # --8<-- [start:requirements] diff --git a/docs/getting_started/installation/gpu.xpu.inc.md b/docs/getting_started/installation/gpu.xpu.inc.md index 620a660a240ed..7e9c6a2b9de07 100644 --- a/docs/getting_started/installation/gpu.xpu.inc.md +++ b/docs/getting_started/installation/gpu.xpu.inc.md @@ -2,9 +2,6 @@ vLLM initially supports basic model inference and serving on Intel GPU platform. -!!! warning - There are no pre-built wheels for this device, so you need build vLLM from source. Or you can use pre-built images which are based on vLLM released versions. - # --8<-- [end:installation] # --8<-- [start:requirements] diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 94920dc5306b3..e3974354d8f3b 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -281,17 +281,27 @@ Alternatively, you can use the `openai` Python package: Currently, vLLM supports multiple backends for efficient Attention computation across different platforms and accelerator architectures. It automatically selects the most performant backend compatible with your system and model specifications. -If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options: +If desired, you can also manually set the backend of your choice using the `--attention-backend` CLI argument: + +```bash +# For online serving +vllm serve Qwen/Qwen2.5-1.5B-Instruct --attention-backend FLASH_ATTN + +# For offline inference +python script.py --attention-backend FLASHINFER +``` + +Some of the available backend options include: - On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`. - On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`. -For AMD ROCm, you can further control the specific Attention implementation using the following variables: +For AMD ROCm, you can further control the specific Attention implementation using the following options: -- Triton Unified Attention: `VLLM_ROCM_USE_AITER=0 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0` -- AITER Unified Attention: `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0` -- Triton Prefill-Decode Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0` -- AITER Multi-head Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=1` +- Triton Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=0 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument. +- AITER Unified Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument. +- Triton Prefill-Decode Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=0` and pass `--attention-config.use_prefill_decode_attention=true` as a CLI argument. +- AITER Multi-head Attention: Set the environment variables `VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=1` and pass `--attention-config.use_prefill_decode_attention=false` as a CLI argument. !!! warning There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see [docker/Dockerfile](../../docker/Dockerfile) for instructions on how to install it. diff --git a/docs/governance/collaboration.md b/docs/governance/collaboration.md new file mode 100644 index 0000000000000..5b3d2beffe5b9 --- /dev/null +++ b/docs/governance/collaboration.md @@ -0,0 +1,43 @@ +# Collaboration Policy + +This page outlines how vLLM collaborates with model providers, hardware vendors, and other stakeholders. + +## Adding New Major Features + +Anyone can contribute to vLLM. For major features, submit an RFC (request for comments) first. To submit an RFC, create an [issue](https://github.com/vllm-project/vllm/issues/new/choose) and select the `RFC` template. +RFCs are similar to design docs that discuss the motivation, problem solved, alternatives considered, and proposed change. + +Once you submit the RFC, please post it in the #contributors channel in vLLM Slack, and loop in area owners and committers for feedback. +For high-interest features, the committers nominate a person to help with the RFC process and PR review. This makes sure someone is guiding you through the process. It is reflected as the "assignee" field in the RFC issue. +If the assignee and lead maintainers find the feature to be contentious, the maintainer team aims to make decisions quickly after learning the details from everyone. This involves assigning a committer as the DRI (Directly Responsible Individual) to make the decision and shepherd the code contribution process. + +For features that you intend to maintain, please feel free to add yourself in [`mergify.yml`](https://github.com/vllm-project/vllm/blob/main/.github/mergify.yml) to receive notifications and auto-assignment when the PRs touching the feature you are maintaining. Over time, the ownership will be evaluated and updated through the committers nomination and voting process. + +## Adding New Models + +If you use vLLM, we recommend you making the model work with vLLM by following the [model registration](../contributing/model/registration.md) process before you release it publicly. + +The vLLM team helps with new model architectures not supported by vLLM, especially models pushing architectural frontiers. +Here's how the vLLM team works with model providers. The vLLM team includes all [committers](./committers.md) of the project. model providers can exclude certain members but shouldn't, as this may harm release timelines due to missing expertise. Contact [project leads](./process.md) if you want to collaborate. + +Once we establish the connection between the vLLM team and model provider: + +- The vLLM team learns the model architecture and relevant changes, then plans which area owners to involve and what features to include. +- The vLLM team creates a private communication channel (currently a Slack channel in the vLLM workspace) and a private fork within the vllm-project organization. The model provider team can invite others to the channel and repo. +- Third parties like compute providers, hosted inference providers, hardware vendors, and other organizations often work with both the model provider and vLLM on model releases. We establish direct communication (with permission) or three-way communication as needed. + +The vLLM team works with model providers on features, integrations, and release timelines. We work to meet release timelines, but engineering challenges like feature development, model accuracy alignment, and optimizations can cause delays. + +The vLLM maintainers will not publicly share details about model architecture, release timelines, or upcoming releases. We maintain model weights on secure servers with security measures (though we can work with security reviews and testing without certification). We delete pre-release weights or artifacts upon request. + +The vLLM team collaborates on marketing and promotional efforts for model releases. model providers can use vLLM's trademark and logo in publications and materials. + +## Adding New Hardware + +vLLM is designed as a platform for frontier model architectures and high-performance accelerators. +For new hardware, follow the [hardware plugin](../design/plugin_system.md) system to add support. +Use the platform plugin system to add hardware support. +As hardware gains popularity, we help endorse it in our documentation and marketing materials. +The vLLM GitHub organization can host hardware plugin repositories, especially for collaborative efforts among companies. + +We rarely add new hardware to vLLM directly. Instead, we make existing hardware platforms modular to keep the vLLM core hardware-agnostic. diff --git a/docs/governance/committers.md b/docs/governance/committers.md new file mode 100644 index 0000000000000..c9428027da953 --- /dev/null +++ b/docs/governance/committers.md @@ -0,0 +1,183 @@ +# Committers + +This document lists the current committers of the vLLM project and the core areas they maintain. +Committers have write access to the vLLM repository and are responsible for reviewing and merging PRs. +You can also refer to the [CODEOWNERS](https://github.com/vllm-project/vllm/blob/main/.github/CODEOWNERS) file for concrete file-level ownership and reviewers. Both this documents and the CODEOWNERS file are living documents and they complement each other. + +## Active Committers + +We try to summarize each committer's role in vLLM in a few words. In general, vLLM committers cover a wide range of areas and help each other in the maintenance process. +Please refer to the later section about Area Owners for exact component ownership details. +Sorted alphabetically by GitHub handle: + +- [@22quinn](https://github.com/22quinn): RL API +- [@aarnphm](https://github.com/aarnphm): Structured output +- [@alexm-redhat](https://github.com/alexm-redhat): Performance +- [@ApostaC](https://github.com/ApostaC): Connectors, offloading +- [@benchislett](https://github.com/benchislett): Engine core and spec decode +- [@bigPYJ1151](https://github.com/bigPYJ1151): Intel CPU/XPU integration +- [@chaunceyjiang](https://github.com/chaunceyjiang): Tool use and reasoning parser +- [@DarkLight1337](https://github.com/DarkLight1337): Multimodality, API server +- [@esmeetu](https://github.com/esmeetu): developer marketing, community +- [@gshtras](https://github.com/gshtras): AMD integration +- [@heheda12345](https://github.com/heheda12345): Hybrid memory allocator +- [@hmellor](https://github.com/hmellor): Hugging Face integration, documentation +- [@houseroad](https://github.com/houseroad): Engine core and Llama models +- [@Isotr0py](https://github.com/Isotr0py): Multimodality, new model support +- [@jeejeelee](https://github.com/jeejeelee): LoRA, new model support +- [@jikunshang](https://github.com/jikunshang): Intel CPU/XPU integration +- [@khluu](https://github.com/khluu): CI infrastructure +- [@KuntaiDu](https://github.com/KuntaiDu): KV Connector +- [@LucasWilkinson](https://github.com/LucasWilkinson): Kernels and performance +- [@luccafong](https://github.com/luccafong): Llama models, speculative decoding, distributed +- [@markmc](https://github.com/markmc): Observability +- [@mgoin](https://github.com/mgoin): Quantization and performance +- [@NickLucche](https://github.com/NickLucche): KV connector +- [@njhill](https://github.com/njhill): Distributed, API server, engine core +- [@noooop](https://github.com/noooop): Pooling models +- [@patrickvonplaten](https://github.com/patrickvonplaten): Mistral models, new model support +- [@pavanimajety](https://github.com/pavanimajety): NVIDIA GPU integration +- [@ProExpertProg](https://github.com/ProExpertProg): Compilation, startup UX +- [@robertgshaw2-redhat](https://github.com/robertgshaw2-redhat): Core, distributed, disagg +- [@ruisearch42](https://github.com/ruisearch42): Pipeline parallelism, Ray Support +- [@russellb](https://github.com/russellb): Structured output, engine core, security +- [@sighingnow](https://github.com/sighingnow): Qwen models, new model support +- [@simon-mo](https://github.com/simon-mo): Project lead, API entrypoints, community +- [@tdoublep](https://github.com/tdoublep): State space models +- [@tjtanaa](https://github.com/tjtanaa): AMD GPU integration +- [@tlrmchlsmth](https://github.com/tlrmchlsmth): Kernels and performance, distributed, disagg +- [@WoosukKwon](https://github.com/WoosukKwon): Project lead, engine core +- [@yaochengji](https://github.com/yaochengji): TPU integration +- [@yeqcharlotte](https://github.com/yeqcharlotte): Benchmark, Llama models +- [@yewentao256](https://github.com/yewentao256): Kernels and performance +- [@Yikun](https://github.com/Yikun): Pluggable hardware interface +- [@youkaichao](https://github.com/youkaichao): Project lead, distributed, compile, community +- [@ywang96](https://github.com/ywang96): Multimodality, benchmarks +- [@zhuohan123](https://github.com/zhuohan123): Project lead, RL integration, numerics +- [@zou3519](https://github.com/zou3519): Compilation + +### Emeritus Committers + +Committers who have contributed to vLLM significantly in the past (thank you!) but no longer active: + +- [@andoorve](https://github.com/andoorve): Pipeline parallelism +- [@cadedaniel](https://github.com/cadedaniel): Speculative decoding +- [@comaniac](https://github.com/comaniac): KV cache management, pipeline parallelism +- [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU): Speculative decoding +- [@pcmoritz](https://github.com/pcmoritz): MoE +- [@rkooo567](https://github.com/rkooo567): Chunked prefill +- [@sroy745](https://github.com/sroy745): Speculative decoding +- [@Yard1](https://github.com/Yard1): kernels and performance +- [@zhisbug](https://github.com/zhisbug): Arctic models, distributed + +## Area Owners + +This section breaks down the active committers by vLLM components and lists the area owners. +If you have PRs touching the area, please feel free to ping the area owner for review. + +### Engine Core + +- Scheduler: the core vLLM engine loop scheduling requests to next batch + - @WoosukKwon, @robertgshaw2-redhat, @njhill, @heheda12345 +- KV Cache Manager: memory management layer within scheduler maintaining KV cache logical block data + - @heheda12345, @WoosukKwon +- AsyncLLM: the zmq based protocol hosting engine core and making it accessible for entrypoints + - @robertgshaw2-redhat, @njhill, @russellb +- ModelRunner, Executor, Worker: the abstractions for engine wrapping model implementation + - @WoosukKwon, @tlrmchlsmth, @heheda12345, @LucasWilkinson, @ProExpertProg +- KV Connector: Connector interface and implementation for KV cache offload and transfer + - @robertgshaw2-redhat, @njhill, @KuntaiDu, @NickLucche, @ApostaC +- Distributed, Parallelism, Process Management: Process launchers managing each worker, and assign them to the right DP/TP/PP/EP ranks + - @youkaichao, @njhill, @WoosukKwon, @ruisearch42 +- Collectives: the usage of nccl and other communication libraries/kernels + - @tlrmchlsmth, @youkaichao +- Multimodality engine and memory management: core scheduling and memory management concerning vision, audio, and video inputs. + - @ywang96, @DarkLight1337 + +### Model Implementations + +- Model Interface: The `nn.Module` interface and implementation for various models + - @zhuohan123, @mgoin, @simon-mo, @houseroad, @ywang96 (multimodality), @jeejeelee (lora) +- Logits Processors / Sampler: The provided sampler class and pluggable logits processors + - @njhill, @houseroad, @22quinn +- Custom Layers: Utility layers in vLLM such as rotary embedding and rms norms + - @ProExpertProg +- Attention: Attention interface for paged attention + - @WoosukKwon, @LucasWilkinson, @heheda12345 +- FusedMoE: FusedMoE kernel, Modular kernel framework, EPLB + - @tlrmchlsmth +- Quantization: Various quantization config, weight loading, and kernel. + - @mgoin, @Isotr0py, @yewentao256 +- Custom quantized GEMM kernels (cutlass_scaled_mm, marlin, machete) + - @tlrmchlsmth, @LucasWilkinson +- Multi-modal Input Processing: Components that load and process image/video/audio data into feature tensors + - @DarkLight1337, @ywang96, @Isotr0py +- torch compile: The torch.compile integration in vLLM, custom passes & transformations + - @ProExpertProg, @zou3519, @youkaichao +- State space models: The state space models implementation in vLLM + - @tdoublep, @tlrmchlsmth +- Reasoning and tool calling parsers + - @chaunceyjiang, @aarnphm + +### Entrypoints + +- LLM Class: The LLM class for offline inference + - @DarkLight1337 +- API Server: The OpenAI-compatible API server + - @DarkLight1337, @njhill, @aarnphm, @simon-mo, @heheda12345 (Responses API) +- Batch Runner: The OpenAI-compatible batch runner + - @simon-mo + +### Features + +- Spec Decode: Covers model definition, attention, sampler, and scheduler related to n-grams, EAGLE, and MTP. + - @WoosukKwon, @benchislett, @luccafong +- Structured Output: The structured output implementation + - @russellb, @aarnphm +- RL: The RL related features such as collective rpc, sleep mode, etc. + - @youkaichao, @zhuohan123, @22quinn +- LoRA: @jeejeelee +- Observability: Metrics and Logging + - @markmc, @robertgshaw2-redhat, @simon-mo + +### Code Base + +- Config: Configuration registration and parsing + - @hmellor +- Documentation: @hmellor, @DarkLight1337, @simon-mo +- Benchmarks: @ywang96, @simon-mo +- CI, Build, Release Process: @khluu, @njhill, @simon-mo +- Security: @russellb + +### External Kernels Integration + +- FlashAttention: @LucasWilkinson +- FlashInfer: @LucasWilkinson, @mgoin, @WoosukKwon +- Blackwell Kernels: @mgoin, @yewentao256 +- DeepEP/DeepGEMM/pplx: @mgoin, @yewentao256 + +### Integrations + +- Hugging Face: @hmellor, @Isotr0py +- Ray: @ruisearch42 +- NIXL: @robertgshaw2-redhat, @NickLucche + +### Collaboration with Model Vendors + +- gpt-oss: @heheda12345, @simon-mo, @zhuohan123 +- Llama: @luccafong +- Qwen: @sighingnow +- Mistral: @patrickvonplaten + +### Hardware + +- Plugin Interface: @youkaichao, @Yikun +- NVIDIA GPU: @pavanimajety +- AMD GPU: @gshtras, @tjtanaa +- Intel CPU/GPU: @jikunshang, @bigPYJ1151 +- Google TPU: @yaochengji + +### Ecosystem Projects + +- Ascend NPU: [@wangxiyuan](https://github.com/wangxiyuan) and [see more details](https://vllm-ascend.readthedocs.io/en/latest/community/contributors.html#maintainers) +- Intel Gaudi HPU [@xuechendi](https://github.com/xuechendi) and [@kzawora-intel](https://github.com/kzawora-intel) diff --git a/docs/governance/process.md b/docs/governance/process.md new file mode 100644 index 0000000000000..1e088dd3c1e64 --- /dev/null +++ b/docs/governance/process.md @@ -0,0 +1,125 @@ +# Governance Process + +vLLM's success comes from our strong open source community. We favor informal, meritocratic norms over formal policies. This document clarifies our governance philosophy and practices. + +## Values + +vLLM aims to be the fastest and easiest-to-use LLM inference and serving engine. We stay current with advances, enable innovation, and support diverse models, modalities, and hardware. + +### Design Values + +1. **Top performance**: System performance is our top priority. We monitor overheads, optimize kernels, and publish benchmarks. We never leave performance on the table. +2. **Ease of use**: vLLM must be simple to install, configure, and operate. We provide clear documentation, fast startup, clean logs, helpful error messages, and monitoring guides. Many users fork our code or study it deeply, so we keep it readable and modular. +3. **Wide coverage**: vLLM supports frontier models and high-performance accelerators. We make it easy to add new models and hardware. vLLM + PyTorch form a simple interface that avoids complexity. +4. **Production ready**: vLLM runs 24/7 in production. It must be easy to operate and monitor for health issues. +5. **Extensibility**: vLLM serves as fundamental LLM infrastructure. Our codebase cannot cover every use case, so we design for easy forking and customization. + +### Collaboration Values + +1. **Tightly Knit and Fast-Moving**: Our maintainer team is aligned on vision, philosophy, and roadmap. We work closely to unblock each other and move quickly. +2. **Individual Merit**: No one buys their way into governance. Committer status belongs to individuals, not companies. We reward contribution, maintenance, and project stewardship. + +## Project Maintainers + +Maintainers form a hierarchy based on sustained, high-quality contributions and alignment with our design philosophy. + +### Core Maintainers + +Core Maintainers function like a project planning and decision making committee. In other convention, they might be called a Technical Steering Committee (TSC). In vLLM vocabulary, they are often known as "Project Leads". They meet weekly to coordinate roadmap priorities and allocate engineering resources. Current active leads: @WoosukKwon, @zhuohan123, @simon-mo, @youkaichao, @robertshaw2-redhat, @tlrmchlsmth, @mgoin, @njhill, @ywang96, @houseroad, @yeqcharlotte, @ApostaC + +The responsibilities of the core maintainers are: + +* Author quarterly roadmap and responsible for each development effort. +* Making major changes to the technical direction or scope of vLLM and vLLM projects. +* Defining the project's release strategy. +* Work with model providers, hardware vendors, and key users of vLLM to ensure the project is on the right track. + +### Lead Maintainers + +While Core maintainers assume the day-to-day responsibilities of the project, Lead maintainers are responsible for the overall direction and strategy of the project. A committee of @WoosukKwon, @zhuohan123, @simon-mo, and @youkaichao currently shares this role with divided responsibilities. + +The responsibilities of the lead maintainers are: + +* Making decisions where consensus among core maintainers cannot be reached. +* Adopting changes to the project's technical governance. +* Organizing the voting process for new committers. + +### Committers and Area Owners + +Committers have write access and merge rights. They typically have deep expertise in specific areas and help the community. + +The responsibilities of the committers are: + +* Reviewing PRs and providing feedback. +* Addressing issues and questions from the community. +* Own specific areas of the codebase and development efforts: reviewing PRs, addressing issues, answering questions, improving documentation. + +Specially, committers are almost all area owners. They author subsystems, review PRs, refactor code, monitor tests, and ensure compatibility with other areas. All area owners are committers with deep expertise in that area, but not all committers own areas. + +For a full list of committers and their respective areas, see the [committers](./committers.md) page. + +#### Nomination Process + +Any committer can nominate candidates via our private mailing list: + +1. **Nominate**: Any committer may nominate a candidate by email to the private maintainers’ list, citing evidence mapped to the pre‑existing standards with links to PRs, reviews, RFCs, issues, benchmarks, and adoption evidence. +2. **Vote**: The lead maintainers will group voices support or concerns. Shared concerns can stop the process. The vote typically last 3 working days. For concerns, committers group discuss the clear criteria for such person to be nominated again. The lead maintainers will make the final decision. +3. **Confirm**: The lead maintainers send invitation, update CODEOWNERS, assign permissions, add to communications channels (mailing list and Slack). + +Committership is highly selective and merit based. The selection criteria requires: + +* **Area expertise**: leading design/implementation of core subsystems, material performance or reliability improvements adopted project‑wide, or accepted RFCs that shape technical direction. +* **Sustained contributions**: high‑quality merged contributions and reviews across releases, responsiveness to feedback, and stewardship of code health. +* **Community leadership**: mentoring contributors, triaging issues, improving docs, and elevating project standards. + +To further illustrate, a committer typically satisfies at least two of the following accomplishment patterns: + +* Author of an accepted RFC or design that materially shaped project direction +* Measurable, widely adopted performance or reliability improvement in core paths +* Long‑term ownership of a subsystem with demonstrable quality and stability gains +* Significant cross‑project compatibility or ecosystem enablement work (models, hardware, tooling) + +While there isn't a quantitative bar, past committers have: + +* Submitted approximately 30+ PRs of substantial quality and scope +* Provided high-quality reviews of approximately 10+ substantial external contributor PRs +* Addressed multiple issues and questions from the community in issues/forums/Slack +* Led concentrated efforts on RFCs and their implementation, or significant performance or reliability improvements adopted project‑wide + +### Working Groups + +vLLM runs informal working groups such as CI, CI infrastructure, torch compile, and startup UX. These can be loosely tracked via `#sig-` (or `#feat-`) channels in vLLM Slack. Some groups have regular sync meetings. + +### Advisory Board + +vLLM project leads consult with an informal advisory board that is composed of model providers, hardware vendors, and ecosystem partners. This manifests as a collaboration channel in Slack and frequent communications. + +## Process + +### Project Roadmap + +Project Leads publish quarterly roadmaps as GitHub issues. These clarify current priorities. Unlisted topics aren't excluded but may get less review attention. See [https://roadmap.vllm.ai/](https://roadmap.vllm.ai/). + +### Decision Making + +We make technical decisions in Slack and GitHub using RFCs and design docs. Discussion may happen elsewhere, but we maintain public records of significant changes: problem statements, rationale, and alternatives considered. + +### Merging Code + +Contributors and maintainers often collaborate closely on code changes, especially within organizations or specific areas. Maintainers should give others appropriate review opportunities based on change significance. + +PRs requires at least one committer review and approval. If the code is covered by CODEOWNERS, the PR should be reviewed by the CODEOWNERS. There are cases where the code is trivial or hotfix, the PR can be merged by the lead maintainers directly. + +In case where CI didn't pass due to the failure is not related to the PR, the PR can be merged by the lead maintainers using "force merge" option that overrides the CI checks. + +### Slack + +Contributors are encouraged to join `#pr-reviews` and `#contributors` channels. + +There are `#sig-` and `#feat-` channels for discussion and coordination around specific topics. + +The project maintainer group also uses a private channel for high-bandwidth collaboration. + +### Meetings + +We hold weekly contributor syncs with standup-style updates on progress, blockers, and plans. You can refer to the notes [standup.vllm.ai](https://standup.vllm.ai) for joining instructions. diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 6e4fb039e3a07..e886a91e65732 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools import logging -from dataclasses import dataclass, field +from dataclasses import dataclass +from functools import cached_property from pathlib import Path from typing import Literal @@ -16,13 +17,18 @@ EXAMPLE_DIR = ROOT_DIR / "examples" EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples" -def fix_case(text: str) -> str: +def title(text: str) -> str: + # Default title case + text = text.replace("_", " ").replace("/", " - ").title() + # Custom substitutions subs = { + "io": "IO", "api": "API", "cli": "CLI", "cpu": "CPU", "llm": "LLM", "mae": "MAE", + "ner": "NER", "tpu": "TPU", "gguf": "GGUF", "lora": "LoRA", @@ -48,71 +54,65 @@ class Example: Attributes: path (Path): The path to the main directory or file. category (str): The category of the document. - main_file (Path): The main file in the directory. - other_files (list[Path]): list of other files in the directory. - title (str): The title of the document. + + Properties:: + main_file() -> Path | None: Determines the main file in the given path. + other_files() -> list[Path]: Determines other files in the directory excluding + the main file. + title() -> str: Determines the title of the document. Methods: - __post_init__(): Initializes the main_file, other_files, and title attributes. - determine_main_file() -> Path: Determines the main file in the given path. - determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file. - determine_title() -> str: Determines the title of the document. generate() -> str: Generates the documentation content. - """ # noqa: E501 + """ path: Path - category: str = None - main_file: Path = field(init=False) - other_files: list[Path] = field(init=False) - title: str = field(init=False) + category: str - def __post_init__(self): - self.main_file = self.determine_main_file() - self.other_files = self.determine_other_files() - self.title = self.determine_title() + @cached_property + def main_file(self) -> Path | None: + """Determines the main file in the given path. - @property - def is_code(self) -> bool: - return self.main_file.suffix != ".md" + If path is a file, it returns the path itself. If path is a directory, it + searches for Markdown files (*.md) in the directory and returns the first one + found. If no Markdown files are found, it returns None.""" + # Single file example + if self.path.is_file(): + return self.path + # Multi file example with a README + if md_paths := list(self.path.glob("*.md")): + return md_paths[0] + # Multi file example without a README + return None - def determine_main_file(self) -> Path: - """ - Determines the main file in the given path. - If the path is a file, it returns the path itself. Otherwise, it searches - for Markdown files (*.md) in the directory and returns the first one found. - Returns: - Path: The main file path, either the original path if it's a file or the first - Markdown file found in the directory. - Raises: - IndexError: If no Markdown files are found in the directory. - """ # noqa: E501 - return self.path if self.path.is_file() else list(self.path.glob("*.md")).pop() + @cached_property + def other_files(self) -> list[Path]: + """Determine other files in the directory excluding the main file. - def determine_other_files(self) -> list[Path]: - """ - Determine other files in the directory excluding the main file. - - This method checks if the given path is a file. If it is, it returns an empty list. - Otherwise, it recursively searches through the directory and returns a list of all - files that are not the main file. - - Returns: - list[Path]: A list of Path objects representing the other files in the directory. - """ # noqa: E501 + If path is a file, it returns an empty list. Otherwise, it returns every file + in the directory except the main file in a list.""" + # Single file example if self.path.is_file(): return [] + # Multi file example is_other_file = lambda file: file.is_file() and file != self.main_file - return [file for file in self.path.rglob("*") if is_other_file(file)] + return sorted(file for file in self.path.rglob("*") if is_other_file(file)) - def determine_title(self) -> str: - if not self.is_code: - # Specify encoding for building on Windows - with open(self.main_file, encoding="utf-8") as f: - first_line = f.readline().strip() - match = re.match(r"^#\s+(?P.+)$", first_line) - if match: - return match.group("title") - return fix_case(self.path.stem.replace("_", " ").title()) + @cached_property + def is_code(self) -> bool: + return self.main_file is not None and self.main_file.suffix != ".md" + + @cached_property + def title(self) -> str: + # Generate title from filename if no main md file found + if self.main_file is None or self.is_code: + return title(self.path.stem) + # Specify encoding for building on Windows + with open(self.main_file, encoding="utf-8") as f: + first_line = f.readline().strip() + match = re.match(r"^#\s+(?P<title>.+)$", first_line) + if match: + return match.group("title") + raise ValueError(f"Title not found in {self.main_file}") def fix_relative_links(self, content: str) -> str: """ @@ -156,24 +156,35 @@ class Example: # included files containing code fences too code_fence = "``````" - if self.is_code: - content += ( - f"{code_fence}{self.main_file.suffix[1:]}\n" - f'--8<-- "{self.main_file}"\n' - f"{code_fence}\n" - ) + if self.main_file is not None: + # Single file example or multi file example with a README + if self.is_code: + content += ( + f"{code_fence}{self.main_file.suffix[1:]}\n" + f'--8<-- "{self.main_file}"\n' + f"{code_fence}\n" + ) + else: + with open(self.main_file, encoding="utf-8") as f: + # Skip the title from md snippets as it's been included above + main_content = f.readlines()[1:] + content += self.fix_relative_links("".join(main_content)) + content += "\n" else: - with open(self.main_file) as f: - # Skip the title from md snippets as it's been included above - main_content = f.readlines()[1:] - content += self.fix_relative_links("".join(main_content)) - content += "\n" + # Multi file example without a README + for file in self.other_files: + file_title = title(str(file.relative_to(self.path).with_suffix(""))) + content += f"## {file_title}\n\n" + content += ( + f'{code_fence}{file.suffix[1:]}\n--8<-- "{file}"\n{code_fence}\n\n' + ) + return content if not self.other_files: return content content += "## Example materials\n\n" - for file in sorted(self.other_files): + for file in self.other_files: content += f'??? abstract "{file.relative_to(self.path)}"\n' if file.suffix != ".md": content += f" {code_fence}{file.suffix[1:]}\n" @@ -200,11 +211,13 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): glob_patterns = ["*.py", "*.md", "*.sh"] # Find categorised examples for category in categories: + logger.info("Processing category: %s", category.stem) globs = [category.glob(pattern) for pattern in glob_patterns] for path in itertools.chain(*globs): examples.append(Example(path, category.stem)) # Find examples in subdirectories - for path in category.glob("*/*.md"): + globs = [category.glob(f"*/{pattern}") for pattern in glob_patterns] + for path in itertools.chain(*globs): examples.append(Example(path.parent, category.stem)) # Generate the example documentation @@ -217,3 +230,4 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): with open(doc_path, "w+", encoding="utf-8") as f: f.write(example.generate()) logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR)) + logger.info("Total examples generated: %d", len(examples)) diff --git a/docs/mkdocs/hooks/generate_metrics.py b/docs/mkdocs/hooks/generate_metrics.py new file mode 100644 index 0000000000000..b20d43c4b2e92 --- /dev/null +++ b/docs/mkdocs/hooks/generate_metrics.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import logging +from pathlib import Path +from typing import Literal + +logger = logging.getLogger("mkdocs") + +ROOT_DIR = Path(__file__).parent.parent.parent.parent +DOCS_DIR = ROOT_DIR / "docs" +GENERATED_METRICS_DIR = DOCS_DIR / "generated" / "metrics" + +# Files to scan for metric definitions - each will generate a separate table +METRIC_SOURCE_FILES = [ + {"path": "vllm/v1/metrics/loggers.py", "output": "general.md"}, + { + "path": "vllm/v1/spec_decode/metrics.py", + "output": "spec_decode.md", + }, + { + "path": "vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py", + "output": "nixl_connector.md", + }, +] + + +class MetricExtractor(ast.NodeVisitor): + """AST visitor to extract metric definitions.""" + + def __init__(self): + self.metrics: list[dict[str, str]] = [] + + def visit_Call(self, node: ast.Call) -> None: + """Visit function calls to find metric class instantiations.""" + metric_type = self._get_metric_type(node) + if metric_type: + name = self._extract_kwarg(node, "name") + documentation = self._extract_kwarg(node, "documentation") + + if name: + self.metrics.append( + { + "name": name, + "type": metric_type, + "documentation": documentation or "", + } + ) + + self.generic_visit(node) + + def _get_metric_type(self, node: ast.Call) -> str | None: + """Determine if this call creates a metric and return its type.""" + metric_type_map = { + "_gauge_cls": "gauge", + "_counter_cls": "counter", + "_histogram_cls": "histogram", + } + if isinstance(node.func, ast.Attribute): + return metric_type_map.get(node.func.attr) + return None + + def _extract_kwarg(self, node: ast.Call, key: str) -> str | None: + """Extract a keyword argument value from a function call.""" + for keyword in node.keywords: + if keyword.arg == key: + return self._get_string_value(keyword.value) + return None + + def _get_string_value(self, node: ast.AST) -> str | None: + """Extract string value from an AST node.""" + if isinstance(node, ast.Constant): + return str(node.value) if node.value is not None else None + return None + + +def extract_metrics_from_file(filepath: Path) -> list[dict[str, str]]: + """Parse a Python file and extract all metric definitions.""" + try: + with open(filepath, encoding="utf-8") as f: + source = f.read() + + tree = ast.parse(source, filename=str(filepath)) + extractor = MetricExtractor() + extractor.visit(tree) + return extractor.metrics + except Exception as e: + raise RuntimeError(f"Failed to parse {filepath}: {e}") from e + + +def generate_markdown_table(metrics: list[dict[str, str]]) -> str: + """Generate a markdown table from extracted metrics.""" + if not metrics: + return "No metrics found.\n" + + # Sort by type, then by name + metrics_sorted = sorted(metrics, key=lambda m: (m["type"], m["name"])) + + lines = [] + lines.append("| Metric Name | Type | Description |") + lines.append("|-------------|------|-------------|") + + for metric in metrics_sorted: + name = metric["name"] + metric_type = metric["type"].capitalize() + doc = metric["documentation"].replace("\n", " ").strip() + lines.append(f"| `{name}` | {metric_type} | {doc} |") + + return "\n".join(lines) + "\n" + + +def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): + """Generate metrics documentation tables from source files.""" + logger.info("Generating metrics documentation") + + # Create generated directory if it doesn't exist + GENERATED_METRICS_DIR.mkdir(parents=True, exist_ok=True) + + total_metrics = 0 + for source_config in METRIC_SOURCE_FILES: + source_path = source_config["path"] + output_file = source_config["output"] + + filepath = ROOT_DIR / source_path + if not filepath.exists(): + raise FileNotFoundError(f"Metrics source file not found: {filepath}") + + logger.debug("Extracting metrics from: %s", source_path) + metrics = extract_metrics_from_file(filepath) + logger.debug("Found %d metrics in %s", len(metrics), source_path) + + # Generate and write the markdown table for this source + table_content = generate_markdown_table(metrics) + output_path = GENERATED_METRICS_DIR / output_file + with open(output_path, "w", encoding="utf-8") as f: + f.write(table_content) + + total_metrics += len(metrics) + logger.info( + "Generated metrics table: %s (%d metrics)", + output_path.relative_to(ROOT_DIR), + len(metrics), + ) + + logger.info( + "Total metrics generated: %d across %d files", + total_metrics, + len(METRIC_SOURCE_FILES), + ) diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index aca865f4bf77d..b4b0150faf841 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -33,8 +33,8 @@ shown in the table below. | Architecture | `--convert` | Supported pooling tasks | |-------------------------------------------------|-------------|---------------------------------------| | `*ForTextEncoding`, `*EmbeddingModel`, `*Model` | `embed` | `token_embed`, `embed` | +| `*ForRewardModeling`, `*RewardModel` | `embed` | `token_embed`, `embed` | | `*For*Classification`, `*ClassificationModel` | `classify` | `token_classify`, `classify`, `score` | -| `*ForRewardModeling`, `*RewardModel` | `reward` | `token_classify` | !!! tip You can explicitly set `--convert <type>` to specify how to convert the model. @@ -70,7 +70,6 @@ the pooler assigned to each task has the following attributes by default: | Task | Pooling Type | Normalization | Softmax | |------------|--------------|---------------|---------| -| `reward` | `ALL` | ❌ | ❌ | | `embed` | `LAST` | ✅︎ | ❌ | | `classify` | `LAST` | ❌ | ✅︎ | @@ -274,7 +273,7 @@ outputs = llm.embed( print(outputs[0].outputs) ``` -A code example can be found here: [examples/offline_inference/pooling/embed_matryoshka_fy.py](../../examples/offline_inference/pooling/embed_matryoshka_fy.py) +A code example can be found here: [examples/pooling/embed/embed_matryoshka_fy.py](../../examples/pooling/embed/embed_matryoshka_fy.py) ### Online Inference @@ -304,7 +303,7 @@ Expected output: {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} ``` -An OpenAI client example can be found here: [examples/online_serving/pooling/openai_embedding_matryoshka_fy.py](../../examples/online_serving/pooling/openai_embedding_matryoshka_fy.py) +An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy.py) ## Deprecated Features @@ -317,4 +316,14 @@ We have split the `encode` task into two more specific token-wise tasks: `token_ ### Remove softmax from PoolingParams -We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function. +We are going to remove `softmax` and `activation` from `PoolingParams` in v0.15. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function. + +### as_reward_model + +!!! warning + We are going to remove `--convert reward` in v0.15, use `--convert embed` instead. + +Pooling models now default support all pooling, you can use it without any settings. + +- Extracting hidden states prefers using `token_embed` task. +- Reward models prefers using `token_classify` task. diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index da7c5edf66bfb..9ba0f4ca9096e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -417,7 +417,8 @@ th { | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | | `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ | -| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | +| `MistralForCausalLM` | Ministral-3, Mistral, Mistral-Instruct | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | +| `MistralLarge3ForCausalLM` | Mistral-Large-3-675B-Base-2512, Mistral-Large-3-675B-Instruct-2512 | `mistralai/Mistral-Large-3-675B-Base-2512`, `mistralai/Mistral-Large-3-675B-Instruct-2512`, etc. | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | @@ -567,7 +568,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ``` !!! note - Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/offline_inference/pooling/qwen3_reranker.py](../../examples/offline_inference/pooling/qwen3_reranker.py). + Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/offline_reranker.py](../../examples/pooling/score/offline_reranker.py). ```bash vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' @@ -580,16 +581,9 @@ These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|-------------------|----------------------|---------------------------| | `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | -| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | +| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | | `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | | `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | -| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | - -<sup>C</sup> Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion)) -\* Feature support is the same as that of the original model. - -If your model is not in the above list, we will try to automatically convert the model using -[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly. !!! important For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, @@ -605,7 +599,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) | `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | !!! note - Named Entity Recognition (NER) usage, please refer to [examples/offline_inference/pooling/ner.py](../../examples/offline_inference/pooling/ner.py), [examples/online_serving/pooling/ner_client.py](../../examples/online_serving/pooling/ner_client.py). + Named Entity Recognition (NER) usage, please refer to [examples/pooling/token_classify/ner.py](../../examples/pooling/token_classify/ner.py), [examples/pooling/token_classify/ner_client.py](../../examples/pooling/token_classify/ner_client.py). ## List of Multimodal Language Models @@ -665,7 +659,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|--------|-------------------|----------------------|---------------------------| | `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | -| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | +| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ | +| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ | +| `BagelForConditionalGeneration` | BAGEL | T + I<sup>+</sup> | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ | | `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ | | `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | @@ -710,8 +706,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | -| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | -| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | +| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ | | `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | | `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | @@ -740,23 +735,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor <sup>E</sup> Pre-computed embeddings can be inputted for this modality. <sup>+</sup> Multiple items can be inputted per text prompt for this modality. -!!! warning - Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. - However, there are differences in how they handle text + image inputs: - - V0 correctly implements the model's attention pattern: - - Uses bidirectional attention between the image tokens corresponding to the same image - - Uses causal attention for other tokens - - Implemented via (naive) PyTorch SDPA with masking tensors - - Note: May use significant memory for long prompts with image - - V1 currently uses a simplified attention pattern: - - Uses causal attention for all tokens, including image tokens - - Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}` - - Will be updated in the future to support the correct behavior - - This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. - !!! note `Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its MobileNet-v5 vision backbone. @@ -767,7 +745,7 @@ Some models are supported only via the [Transformers modeling backend](#transfor - There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups. !!! note - For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc), InternVL3 and InternVL3.5 have video inputs support currently. + For `InternVLChatModel`, only InternVL2.5 with Qwen2.5 text backbone (`OpenGVLab/InternVL2.5-1B` etc.), InternVL3 and InternVL3.5 have video inputs support currently. !!! note To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. @@ -776,9 +754,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. For more details, please see: <https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630> -!!! warning - Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. - !!! note For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported. diff --git a/docs/serving/data_parallel_deployment.md b/docs/serving/data_parallel_deployment.md index eff9c5d5e4efa..f0946eaf407a9 100644 --- a/docs/serving/data_parallel_deployment.md +++ b/docs/serving/data_parallel_deployment.md @@ -8,11 +8,11 @@ For MoE models, particularly those like DeepSeek that employ MLA (Multi-head Lat In these cases, the data parallel ranks are not completely independent. Forward passes must be aligned, and expert layers across all ranks are required to synchronize during every forward pass, even when there are fewer requests to be processed than DP ranks. -The expert layers will by default form a (DP x TP) sized tensor parallel group. To enable expert parallelism, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case). +By default, expert layers form a tensor parallel group of size `DP × TP`. To use expert parallelism instead, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case). See [Expert Parallel Deployment](expert_parallel_deployment.md) for details on how attention and expert layers behave differently with EP enabled. In vLLM, each DP rank is deployed as a separate "core engine" process that communicates with front-end process(es) via ZMQ sockets. Data Parallel attention can be combined with Tensor Parallel attention, in which case each DP engine owns a number of per-GPU worker processes equal to the configured TP size. -For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form an EP or TP group of size (DP x TP). +For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form a group of size `DP × TP` (using either tensor parallelism by default, or expert parallelism if `--enable-expert-parallel` is set). In all cases, it is beneficial to load-balance requests between DP ranks. For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently. @@ -24,7 +24,7 @@ There are two distinct modes supported for online deployments - self-contained w vLLM supports "self-contained" data parallel deployments that expose a single API endpoint. -It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs. +It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs. When sizing DP deployments, remember that `--max-num-seqs` applies per DP rank. Running a single data parallel deployment across multiple nodes requires a different `vllm serve` to be run on each node, specifying which DP ranks should run on that node. In this case, there will still be a single HTTP entrypoint - the API server(s) will run only on one node, but it doesn't necessarily need to be co-located with the DP ranks. @@ -80,6 +80,18 @@ When deploying large DP sizes using this method, the API server process can beco ![DP Internal LB Diagram](../assets/deployment/dp_internal_lb.png) </figure> +## Hybrid Load Balancing + +Hybrid load balancing sits between the internal and external approaches. Each node runs its own API server(s) that only queue requests to the data-parallel engines colocated on that node. An upstream load balancer (for example, an ingress controller or traffic router) spreads user requests across those per-node endpoints. + +Enable this mode with `--data-parallel-hybrid-lb` while still launching every node with the global data-parallel size. The key differences from internal load balancing are: + +- You must provide `--data-parallel-size-local` and `--data-parallel-start-rank` so each node knows which ranks it owns. +- Not compatible with `--headless` since every node exposes an API endpoint. +- Scale `--api-server-count` per node based on the number of local ranks + +In this configuration, each node keeps scheduling decisions local, which reduces cross-node traffic and avoids single node bottlenecks at larger DP sizes. + ## External Load Balancing For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally. diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index ec07896592ba3..82fde27d71fd4 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -40,10 +40,32 @@ EP_SIZE = TP_SIZE × DP_SIZE Where: -- `TP_SIZE`: Tensor parallel size (always 1 for now) +- `TP_SIZE`: Tensor parallel size - `DP_SIZE`: Data parallel size - `EP_SIZE`: Expert parallel size (computed automatically) +### Layer Behavior with EP Enabled + +When EP is enabled, different layers in MoE models behave differently: + +| Layer Type | Behavior | Parallelism Used | +|------------|----------|------------------| +| **Expert (MoE) Layers** | Sharded across all EP ranks | Expert Parallel (EP) of size `TP × DP` | +| **Attention Layers** | Behavior depends on TP size | See below | + +**Attention layer parallelism:** + +- **When `TP = 1`**: Attention weights are **replicated** across all DP ranks (data parallelism) +- **When `TP > 1`**: Attention weights are **sharded** using tensor parallelism across TP ranks within each DP group + +For example, with `TP=2, DP=4` (8 GPUs total): + +- Expert layers form an EP group of size 8, with experts distributed across all GPUs +- Attention layers use TP=2 within each of the 4 DP groups + +!!! note "Key Difference from Data Parallel Deployment" + Without `--enable-expert-parallel`, MoE layers would use tensor parallelism (forming a TP group of size `TP × DP`), similar to dense models. With EP enabled, expert layers switch to expert parallelism, which can provide better efficiency and locality for MoE models. + ### Example Command The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section. @@ -81,7 +103,7 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \ --data-parallel-size-local 8 \ # Local DP size on this node (8 GPUs per node) --data-parallel-address 192.168.1.100 \ # Replace with actual IP of Node 1 --data-parallel-rpc-port 13345 \ # RPC communication port, can be any port as long as reachable by all nodes - --api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended) + --api-server-count=8 # Number of API servers for load handling (scaling this out to # local ranks is recommended) # Node 2 (Secondary - headless mode, no API server) vllm serve deepseek-ai/DeepSeek-V3-0324 \ @@ -119,9 +141,6 @@ While MoE models are typically trained so that each expert receives a similar nu Enable EPLB with the `--enable-eplb` flag. -!!! note "Model Support" - Currently only DeepSeek V3 architecture is supported. - When enabled, vLLM collects load statistics with every forward pass and periodically rebalances expert distribution. ### EPLB Parameters @@ -134,6 +153,8 @@ Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. T | `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 | | `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` | | `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` | +| `use_async` | Use non-blocking EPLB for reduced latency overhead | `false` | +| `policy` | The policy type for expert parallel load balancing | `"default"` | For example: @@ -183,6 +204,26 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \ For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available. +## Advanced Configuration + +### Performance Optimization + +- **DeepEP kernels**: The `high_throughput` and `low_latency` kernels are optimized for disaggregated serving and may show poor performance for mixed workloads +- **Dual Batch Overlap**: Use `--enable-dbo` to overlap all-to-all communication with compute. See [Dual Batch Overlap](../design/dbo.md) for more details. +- **Async scheduling (experimental)**: Try `--async-scheduling` to overlap scheduling with model execution. + +### Troubleshooting + +- **`non-zero status: 7 cannot register cq buf`**: When using Infiniband/RoCE, make sure host VM and pods show `ulimit -l` "unlimited". +- **`init failed for transport: IBGDA`**: The InfiniBand GDA kernel modules are missing. Run `tools/ep_kernels/configure_system_drivers.sh` on each GPU node and reboot. Also fixes error `NVSHMEM API called before NVSHMEM initialization has completed`. +- **NVSHMEM peer disconnect**: Usually a networking misconfiguration. If deploying via Kubernetes, verify that every pod runs with `hostNetwork: true`, `securityContext.privileged: true` to access Infiniband. + +### Benchmarking + +- Use simulator flags `VLLM_MOE_ROUTING_SIMULATION_STRATEGY=uniform_random` and `VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1` so token routing is balanced across EP ranks. + +- Increasing `VLLM_MOE_DP_CHUNK_SIZE` may increase throughput by increasing the maximum batch size for inter-rank token transfers. This may cause DeepEP to throw `assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2`, which can be fixed by increasing environment variable `NVSHMEM_QP_DEPTH`. + ## Disaggregated Serving (Prefill/Decode Split) For production deployments requiring strict SLA guarantees for time-to-first-token and inter-token latency, disaggregated serving allows independent scaling of prefill and decode operations. @@ -273,3 +314,9 @@ except Exception as e: print(f"❌ Error during disaggregated serving: {e}") print("Check that both prefill and decode instances are running and accessible") ``` + +### Benchmarking + +- To simulate the decode deployment of disaggregated serving, pass `--kv-transfer-config '{"kv_connector":"DecodeBenchConnector","kv_role":"kv_both"}'` to the `vllm serve` invocation. The connector populates KV cache with random values so decode can be profiled in isolation. + +- **CUDAGraph capture**: Use `--compilation_config '{"cudagraph_mode": "FULL_DECODE_ONLY"}'` to enable CUDA graph capture for decode only and save KV cache. diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index ac98efb7b88a6..0e29204f8947c 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -234,7 +234,7 @@ The following extra parameters are supported: Our Embeddings API is compatible with [OpenAI's Embeddings API](https://platform.openai.com/docs/api-reference/embeddings); you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. -Code example: [examples/online_serving/pooling/openai_embedding_client.py](../../examples/online_serving/pooling/openai_embedding_client.py) +Code example: [examples/pooling/embed/openai_embedding_client.py](../../examples/pooling/embed/openai_embedding_client.py) If the model has a [chat template](../serving/openai_compatible_server.md#chat-template), you can replace `inputs` with a list of `messages` (same schema as [Chat API](#chat-api)) which will be treated as a single prompt to the model. Here is a convenience function for calling the API while retaining OpenAI's type annotations: @@ -335,7 +335,7 @@ and passing a list of `messages` in the request. Refer to the examples below for `MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code example below for details. -Full example: [examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py](../../examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py) +Full example: [examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py](../../examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py) #### Extra parameters @@ -456,6 +456,7 @@ For `verbose_json` response format: ] } ``` +Currently “verbose_json” response format doesn’t support avg_logprob, compression_ratio, no_speech_prob. #### Extra Parameters @@ -515,7 +516,7 @@ Our Pooling API encodes input prompts using a [pooling model](../models/pooling_ The input format is the same as [Embeddings API](#embeddings-api), but the output data can contain an arbitrary nested list, not just a 1-D list of floats. -Code example: [examples/online_serving/pooling/openai_pooling_client.py](../../examples/online_serving/pooling/openai_pooling_client.py) +Code example: [examples/pooling/pooling/openai_pooling_client.py](../../examples/pooling/pooling/openai_pooling_client.py) ### Classification API @@ -523,7 +524,7 @@ Our Classification API directly supports Hugging Face sequence-classification mo We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. -Code example: [examples/online_serving/pooling/openai_classification_client.py](../../examples/online_serving/pooling/openai_classification_client.py) +Code example: [examples/pooling/classify/openai_classification_client.py](../../examples/pooling/classify/openai_classification_client.py) #### Example Requests @@ -639,7 +640,7 @@ Usually, the score for a sentence pair refers to the similarity between two sent You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). -Code example: [examples/online_serving/pooling/openai_cross_encoder_score.py](../../examples/online_serving/pooling/openai_cross_encoder_score.py) +Code example: [examples/pooling/score/openai_cross_encoder_score.py](../../examples/pooling/score/openai_cross_encoder_score.py) #### Single inference @@ -820,7 +821,7 @@ You can pass multi-modal inputs to scoring models by passing `content` including print("Scoring output:", response_json["data"][0]["score"]) print("Scoring output:", response_json["data"][1]["score"]) ``` -Full example: [examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py](../../examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py) +Full example: [examples/pooling/score/openai_cross_encoder_score_for_multimodal.py](../../examples/pooling/score/openai_cross_encoder_score_for_multimodal.py) #### Extra parameters @@ -850,7 +851,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin [Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with popular open-source tools. -Code example: [examples/online_serving/pooling/jinaai_rerank_client.py](../../examples/online_serving/pooling/jinaai_rerank_client.py) +Code example: [examples/pooling/score/openai_reranker.py](../../examples/pooling/score/openai_reranker.py) #### Example Request diff --git a/docs/serving/parallelism_scaling.md b/docs/serving/parallelism_scaling.md index a32840ea73b9a..ed93432701f35 100644 --- a/docs/serving/parallelism_scaling.md +++ b/docs/serving/parallelism_scaling.md @@ -62,7 +62,7 @@ If a single node lacks sufficient GPUs to hold the model, deploy vLLM across mul ### What is Ray? -Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments require Ray as the runtime engine. +Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments can use Ray as the runtime engine. vLLM uses Ray to manage the distributed execution of tasks across multiple nodes and control where execution happens. @@ -130,9 +130,31 @@ vllm serve /path/to/the/model/in/the/container \ --distributed-executor-backend ray ``` +### Running vLLM with MultiProcessing + +Besides Ray, Multi-node vLLM deployments can also use `multiprocessing` as the runtime engine. Here's an example to deploy model across 2 nodes (8 GPUs per node) with `tp_size=8` and `pp_size=2`. + +Choose one node as the head node and run: + +```bash +vllm serve /path/to/the/model/in/the/container \ + --tensor-parallel-size 8 --pipeline-parallel-size 2 \ + --nnodes 2 --node-rank 0 \ + --master-addr <HEAD_NODE_IP> +``` + +On the other worker node, run: + +```bash +vllm serve /path/to/the/model/in/the/container \ + --tensor-parallel-size 8 --pipeline-parallel-size 2 \ + --nnodes 2 --node-rank 1 \ + --master-addr <HEAD_NODE_IP> --headless +``` + ## Optimizing network communication for tensor parallelism -Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand. +Efficient tensor parallelism requires fast internode communication, preferably through high-speed network adapters such as InfiniBand. To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the [examples/online_serving/run_cluster.sh](../../examples/online_serving/run_cluster.sh) helper script. Contact your system administrator for more information about the required flags. diff --git a/docs/usage/metrics.md b/docs/usage/metrics.md index d756e32476f0a..829533b84328f 100644 --- a/docs/usage/metrics.md +++ b/docs/usage/metrics.md @@ -33,11 +33,19 @@ Then query the endpoint to get the latest metrics from the server: The following metrics are exposed: -??? code +## General Metrics - ```python - --8<-- "vllm/engine/metrics.py:metrics-definitions" - ``` +--8<-- "docs/generated/metrics/general.md" + +## Speculative Decoding Metrics + +--8<-- "docs/generated/metrics/spec_decode.md" + +## NIXL KV Connector Metrics + +--8<-- "docs/generated/metrics/nixl_connector.md" + +## Deprecation Policy Note: when metrics are deprecated in version `X.Y`, they are hidden in version `X.Y+1` but can be re-enabled using the `--show-hidden-metrics-for-version=X.Y` escape hatch, diff --git a/docs/usage/security.md b/docs/usage/security.md index 9d10b66a5a97f..e619eec660aee 100644 --- a/docs/usage/security.md +++ b/docs/usage/security.md @@ -10,7 +10,7 @@ All communications between nodes in a multi-node vLLM deployment are **insecure ### Configuration Options for Inter-Node Communications -The following options control inter-node communications in vLLM: +The following options control internode communications in vLLM: #### 1. **Environment Variables:** @@ -28,7 +28,7 @@ The following options control inter-node communications in vLLM: ### Notes on PyTorch Distributed -vLLM uses PyTorch's distributed features for some inter-node communication. For +vLLM uses PyTorch's distributed features for some internode communication. For detailed information about PyTorch Distributed security considerations, please refer to the [PyTorch Security Guide](https://github.com/pytorch/pytorch/security/policy#using-distributed-features). @@ -108,6 +108,116 @@ networks. Consult your operating system or application platform documentation for specific firewall configuration instructions. +## API Key Authentication Limitations + +### Overview + +The `--api-key` flag (or `VLLM_API_KEY` environment variable) provides authentication for vLLM's HTTP server, but **only for OpenAI-compatible API endpoints under the `/v1` path prefix**. Many other sensitive endpoints are exposed on the same HTTP server without any authentication enforcement. + +**Important:** Do not rely exclusively on `--api-key` for securing access to vLLM. Additional security measures are required for production deployments. + +### Protected Endpoints (Require API Key) + +When `--api-key` is configured, the following `/v1` endpoints require Bearer token authentication: + +- `/v1/models` - List available models +- `/v1/chat/completions` - Chat completions +- `/v1/completions` - Text completions +- `/v1/embeddings` - Generate embeddings +- `/v1/audio/transcriptions` - Audio transcription +- `/v1/audio/translations` - Audio translation +- `/v1/messages` - Anthropic-compatible messages API +- `/v1/responses` - Response management +- `/v1/score` - Scoring API +- `/v1/rerank` - Reranking API + +### Unprotected Endpoints (No API Key Required) + +The following endpoints **do not require authentication** even when `--api-key` is configured: + +**Inference endpoints:** + +- `/invocations` - SageMaker-compatible endpoint (routes to the same inference functions as `/v1` endpoints) +- `/inference/v1/generate` - Generate completions +- `/pooling` - Pooling API +- `/classify` - Classification API +- `/score` - Scoring API (non-`/v1` variant) +- `/rerank` - Reranking API (non-`/v1` variant) + +**Operational control endpoints (always enabled):** + +- `/pause` - Pause generation (causes denial of service) +- `/resume` - Resume generation +- `/scale_elastic_ep` - Trigger scaling operations + +**Utility endpoints:** + +- `/tokenize` - Tokenize text +- `/detokenize` - Detokenize tokens +- `/health` - Health check +- `/ping` - SageMaker health check +- `/version` - Version information +- `/load` - Server load metrics + +**Tokenizer information endpoint (only when `--enable-tokenizer-info-endpoint` is set):** + +This endpoint is **only available when the `--enable-tokenizer-info-endpoint` flag is set**. It may expose sensitive information such as chat templates and tokenizer configuration: + +- `/tokenizer_info` - Get comprehensive tokenizer information including chat templates and configuration + +**Development endpoints (only when `VLLM_SERVER_DEV_MODE=1`):** + +These endpoints are **only available when the environment variable `VLLM_SERVER_DEV_MODE` is set to `1`**. They are intended for development and debugging purposes and should never be enabled in production: + +- `/server_info` - Get detailed server configuration +- `/reset_prefix_cache` - Reset prefix cache (can disrupt service) +- `/reset_mm_cache` - Reset multimodal cache (can disrupt service) +- `/sleep` - Put engine to sleep (causes denial of service) +- `/wake_up` - Wake engine from sleep +- `/is_sleeping` - Check if engine is sleeping +- `/collective_rpc` - Execute arbitrary RPC methods on the engine (extremely dangerous) + +**Profiler endpoints (only when `VLLM_TORCH_PROFILER_DIR` or `VLLM_TORCH_CUDA_PROFILE` are set):** + +These endpoints are only available when profiling is enabled and should only be used for local development: + +- `/start_profile` - Start PyTorch profiler +- `/stop_profile` - Stop PyTorch profiler + +**Note:** The `/invocations` endpoint is particularly concerning as it provides unauthenticated access to the same inference capabilities as the protected `/v1` endpoints. + +### Security Implications + +An attacker who can reach the vLLM HTTP server can: + +1. **Bypass authentication** by using non-`/v1` endpoints like `/invocations`, `/inference/v1/generate`, `/pooling`, `/classify`, `/score`, or `/rerank` to run arbitrary inference without credentials +2. **Cause denial of service** by calling `/pause` or `/scale_elastic_ep` without a token +3. **Access operational controls** to manipulate server state (e.g., pausing generation) +4. **If `--enable-tokenizer-info-endpoint` is set:** Access sensitive tokenizer configuration including chat templates, which may reveal prompt engineering strategies or other implementation details +5. **If `VLLM_SERVER_DEV_MODE=1` is set:** Execute arbitrary RPC commands via `/collective_rpc`, reset caches, put the engine to sleep, and access detailed server configuration + +### Recommended Security Practices + +#### 1. Minimize Exposed Endpoints + +**CRITICAL:** Never set `VLLM_SERVER_DEV_MODE=1` in production environments. Development endpoints expose extremely dangerous functionality including: + +- Arbitrary RPC execution via `/collective_rpc` +- Cache manipulation that can disrupt service +- Detailed server configuration disclosure + +Similarly, never enable profiler endpoints (`VLLM_TORCH_PROFILER_DIR` or `VLLM_TORCH_CUDA_PROFILE`) in production. + +**Be cautious with `--enable-tokenizer-info-endpoint`:** Only enable the `/tokenizer_info` endpoint if you need to expose tokenizer configuration information. This endpoint reveals chat templates and tokenizer settings that may contain sensitive implementation details or prompt engineering strategies. + +#### 2. Deploy Behind a Reverse Proxy + +The most effective approach is to deploy vLLM behind a reverse proxy (such as nginx, Envoy, or a Kubernetes Gateway) that: + +- Explicitly allowlists only the endpoints you want to expose to end users +- Blocks all other endpoints, including the unauthenticated inference and operational control endpoints +- Implements additional authentication, rate limiting, and logging at the proxy layer + ## Reporting Security Vulnerabilities If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md). diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index df6e96ca375fc..a6d0c5d12dd41 100755 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -42,60 +42,31 @@ class ModelRequestData(NamedTuple): # Unless specified, these settings have been tested to work on a single L4. -# Voxtral -# Make sure to install mistral-common[audio]. -def run_voxtral(question: str, audio_count: int) -> ModelRequestData: - from mistral_common.audio import Audio - from mistral_common.protocol.instruct.chunk import ( - AudioChunk, - RawAudio, - TextChunk, - ) - from mistral_common.protocol.instruct.messages import ( - UserMessage, - ) - from mistral_common.protocol.instruct.request import ChatCompletionRequest - from mistral_common.tokens.tokenizers.mistral import MistralTokenizer - - model_name = "mistralai/Voxtral-Mini-3B-2507" - tokenizer = MistralTokenizer.from_hf_hub(model_name) - +# AudioFlamingo3 +def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData: + model_name = "nvidia/audio-flamingo-3-hf" engine_args = EngineArgs( model=model_name, - max_model_len=8192, + max_model_len=4096, max_num_seqs=2, limit_mm_per_prompt={"audio": audio_count}, - config_format="mistral", - load_format="mistral", - tokenizer_mode="mistral", enforce_eager=True, - enable_chunked_prefill=False, ) - text_chunk = TextChunk(text=question) - audios = [ - Audio.from_file(str(audio_assets[i].get_local_path()), strict=False) - for i in range(audio_count) - ] - audio_chunks = [ - AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios - ] + # AudioFlamingo3 uses <sound> token for audio + audio_placeholder = "<sound>" * audio_count - messages = [UserMessage(content=[*audio_chunks, text_chunk])] - - req = ChatCompletionRequest(messages=messages, model=model_name) - - tokens = tokenizer.encode_chat_completion(req) - prompt_ids, audios = tokens.tokens, tokens.audios - - audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios] - - multi_modal_data = {"audio": audios_and_sr} + prompt = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_placeholder}{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) return ModelRequestData( engine_args=engine_args, - prompt_token_ids=prompt_ids, - multi_modal_data=multi_modal_data, + prompt=prompt, ) @@ -361,6 +332,63 @@ def run_ultravox(question: str, audio_count: int) -> ModelRequestData: ) +# Voxtral +# Make sure to install mistral-common[audio]. +def run_voxtral(question: str, audio_count: int) -> ModelRequestData: + from mistral_common.audio import Audio + from mistral_common.protocol.instruct.chunk import ( + AudioChunk, + RawAudio, + TextChunk, + ) + from mistral_common.protocol.instruct.messages import ( + UserMessage, + ) + from mistral_common.protocol.instruct.request import ChatCompletionRequest + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + + model_name = "mistralai/Voxtral-Mini-3B-2507" + tokenizer = MistralTokenizer.from_hf_hub(model_name) + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={"audio": audio_count}, + config_format="mistral", + load_format="mistral", + tokenizer_mode="mistral", + enforce_eager=True, + enable_chunked_prefill=False, + ) + + text_chunk = TextChunk(text=question) + audios = [ + Audio.from_file(str(audio_assets[i].get_local_path()), strict=False) + for i in range(audio_count) + ] + audio_chunks = [ + AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios + ] + + messages = [UserMessage(content=[*audio_chunks, text_chunk])] + + req = ChatCompletionRequest(messages=messages, model=model_name) + + tokens = tokenizer.encode_chat_completion(req) + prompt_ids, audios = tokens.tokens, tokens.audios + + audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios] + + multi_modal_data = {"audio": audios_and_sr} + + return ModelRequestData( + engine_args=engine_args, + prompt_token_ids=prompt_ids, + multi_modal_data=multi_modal_data, + ) + + # Whisper def run_whisper(question: str, audio_count: int) -> ModelRequestData: assert audio_count == 1, "Whisper only support single audio input per prompt" @@ -382,7 +410,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: model_example_map = { - "voxtral": run_voxtral, + "audioflamingo3": run_audioflamingo3, "gemma3n": run_gemma3n, "granite_speech": run_granite_speech, "midashenglm": run_midashenglm, @@ -392,6 +420,7 @@ model_example_map = { "qwen2_audio": run_qwen2_audio, "qwen2_5_omni": run_qwen2_5_omni, "ultravox": run_ultravox, + "voxtral": run_voxtral, "whisper": run_whisper, } @@ -422,7 +451,7 @@ def parse_args(): parser.add_argument( "--seed", type=int, - default=None, + default=0, help="Set the seed when initializing `vllm.LLM`.", ) parser.add_argument( diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index eeb7137ff7bae..17f727b33d321 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -4,6 +4,9 @@ from argparse import Namespace from vllm import LLM, EngineArgs +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config import AttentionConfig +from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -20,6 +23,11 @@ def parse_args(): def main(args: Namespace): + if current_platform.is_rocm(): + args.attention_config = AttentionConfig( + backend=AttentionBackendEnum.FLEX_ATTENTION + ) + # Sample prompts. prompts = [ "Hello, my name is", diff --git a/examples/offline_inference/basic/score.py b/examples/offline_inference/basic/score.py index cbca50eb5efa8..b2dadffd249f5 100644 --- a/examples/offline_inference/basic/score.py +++ b/examples/offline_inference/basic/score.py @@ -4,6 +4,9 @@ from argparse import Namespace from vllm import LLM, EngineArgs +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config import AttentionConfig +from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -20,6 +23,11 @@ def parse_args(): def main(args: Namespace): + if current_platform.is_rocm(): + args.attention_config = AttentionConfig( + backend=AttentionBackendEnum.FLEX_ATTENTION + ) + # Sample prompts. text_1 = "What is the capital of France?" texts_2 = [ diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 0b281fc41a341..be0b846995a92 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -33,6 +33,7 @@ import os from time import sleep from vllm import LLM, SamplingParams +from vllm.platforms import current_platform from vllm.utils.network_utils import get_open_port @@ -222,6 +223,11 @@ if __name__ == "__main__": from multiprocessing import Process + if current_platform.is_rocm(): + from multiprocessing import set_start_method + + set_start_method("spawn", force=True) + procs = [] for local_dp_rank, global_dp_rank in enumerate( range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node) diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py index 8f3d1a5c00369..2d575840e6a71 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py @@ -30,7 +30,7 @@ def main(): max_num_batched_tokens=64, max_num_seqs=16, kv_transfer_config=KVTransferConfig( - kv_connector="SharedStorageConnector", + kv_connector="ExampleConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ), diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py index 0bfe7ec0e6cf6..207c6daebc2f5 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py @@ -26,7 +26,7 @@ def main(): enforce_eager=True, gpu_memory_utilization=0.8, kv_transfer_config=KVTransferConfig( - kv_connector="SharedStorageConnector", + kv_connector="ExampleConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ), diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index c1d6c6db53dfb..857767ac3c628 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -77,7 +77,7 @@ def parse_args(): parser.add_argument( "--seed", type=int, - default=None, + default=0, help="Set the seed when initializing `vllm.LLM`.", ) return parser.parse_args() diff --git a/examples/offline_inference/kv_load_failure_recovery/README.md b/examples/offline_inference/kv_load_failure_recovery/README.md index 230a16812b25e..1f29a6ff56dbc 100644 --- a/examples/offline_inference/kv_load_failure_recovery/README.md +++ b/examples/offline_inference/kv_load_failure_recovery/README.md @@ -10,7 +10,7 @@ It demonstrates vLLM's ability to recover from KV load failures in both synchron - `decode_example.py` – performs the decode stage. Accepts: - `--simulate-failure`: simulates KV load failure using a custom connector. - `--async-load`: enables asynchronous KV loading mode. -- `rogue_shared_storage_connector.py` – defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request. +- `load_recovery_example_connector.py` – defines `LoadRecoveryExampleConnector`, a subclass of `ExampleConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request. - `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages: 1. Normal decode (baseline). 2. Decode with simulated sync KV load failure. @@ -20,7 +20,7 @@ It demonstrates vLLM's ability to recover from KV load failures in both synchron ## How It Works -- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector. +- The test dynamically loads `LoadRecoveryExampleConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector. - The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode. - If recovery fails, the script prints a unified diff of the output mismatch and exits with error. diff --git a/examples/offline_inference/kv_load_failure_recovery/decode_example.py b/examples/offline_inference/kv_load_failure_recovery/decode_example.py index 69523f56eace3..d0df54167aeac 100644 --- a/examples/offline_inference/kv_load_failure_recovery/decode_example.py +++ b/examples/offline_inference/kv_load_failure_recovery/decode_example.py @@ -35,13 +35,13 @@ def main(): if args.simulate_failure: ktc = KVTransferConfig( - kv_connector="RogueSharedStorageConnector", + kv_connector="LoadRecoveryExampleConnector", kv_role="kv_both", kv_connector_extra_config={ "shared_storage_path": "local_storage", "async_load": args.async_load, }, - kv_connector_module_path="rogue_shared_storage_connector", + kv_connector_module_path="load_recovery_example_connector", ) out_file = ( "async_decode_recovered_output.txt" @@ -50,7 +50,7 @@ def main(): ) else: ktc = KVTransferConfig( - kv_connector="SharedStorageConnector", + kv_connector="ExampleConnector", kv_role="kv_both", kv_connector_extra_config={ "shared_storage_path": "local_storage", diff --git a/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py b/examples/offline_inference/kv_load_failure_recovery/load_recovery_example_connector.py similarity index 88% rename from examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py rename to examples/offline_inference/kv_load_failure_recovery/load_recovery_example_connector.py index 5b2acea4c9457..7aab07f8a2c33 100644 --- a/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py +++ b/examples/offline_inference/kv_load_failure_recovery/load_recovery_example_connector.py @@ -10,9 +10,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata, KVConnectorRole, ) -from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( - SharedStorageConnector, - SharedStorageConnectorMetadata, +from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( + ExampleConnector, + ExampleConnectorMetadata, ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -26,15 +26,15 @@ logging.basicConfig(level=logging.INFO) @dataclass -class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata): +class LoadRecoveryExampleConnectorMetadata(ExampleConnectorMetadata): req_to_block_ids: dict[str, set[int]] = field(default_factory=dict) @classmethod - def from_base(cls, base: SharedStorageConnectorMetadata): + def from_base(cls, base: ExampleConnectorMetadata): return cls(requests=base.requests) -class RogueSharedStorageConnector(SharedStorageConnector): +class LoadRecoveryExampleConnector(ExampleConnector): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._async_load = vllm_config.kv_transfer_config.get_from_extra_config( @@ -45,7 +45,7 @@ class RogueSharedStorageConnector(SharedStorageConnector): self._req_to_block_ids: dict[str, list[int]] = dict() def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: - assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata) + assert isinstance(connector_metadata, LoadRecoveryExampleConnectorMetadata) index, failed_request = next( ( (i, x) @@ -84,7 +84,7 @@ class RogueSharedStorageConnector(SharedStorageConnector): ) -> tuple[set[str] | None, set[str] | None]: if self._async_load: meta = self._get_connector_metadata() - assert isinstance(meta, RogueSharedStorageConnectorMetadata) + assert isinstance(meta, LoadRecoveryExampleConnectorMetadata) if meta.req_to_block_ids: return None, set(meta.req_to_block_ids) @@ -126,9 +126,9 @@ class RogueSharedStorageConnector(SharedStorageConnector): ) -> KVConnectorMetadata: if not self._async_load: base = super().build_connector_meta(scheduler_output) - meta = RogueSharedStorageConnectorMetadata.from_base(base) + meta = LoadRecoveryExampleConnectorMetadata.from_base(base) else: - meta = RogueSharedStorageConnectorMetadata() + meta = LoadRecoveryExampleConnectorMetadata() if self._requests_need_load: for req_id, request in self._requests_need_load.items(): meta.add_request( diff --git a/examples/offline_inference/kv_load_failure_recovery/prefill_example.py b/examples/offline_inference/kv_load_failure_recovery/prefill_example.py index 047b81c82df53..ee4a84fd95003 100644 --- a/examples/offline_inference/kv_load_failure_recovery/prefill_example.py +++ b/examples/offline_inference/kv_load_failure_recovery/prefill_example.py @@ -26,7 +26,7 @@ def main(): enforce_eager=True, gpu_memory_utilization=0.8, kv_transfer_config=KVTransferConfig( - kv_connector="SharedStorageConnector", + kv_connector="ExampleConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ), diff --git a/examples/offline_inference/llm_engine_reset_kv.py b/examples/offline_inference/llm_engine_reset_kv.py new file mode 100644 index 0000000000000..3fbe7fa7545e6 --- /dev/null +++ b/examples/offline_inference/llm_engine_reset_kv.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file demonstrates preempt requests when using the `LLMEngine` +for processing prompts with various sampling parameters. +""" + +import argparse + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def create_test_prompts() -> list[tuple[str, SamplingParams]]: + """Create a list of test prompts with their sampling parameters.""" + return [ + ( + "A robot may not injure a human being " * 50, + SamplingParams( + temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16 + ), + ), + ( + "A robot may not injure a human being " * 50, + SamplingParams( + temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16 + ), + ), + ( + "To be or not to be,", + SamplingParams( + temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128 + ), + ), + ( + "What is the meaning of life?", + SamplingParams( + n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1, max_tokens=128 + ), + ), + ] + + +def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + print("-" * 50) + step_id = 0 + while test_prompts or engine.has_unfinished_requests(): + print("-" * 50) + import os + + print(f"Step {step_id} (pid={os.getpid()})") + + if test_prompts: + prompt, sampling_params = test_prompts.pop(0) + engine.add_request(str(request_id), prompt, sampling_params) + request_id += 1 + + if step_id == 10: + print(f"Resetting prefix cache at {step_id}") + engine.reset_prefix_cache(reset_running_requests=True) + + request_outputs: list[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + print("-" * 50) + print(request_output) + print("-" * 50) + step_id += 1 + + +def initialize_engine(args: argparse.Namespace) -> LLMEngine: + """Initialize the LLMEngine from the command line arguments.""" + engine_args = EngineArgs.from_cli_args(args) + return LLMEngine.from_engine_args(engine_args) + + +def parse_args(): + parser = FlexibleArgumentParser( + description="Demo on using the LLMEngine class directly" + ) + parser = EngineArgs.add_cli_args(parser) + return parser.parse_args() + + +def main(args: argparse.Namespace): + """Main function that sets up and runs the prompt processing.""" + engine = initialize_engine(args) + test_prompts = create_test_prompts() + process_requests(engine, test_prompts) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py index dc5c6202fa57b..2f3564b597556 100644 --- a/examples/offline_inference/lora_with_quantization_inference.py +++ b/examples/offline_inference/lora_with_quantization_inference.py @@ -23,31 +23,23 @@ def create_test_prompts( # this is an example of using quantization without LoRA ( "My name is", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), None, ), # the next three examples use quantization with LoRA ( "my name is", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("lora-test-1", 1, lora_path), ), ( "The capital of USA is", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("lora-test-2", 1, lora_path), ), ( "The capital of France is", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("lora-test-3", 1, lora_path), ), ] diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py index 5e5da2c0144c9..92021f9fb226c 100644 --- a/examples/offline_inference/multilora_inference.py +++ b/examples/offline_inference/multilora_inference.py @@ -27,9 +27,7 @@ def create_test_prompts( return [ ( "A robot may not injure a human being", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), None, ), ( @@ -41,22 +39,12 @@ def create_test_prompts( ), ( "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 - SamplingParams( - temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128, - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("sql-lora", 1, lora_path), ), ( "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 - SamplingParams( - temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128, - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("sql-lora2", 2, lora_path), ), ] diff --git a/examples/offline_inference/pooling/README.md b/examples/offline_inference/pooling/README.md deleted file mode 100644 index ad78be38716b6..0000000000000 --- a/examples/offline_inference/pooling/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# Pooling models - -## Convert llm model to seq cls - -```bash -# for BAAI/bge-reranker-v2-gemma -# Caution: "Yes" and "yes" are two different tokens -python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls -# for mxbai-rerank-v2 -python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls -# for Qwen3-Reranker -python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls -``` - -## Embed jina_embeddings_v3 usage - -Only text matching task is supported for now. See <https://github.com/vllm-project/vllm/pull/16120> - -```bash -python examples/offline_inference/pooling/embed_jina_embeddings_v3.py -``` - -## Embed matryoshka dimensions usage - -```bash -python examples/offline_inference/pooling/embed_matryoshka_fy.py -``` - -## Multi vector retrieval usage - -```bash -python examples/offline_inference/pooling/multi_vector_retrieval.py -``` - -## Named Entity Recognition (NER) usage - -```bash -python examples/offline_inference/pooling/ner.py -``` - -## Prithvi Geospatial MAE usage - -```bash -python examples/offline_inference/pooling/prithvi_geospatial_mae.py -``` - -## IO Processor Plugins for Prithvi Geospatial MAE - -```bash -python examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py -``` - -## Qwen3 reranker usage - -```bash -python examples/offline_inference/pooling/qwen3_reranker.py -``` diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index ed005e6a69b80..cee83519fadcc 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -158,7 +158,7 @@ def parse_args(): parser.add_argument( "--seed", type=int, - default=None, + default=0, help="Set the seed when initializing `vllm.LLM`.", ) diff --git a/examples/offline_inference/qwen3_omni/only_thinker.py b/examples/offline_inference/qwen3_omni/only_thinker.py index 88a61ed694c2e..62131633da8aa 100644 --- a/examples/offline_inference/qwen3_omni/only_thinker.py +++ b/examples/offline_inference/qwen3_omni/only_thinker.py @@ -158,7 +158,7 @@ def parse_args(): parser.add_argument( "--seed", type=int, - default=None, + default=0, help="Set the seed when initializing `vllm.LLM`.", ) diff --git a/examples/offline_inference/simple_profiling.py b/examples/offline_inference/simple_profiling.py index 46858fffadc52..e8a75cd03befb 100644 --- a/examples/offline_inference/simple_profiling.py +++ b/examples/offline_inference/simple_profiling.py @@ -1,14 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import time from vllm import LLM, SamplingParams -# enable torch profiler, can also be set on cmd line -os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile" - # Sample prompts. prompts = [ "Hello, my name is", @@ -22,7 +18,14 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) + llm = LLM( + model="facebook/opt-125m", + tensor_parallel_size=1, + profiler_config={ + "profiler": "torch", + "torch_profiler_dir": "./vllm_profile", + }, + ) llm.start_profile() diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 8f72bf6f0b0d1..dd5b22ae9b0f6 100755 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -72,7 +72,7 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData: # Aya Vision def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" - model_name = "CohereForAI/aya-vision-8b" + model_name = "CohereLabs/aya-vision-8b" engine_args = EngineArgs( model=model_name, @@ -118,6 +118,32 @@ def run_bee(questions: list[str], modality: str) -> ModelRequestData: ) +def run_bagel(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "ByteDance-Seed/BAGEL-7B-MoT" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={modality: 1}, + ) + + prompts = [ + ( + f"<|im_start|>user\n<|image_pad|>\n{question}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # BLIP-2 def run_blip2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1801,7 +1827,10 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model=model_name, max_model_len=4096, - hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}, + hf_overrides={ + "architectures": ["Tarsier2ForConditionalGeneration"], + "model_type": "tarsier2", + }, limit_mm_per_prompt={modality: 1}, ) @@ -1829,6 +1858,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: model_example_map = { "aria": run_aria, "aya_vision": run_aya_vision, + "bagel": run_bagel, "bee": run_bee, "blip-2": run_blip2, "chameleon": run_chameleon, @@ -2028,7 +2058,7 @@ def parse_args(): parser.add_argument( "--seed", type=int, - default=None, + default=0, help="Set the seed when initializing `vllm.LLM`.", ) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 7ba4e64b567de..3c01806baa203 100755 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -76,7 +76,7 @@ def load_aria(question: str, image_urls: list[str]) -> ModelRequestData: def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "CohereForAI/aya-vision-8b" + model_name = "CohereLabs/aya-vision-8b" engine_args = EngineArgs( model=model_name, @@ -309,6 +309,28 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData: ) +# HunyuanOCR +def load_hunyuan_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "tencent/HunyuanOCR" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholder = ( + "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501 + ) * len(image_urls) + prompt = f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>" + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_hyperclovax_seed_vision( question: str, image_urls: list[str] ) -> ModelRequestData: @@ -1222,7 +1244,10 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData: trust_remote_code=True, max_model_len=32768, limit_mm_per_prompt={"image": len(image_urls)}, - hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}, + hf_overrides={ + "architectures": ["Tarsier2ForConditionalGeneration"], + "model_type": "tarsier2", + }, ) prompt = ( @@ -1319,6 +1344,7 @@ model_example_map = { "deepseek_ocr": load_deepseek_ocr, "gemma3": load_gemma3, "h2ovl_chat": load_h2ovl, + "hunyuan_vl": load_hunyuan_vl, "hyperclovax_seed_vision": load_hyperclovax_seed_vision, "idefics3": load_idefics3, "interns1": load_interns1, @@ -1356,7 +1382,7 @@ def run_generate( model, question: str, image_urls: list[str], - seed: int | None, + seed: int, tensor_parallel_size: int | None, ): req_data = model_example_map[model](question, image_urls) @@ -1390,7 +1416,7 @@ def run_chat( model: str, question: str, image_urls: list[str], - seed: int | None, + seed: int, tensor_parallel_size: int | None, ): req_data = model_example_map[model](question, image_urls) @@ -1468,7 +1494,7 @@ def parse_args(): parser.add_argument( "--seed", type=int, - default=None, + default=0, help="Set the seed when initializing `vllm.LLM`.", ) parser.add_argument( diff --git a/examples/online_serving/disaggregated_encoder/README.md b/examples/online_serving/disaggregated_encoder/README.md index 5813a3cecf73b..b2c3bb974dfab 100644 --- a/examples/online_serving/disaggregated_encoder/README.md +++ b/examples/online_serving/disaggregated_encoder/README.md @@ -50,12 +50,12 @@ The vllm instances and `disagg_encoder_proxy` supports local URIs with ```{"url" ## EC connector and KV transfer -The `ECSharedStorageConnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration: +The `ECExampleonnector` is used to store the encoder cache on local disk and facilitate transfer. To enable the encoder disaggregation feature, add the following configuration: ```bash # Add to encoder instance: --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_producer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" @@ -64,7 +64,7 @@ The `ECSharedStorageConnector` is used to store the encoder cache on local disk # Add to prefill/prefill+decode instance: --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_consumer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" diff --git a/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh b/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh index 57489df64f51e..95a418374ad28 100644 --- a/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh +++ b/examples/online_serving/disaggregated_encoder/disagg_1e1p1d_example.sh @@ -102,7 +102,7 @@ CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_producer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" @@ -126,7 +126,7 @@ vllm serve "$MODEL" \ --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_consumer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" diff --git a/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh b/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh index 6073e0580b11d..c4a591d7438cb 100644 --- a/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh +++ b/examples/online_serving/disaggregated_encoder/disagg_1e1pd_example.sh @@ -96,7 +96,7 @@ CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \ --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_producer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" @@ -117,7 +117,7 @@ CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \ --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_consumer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" diff --git a/examples/online_serving/openai_responses_client_with_tools.py b/examples/online_serving/openai_responses_client_with_tools.py index 276010197b5ab..c85c8cf807b49 100644 --- a/examples/online_serving/openai_responses_client_with_tools.py +++ b/examples/online_serving/openai_responses_client_with_tools.py @@ -3,7 +3,7 @@ """ Set up this example by starting a vLLM OpenAI-compatible server with tool call options enabled. -Reasoning models can be used through the Responses API as seen here +Reasoning models can be used through the Responses API as seen here https://platform.openai.com/docs/api-reference/responses For example: vllm serve Qwen/Qwen3-1.7B --reasoning-parser qwen3 \ diff --git a/examples/online_serving/pooling/README.md b/examples/online_serving/pooling/README.md deleted file mode 100644 index b76ad21f04818..0000000000000 --- a/examples/online_serving/pooling/README.md +++ /dev/null @@ -1,97 +0,0 @@ -# Pooling models - -## Cohere rerank usage - -```bash -# vllm serve BAAI/bge-reranker-base -python examples/online_serving/pooling/cohere_rerank_client.py -``` - -## Embedding requests base64 encoding_format usage - -```bash -# vllm serve intfloat/e5-small -python examples/online_serving/pooling/embedding_requests_base64_client.py -``` - -## Embedding requests bytes encoding_format usage - -```bash -# vllm serve intfloat/e5-small -python examples/online_serving/pooling/embedding_requests_bytes_client.py -``` - -## Jinaai rerank usage - -```bash -# vllm serve BAAI/bge-reranker-base -python examples/online_serving/pooling/jinaai_rerank_client.py -``` - -## Multi vector retrieval usage - -```bash -# vllm serve BAAI/bge-m3 -python examples/online_serving/pooling/multi_vector_retrieval_client.py -``` - -## Named Entity Recognition (NER) usage - -```bash -# vllm serve boltuix/NeuroBERT-NER -python examples/online_serving/pooling/ner_client.py -``` - -## OpenAI chat embedding for multimodal usage - -```bash -python examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py -``` - -## OpenAI classification usage - -```bash -# vllm serve jason9693/Qwen2.5-1.5B-apeach -python examples/online_serving/pooling/openai_classification_client.py -``` - -## OpenAI cross_encoder score usage - -```bash -# vllm serve BAAI/bge-reranker-v2-m3 -python examples/online_serving/pooling/openai_cross_encoder_score.py -``` - -## OpenAI cross_encoder score for multimodal usage - -```bash -# vllm serve jinaai/jina-reranker-m0 -python examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py -``` - -## OpenAI embedding usage - -```bash -# vllm serve intfloat/e5-small -python examples/online_serving/pooling/openai_embedding_client.py -``` - -## OpenAI embedding matryoshka dimensions usage - -```bash -# vllm serve jinaai/jina-embeddings-v3 --trust-remote-code -python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py -``` - -## OpenAI pooling usage - -```bash -# vllm serve internlm/internlm2-1_8b-reward --trust-remote-code -python examples/online_serving/pooling/openai_pooling_client.py -``` - -## Online Prithvi Geospatial MAE usage - -```bash -python examples/online_serving/pooling/prithvi_geospatial_mae.py -``` diff --git a/examples/online_serving/prompt_embed_inference_with_openai_client.py b/examples/online_serving/prompt_embed_inference_with_openai_client.py index 0bbe4b8f5ee9b..889be6820e70a 100644 --- a/examples/online_serving/prompt_embed_inference_with_openai_client.py +++ b/examples/online_serving/prompt_embed_inference_with_openai_client.py @@ -28,13 +28,11 @@ Dependencies: - openai """ -import base64 -import io - -import torch import transformers from openai import OpenAI +from vllm.utils.serial_utils import tensor2base64 + def main(): client = OpenAI( @@ -58,11 +56,7 @@ def main(): prompt_embeds = embedding_layer(token_ids).squeeze(0) # Prompt embeddings - buffer = io.BytesIO() - torch.save(prompt_embeds, buffer) - buffer.seek(0) - binary_data = buffer.read() - encoded_embeds = base64.b64encode(binary_data).decode("utf-8") + encoded_embeds = tensor2base64(prompt_embeds) completion = client.completions.create( model=model_name, diff --git a/examples/online_serving/run_cluster.sh b/examples/online_serving/run_cluster.sh index 0756d4b0ae556..5996098eb25aa 100644 --- a/examples/online_serving/run_cluster.sh +++ b/examples/online_serving/run_cluster.sh @@ -21,7 +21,7 @@ # --worker \ # /abs/path/to/huggingface/cache \ # -e VLLM_HOST_IP=<worker_node_ip> -# +# # Each worker requires a unique VLLM_HOST_IP value. # Keep each terminal session open. Closing a session stops the associated Ray # node and thereby shuts down the entire cluster. @@ -59,6 +59,34 @@ if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then exit 1 fi +# Extract VLLM_HOST_IP from ADDITIONAL_ARGS (e.g. "-e VLLM_HOST_IP=..."). +VLLM_HOST_IP="" +for ((i = 0; i < ${#ADDITIONAL_ARGS[@]}; i++)); do + arg="${ADDITIONAL_ARGS[$i]}" + case "${arg}" in + -e) + next="${ADDITIONAL_ARGS[$((i + 1))]:-}" + if [[ "${next}" == VLLM_HOST_IP=* ]]; then + VLLM_HOST_IP="${next#VLLM_HOST_IP=}" + break + fi + ;; + -eVLLM_HOST_IP=* | VLLM_HOST_IP=*) + VLLM_HOST_IP="${arg#*=}" + break + ;; + esac +done + +# For the head node, HEAD_NODE_ADDRESS and VLLM_HOST_IP should be consistent. +if [[ "${NODE_TYPE}" == "--head" && -n "${VLLM_HOST_IP}" ]]; then + if [[ "${VLLM_HOST_IP}" != "${HEAD_NODE_ADDRESS}" ]]; then + echo "Warning: VLLM_HOST_IP (${VLLM_HOST_IP}) differs from head_node_ip (${HEAD_NODE_ADDRESS})." + echo "Using VLLM_HOST_IP as the head node address." + HEAD_NODE_ADDRESS="${VLLM_HOST_IP}" + fi +fi + # Generate a unique container name with random suffix. # Docker container names must be unique on each host. # The random suffix allows multiple Ray containers to run simultaneously on the same machine, @@ -74,36 +102,17 @@ cleanup() { trap cleanup EXIT # Build the Ray start command based on the node role. -# The head node manages the cluster and accepts connections on port 6379, +# The head node manages the cluster and accepts connections on port 6379, # while workers connect to the head's address. RAY_START_CMD="ray start --block" if [ "${NODE_TYPE}" == "--head" ]; then - RAY_START_CMD+=" --head --port=6379" + RAY_START_CMD+=" --head --node-ip-address=${HEAD_NODE_ADDRESS} --port=6379" else + RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379" -fi - -# Parse VLLM_HOST_IP from additional args if present. -# This is needed for multi-NIC configurations where Ray needs explicit IP bindings. -VLLM_HOST_IP="" -for arg in "${ADDITIONAL_ARGS[@]}"; do - if [[ $arg == "-e" ]]; then - continue + if [ -n "${VLLM_HOST_IP}" ]; then + RAY_START_CMD+=" --node-ip-address=${VLLM_HOST_IP}" fi - if [[ $arg == VLLM_HOST_IP=* ]]; then - VLLM_HOST_IP="${arg#VLLM_HOST_IP=}" - break - fi -done - -# Build Ray IP environment variables if VLLM_HOST_IP is set. -# These variables ensure Ray binds to the correct network interface on multi-NIC systems. -RAY_IP_VARS=() -if [ -n "${VLLM_HOST_IP}" ]; then - RAY_IP_VARS=( - -e "RAY_NODE_IP_ADDRESS=${VLLM_HOST_IP}" - -e "RAY_OVERRIDE_NODE_IP_ADDRESS=${VLLM_HOST_IP}" - ) fi # Launch the container with the assembled parameters. @@ -118,6 +127,5 @@ docker run \ --shm-size 10.24g \ --gpus all \ -v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \ - "${RAY_IP_VARS[@]}" \ "${ADDITIONAL_ARGS[@]}" \ "${DOCKER_IMAGE}" -c "${RAY_START_CMD}" diff --git a/examples/online_serving/structured_outputs/structured_outputs.py b/examples/online_serving/structured_outputs/structured_outputs.py index ff473d044e323..2599c951ef8ad 100644 --- a/examples/online_serving/structured_outputs/structured_outputs.py +++ b/examples/online_serving/structured_outputs/structured_outputs.py @@ -112,7 +112,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = { "messages": [ { "role": "user", - "content": "Generate an SQL query to show the 'username' and 'email'from the 'users' table.", + "content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.", } ], "extra_body": { diff --git a/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py index 5d8e38c73b89a..c8965e050ff0b 100644 --- a/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py +++ b/examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py @@ -26,9 +26,21 @@ async def lifespan(app: FastAPI): ) app.state.prefill_client = httpx.AsyncClient( - timeout=None, base_url=prefiller_base_url + timeout=None, + base_url=prefiller_base_url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ) + app.state.decode_client = httpx.AsyncClient( + timeout=None, + base_url=decoder_base_url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), ) - app.state.decode_client = httpx.AsyncClient(timeout=None, base_url=decoder_base_url) yield @@ -105,6 +117,11 @@ async def send_request_to_service( headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} response = await client.post(endpoint, json=req_data, headers=headers) response.raise_for_status() + + # read/consume the response body to release the connection + # otherwise, it would http.ReadError + await response.aread() + return response diff --git a/examples/online_serving/pooling/openai_classification_client.py b/examples/pooling/classify/openai_classification_client.py similarity index 100% rename from examples/online_serving/pooling/openai_classification_client.py rename to examples/pooling/classify/openai_classification_client.py diff --git a/examples/offline_inference/pooling/embed_jina_embeddings_v3.py b/examples/pooling/embed/embed_jina_embeddings_v3.py similarity index 100% rename from examples/offline_inference/pooling/embed_jina_embeddings_v3.py rename to examples/pooling/embed/embed_jina_embeddings_v3.py diff --git a/examples/offline_inference/pooling/embed_matryoshka_fy.py b/examples/pooling/embed/embed_matryoshka_fy.py similarity index 100% rename from examples/offline_inference/pooling/embed_matryoshka_fy.py rename to examples/pooling/embed/embed_matryoshka_fy.py diff --git a/examples/online_serving/pooling/embedding_requests_base64_client.py b/examples/pooling/embed/embedding_requests_base64_client.py similarity index 100% rename from examples/online_serving/pooling/embedding_requests_base64_client.py rename to examples/pooling/embed/embedding_requests_base64_client.py diff --git a/examples/online_serving/pooling/embedding_requests_bytes_client.py b/examples/pooling/embed/embedding_requests_bytes_client.py similarity index 58% rename from examples/online_serving/pooling/embedding_requests_bytes_client.py rename to examples/pooling/embed/embedding_requests_bytes_client.py index c2832f1b54ce7..5ea4525241497 100644 --- a/examples/online_serving/pooling/embedding_requests_bytes_client.py +++ b/examples/pooling/embed/embedding_requests_bytes_client.py @@ -16,6 +16,7 @@ from vllm.utils.serial_utils import ( EMBED_DTYPE_TO_TORCH_DTYPE, ENDIANNESS, MetadataItem, + build_metadata_items, decode_pooling_output, ) @@ -38,6 +39,11 @@ def parse_args(): def main(args): api_url = f"http://{args.host}:{args.port}/v1/embeddings" model_name = args.model + embedding_size = 0 + + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] * 2 # The OpenAI client does not support the bytes encoding_format. # The OpenAI client does not support the embed_dtype and endianness parameters. @@ -45,7 +51,7 @@ def main(args): for endianness in ENDIANNESS: prompt = { "model": model_name, - "input": "vLLM is great!", + "input": input_texts, "encoding_format": "bytes", "embed_dtype": embed_dtype, "endianness": endianness, @@ -57,7 +63,34 @@ def main(args): embedding = decode_pooling_output(items=items, body=body) embedding = [x.to(torch.float32) for x in embedding] - embedding = torch.cat(embedding) + embedding = torch.stack(embedding) + embedding_size = embedding.shape[-1] + print(embed_dtype, endianness, embedding.shape) + + # The vllm server always sorts the returned embeddings in the order of input. So + # returning metadata is not necessary. You can set encoding_format to bytes_only + # to let the server not return metadata. + for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: + for endianness in ENDIANNESS: + prompt = { + "model": model_name, + "input": input_texts, + "encoding_format": "bytes_only", + "embed_dtype": embed_dtype, + "endianness": endianness, + } + response = post_http_request(prompt=prompt, api_url=api_url) + body = response.content + + items = build_metadata_items( + embed_dtype=embed_dtype, + endianness=endianness, + shape=(embedding_size,), + n_request=len(input_texts), + ) + embedding = decode_pooling_output(items=items, body=body) + embedding = [x.to(torch.float32) for x in embedding] + embedding = torch.stack(embedding) print(embed_dtype, endianness, embedding.shape) diff --git a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py b/examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py similarity index 99% rename from examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py rename to examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py index 47c2c5030078c..a7ab7e73e7d42 100644 --- a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py +++ b/examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py @@ -150,7 +150,8 @@ def run_siglip(client: OpenAI, model: str): Start the server using: vllm serve google/siglip-base-patch16-224 \ - --runner pooling + --runner pooling \ + --chat-template template_basic.jinja """ response = create_chat_embeddings( diff --git a/examples/online_serving/pooling/openai_embedding_client.py b/examples/pooling/embed/openai_embedding_client.py similarity index 100% rename from examples/online_serving/pooling/openai_embedding_client.py rename to examples/pooling/embed/openai_embedding_client.py diff --git a/examples/online_serving/openai_embedding_long_text/README.md b/examples/pooling/embed/openai_embedding_long_text/README.md similarity index 100% rename from examples/online_serving/openai_embedding_long_text/README.md rename to examples/pooling/embed/openai_embedding_long_text/README.md diff --git a/examples/online_serving/openai_embedding_long_text/client.py b/examples/pooling/embed/openai_embedding_long_text/client.py similarity index 100% rename from examples/online_serving/openai_embedding_long_text/client.py rename to examples/pooling/embed/openai_embedding_long_text/client.py diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/pooling/embed/openai_embedding_long_text/service.sh similarity index 100% rename from examples/online_serving/openai_embedding_long_text/service.sh rename to examples/pooling/embed/openai_embedding_long_text/service.sh diff --git a/examples/online_serving/pooling/openai_embedding_matryoshka_fy.py b/examples/pooling/embed/openai_embedding_matryoshka_fy.py similarity index 100% rename from examples/online_serving/pooling/openai_embedding_matryoshka_fy.py rename to examples/pooling/embed/openai_embedding_matryoshka_fy.py diff --git a/examples/online_serving/pooling/prithvi_geospatial_mae.py b/examples/pooling/plugin/prithvi_geospatial_mae_client.py similarity index 97% rename from examples/online_serving/pooling/prithvi_geospatial_mae.py rename to examples/pooling/plugin/prithvi_geospatial_mae_client.py index a6246999c14d6..1ba1fd6a92ca4 100644 --- a/examples/online_serving/pooling/prithvi_geospatial_mae.py +++ b/examples/pooling/plugin/prithvi_geospatial_mae_client.py @@ -16,7 +16,7 @@ import requests # - start vllm in serving mode with the below args # --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' # --model-impl terratorch -# --task embed --trust-remote-code +# --trust-remote-code # --skip-tokenizer-init --enforce-eager # --io-processor-plugin terratorch_segmentation # --enable-mm-embeds diff --git a/examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py b/examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py similarity index 100% rename from examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py rename to examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py diff --git a/examples/offline_inference/pooling/prithvi_geospatial_mae.py b/examples/pooling/plugin/prithvi_geospatial_mae_offline.py similarity index 100% rename from examples/offline_inference/pooling/prithvi_geospatial_mae.py rename to examples/pooling/plugin/prithvi_geospatial_mae_offline.py diff --git a/examples/online_serving/pooling/openai_pooling_client.py b/examples/pooling/pooling/openai_pooling_client.py similarity index 100% rename from examples/online_serving/pooling/openai_pooling_client.py rename to examples/pooling/pooling/openai_pooling_client.py diff --git a/examples/offline_inference/vision_language_pooling.py b/examples/pooling/pooling/vision_language_pooling.py similarity index 98% rename from examples/offline_inference/vision_language_pooling.py rename to examples/pooling/pooling/vision_language_pooling.py index 530aad4bc031c..dda56bc34df2e 100644 --- a/examples/offline_inference/vision_language_pooling.py +++ b/examples/pooling/pooling/vision_language_pooling.py @@ -305,7 +305,7 @@ def get_query(modality: QueryModality): raise ValueError(msg) -def run_encode(model: str, modality: QueryModality, seed: int | None): +def run_encode(model: str, modality: QueryModality, seed: int): query = get_query(modality) req_data = model_example_map[model](query) @@ -335,7 +335,7 @@ def run_encode(model: str, modality: QueryModality, seed: int | None): print("-" * 50) -def run_score(model: str, modality: QueryModality, seed: int | None): +def run_score(model: str, modality: QueryModality, seed: int): query = get_query(modality) req_data = model_example_map[model](query) @@ -390,7 +390,7 @@ def parse_args(): parser.add_argument( "--seed", type=int, - default=None, + default=0, help="Set the seed when initializing `vllm.LLM`.", ) return parser.parse_args() diff --git a/examples/online_serving/pooling/cohere_rerank_client.py b/examples/pooling/score/cohere_rerank_client.py similarity index 100% rename from examples/online_serving/pooling/cohere_rerank_client.py rename to examples/pooling/score/cohere_rerank_client.py diff --git a/examples/offline_inference/pooling/convert_model_to_seq_cls.py b/examples/pooling/score/convert_model_to_seq_cls.py similarity index 100% rename from examples/offline_inference/pooling/convert_model_to_seq_cls.py rename to examples/pooling/score/convert_model_to_seq_cls.py diff --git a/examples/offline_inference/pooling/qwen3_reranker.py b/examples/pooling/score/offline_reranker.py similarity index 100% rename from examples/offline_inference/pooling/qwen3_reranker.py rename to examples/pooling/score/offline_reranker.py diff --git a/examples/online_serving/pooling/openai_cross_encoder_score.py b/examples/pooling/score/openai_cross_encoder_score.py similarity index 100% rename from examples/online_serving/pooling/openai_cross_encoder_score.py rename to examples/pooling/score/openai_cross_encoder_score.py diff --git a/examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py b/examples/pooling/score/openai_cross_encoder_score_for_multimodal.py similarity index 100% rename from examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py rename to examples/pooling/score/openai_cross_encoder_score_for_multimodal.py diff --git a/examples/online_serving/pooling/jinaai_rerank_client.py b/examples/pooling/score/openai_reranker.py similarity index 100% rename from examples/online_serving/pooling/jinaai_rerank_client.py rename to examples/pooling/score/openai_reranker.py diff --git a/examples/offline_inference/pooling/ner.py b/examples/pooling/token_classify/ner.py similarity index 100% rename from examples/offline_inference/pooling/ner.py rename to examples/pooling/token_classify/ner.py diff --git a/examples/online_serving/pooling/ner_client.py b/examples/pooling/token_classify/ner_client.py similarity index 100% rename from examples/online_serving/pooling/ner_client.py rename to examples/pooling/token_classify/ner_client.py diff --git a/examples/pooling/token_embed/jina_embeddings_v4.py b/examples/pooling/token_embed/jina_embeddings_v4.py new file mode 100644 index 0000000000000..83d4c446d426c --- /dev/null +++ b/examples/pooling/token_embed/jina_embeddings_v4.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm import LLM +from vllm.inputs.data import TextPrompt +from vllm.multimodal.utils import fetch_image + +# Initialize model +model = LLM( + model="jinaai/jina-embeddings-v4-vllm-text-matching", + runner="pooling", + max_model_len=1024, + gpu_memory_utilization=0.8, +) + +# Create text prompts +text1 = "Ein wunderschöner Sonnenuntergang am Strand" +text1_prompt = TextPrompt(prompt=f"Query: {text1}") + +text2 = "浜辺に沈む美しい夕日" +text2_prompt = TextPrompt(prompt=f"Query: {text2}") + +# Create image prompt +image = fetch_image( + "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501 +) +image_prompt = TextPrompt( + prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501 + multi_modal_data={"image": image}, +) + +# Encode all prompts +prompts = [text1_prompt, text2_prompt, image_prompt] +outputs = model.encode(prompts, pooling_task="token_embed") + + +def get_embeddings(outputs): + VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653 + + embeddings = [] + for output in outputs: + if VISION_START_TOKEN_ID in output.prompt_token_ids: + # Gather only vision tokens + img_start_pos = torch.where( + torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID + )[0][0] + img_end_pos = torch.where( + torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID + )[0][0] + embeddings_tensor = output.outputs.data.detach().clone()[ + img_start_pos : img_end_pos + 1 + ] + else: + # Use all tokens for text-only prompts + embeddings_tensor = output.outputs.data.detach().clone() + + # Pool and normalize embeddings + pooled_output = ( + embeddings_tensor.sum(dim=0, dtype=torch.float32) + / embeddings_tensor.shape[0] + ) + embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1)) + return embeddings + + +embeddings = get_embeddings(outputs) + +for embedding in embeddings: + print(embedding.shape) diff --git a/examples/offline_inference/pooling/multi_vector_retrieval.py b/examples/pooling/token_embed/multi_vector_retrieval.py similarity index 100% rename from examples/offline_inference/pooling/multi_vector_retrieval.py rename to examples/pooling/token_embed/multi_vector_retrieval.py diff --git a/examples/online_serving/pooling/multi_vector_retrieval_client.py b/examples/pooling/token_embed/multi_vector_retrieval_client.py similarity index 100% rename from examples/online_serving/pooling/multi_vector_retrieval_client.py rename to examples/pooling/token_embed/multi_vector_retrieval_client.py diff --git a/mkdocs.yaml b/mkdocs.yaml index bf97093dafb11..8fb8f0568c6ef 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -51,6 +51,7 @@ hooks: - docs/mkdocs/hooks/remove_announcement.py - docs/mkdocs/hooks/generate_examples.py - docs/mkdocs/hooks/generate_argparse.py + - docs/mkdocs/hooks/generate_metrics.py - docs/mkdocs/hooks/url_schemes.py plugins: diff --git a/requirements/common.txt b/requirements/common.txt index 8b9e6b935bd20..31c8fb404f63a 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -46,7 +46,9 @@ scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects +ijson # Required for mistral streaming tool parser setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss anthropic == 0.71.0 -model-hosting-container-standards >= 0.1.9, < 1.0.0 \ No newline at end of file +model-hosting-container-standards >= 0.1.9, < 1.0.0 +mcp \ No newline at end of file diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt index e18e0825fc428..1ea401a04a12c 100644 --- a/requirements/cpu-build.txt +++ b/requirements/cpu-build.txt @@ -3,7 +3,6 @@ ninja packaging>=24.2 setuptools>=77.0.3,<81.0.0 setuptools-scm>=8 ---extra-index-url https://download.pytorch.org/whl/cpu torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x" torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64" scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL) diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 21571be479c83..7a670812e8943 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -4,7 +4,6 @@ numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding # Dependencies for CPUs ---extra-index-url https://download.pytorch.org/whl/cpu torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x" torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64" diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 53b012372be8e..7b2c665448a3b 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -42,6 +42,6 @@ tritonclient==2.51.0 numba == 0.61.2 # Required for N-gram speculative decoding numpy -runai-model-streamer[s3,gcs]==0.15.0 +runai-model-streamer[s3,gcs]==0.15.3 fastsafetensors>=0.1.10 pydantic>=2.12 # 2.11 leads to error on python 3.13 diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index ae61d4c6c6a81..3f0fd235fba50 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -49,6 +49,7 @@ blobfile==3.0.0 # Multi-Modal Models Test decord==0.6.0 # video processing, required by entrypoints/openai/test_video.py +rapidfuzz==3.12.1 # OpenAI compatibility and testing gpt-oss==0.0.8 @@ -58,10 +59,14 @@ schemathesis==3.39.15 # Evaluation and benchmarking lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d +jiwer==4.0.0 # Required for multiprocessed tests that use spawn method, Datasets and Evaluate Test multiprocess==0.70.16 +# Required for v1/metrics/test_engine_logger_apis.py +ray[cgraph,default]>=2.48.0 + # Plugins test terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e torchgeo==0.7.0 @@ -70,8 +75,8 @@ torchgeo==0.7.0 mteb==2.1.2 # Data processing -xgrammar @ git+https://github.com/mlc-ai/xgrammar.git@eafd4db51b78acc64b3f0764ef27dfd206c28628 - # Test async scheduling +xgrammar @ git+https://github.com/divakar-amd/xgrammar@3272f7c520564858056a60480d5afdf69ae79c84 +# Test async scheduling # Utilities num2words==0.5.14 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index abbd33d6e1240..05b9a21791c92 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -12,7 +12,7 @@ tensorizer==2.10.1 packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -runai-model-streamer[s3,gcs]==0.15.0 +runai-model-streamer[s3,gcs]==0.15.3 conch-triton-kernels==1.2.1 timm>=1.0.17 fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@d6f998a03432b2452f8de2bb5cefb5af9795d459 diff --git a/requirements/test.in b/requirements/test.in index da7a7db1f00c9..dfae5b75821f8 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -51,7 +51,7 @@ tritonclient==2.51.0 arctic-inference == 0.1.1 # Required for suffix decoding test numba == 0.61.2 # Required for N-gram speculative decoding numpy -runai-model-streamer[s3,gcs]==0.15.0 +runai-model-streamer[s3,gcs]==0.15.3 fastsafetensors>=0.1.10 pydantic>=2.12 # 2.11 leads to error on python 3.13 decord==0.6.0 diff --git a/requirements/test.txt b/requirements/test.txt index c5f103b8b0d78..571194e05c1ba 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -965,11 +965,11 @@ rsa==4.9.1 # via google-auth rtree==1.4.0 # via torchgeo -runai-model-streamer==0.15.0 +runai-model-streamer==0.15.3 # via -r requirements/test.in -runai-model-streamer-gcs==0.15.0 +runai-model-streamer-gcs==0.15.3 # via runai-model-streamer -runai-model-streamer-s3==0.15.0 +runai-model-streamer-s3==0.15.3 # via runai-model-streamer s3transfer==0.10.3 # via boto3 diff --git a/requirements/tpu.txt b/requirements/tpu.txt index e6fff58f7b794..7695b4ba2f4cb 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -11,5 +11,4 @@ ray[default] ray[data] setuptools==78.1.0 nixl==0.3.0 -tpu_info==0.4.0 -tpu-inference==0.11.1 +tpu-inference==0.12.0 diff --git a/setup.py b/setup.py index 0022e7fe0bf36..6fcb6653bc4a3 100644 --- a/setup.py +++ b/setup.py @@ -311,7 +311,7 @@ class precompiled_build_ext(build_ext): """Disables extension building when using precompiled binaries.""" def run(self) -> None: - assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" + return def build_extensions(self) -> None: print("Skipping build_ext: using precompiled extensions.") @@ -322,14 +322,127 @@ class precompiled_wheel_utils: """Extracts libraries and other files from an existing wheel.""" @staticmethod - def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict: + def fetch_metadata_for_variant( + commit: str, variant: str | None + ) -> tuple[list[dict], str]: + """ + Fetches metadata for a specific variant of the precompiled wheel. + """ + variant_dir = f"{variant}/" if variant is not None else "" + repo_url = f"https://wheels.vllm.ai/{commit}/{variant_dir}vllm/" + meta_url = repo_url + "metadata.json" + print(f"Trying to fetch nightly build metadata from {meta_url}") + from urllib.request import urlopen + + with urlopen(meta_url) as resp: + # urlopen raises HTTPError on unexpected status code + wheels = json.loads(resp.read().decode("utf-8")) + return wheels, repo_url + + @staticmethod + def determine_wheel_url() -> tuple[str, str | None]: + """ + Try to determine the precompiled wheel URL or path to use. + The order of preference is: + 1. user-specified wheel location (can be either local or remote, via + VLLM_PRECOMPILED_WHEEL_LOCATION) + 2. user-specified variant (VLLM_PRECOMPILED_WHEEL_VARIANT) from nightly repo + 3. the variant corresponding to VLLM_MAIN_CUDA_VERSION from nightly repo + 4. the default variant from nightly repo + + If downloading from the nightly repo, the commit can be specified via + VLLM_PRECOMPILED_WHEEL_COMMIT; otherwise, the head commit in the main branch + is used. + """ + wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) + if wheel_location is not None: + print(f"Using user-specified precompiled wheel location: {wheel_location}") + return wheel_location, None + else: + import platform + + arch = platform.machine() + # try to fetch the wheel metadata from the nightly wheel repo + main_variant = "cu" + envs.VLLM_MAIN_CUDA_VERSION.replace(".", "") + variant = os.getenv("VLLM_PRECOMPILED_WHEEL_VARIANT", main_variant) + commit = os.getenv("VLLM_PRECOMPILED_WHEEL_COMMIT", "").lower() + if not commit or len(commit) != 40: + print( + f"VLLM_PRECOMPILED_WHEEL_COMMIT not valid: {commit}" + ", trying to fetch base commit in main branch" + ) + commit = precompiled_wheel_utils.get_base_commit_in_main_branch() + print(f"Using precompiled wheel commit {commit} with variant {variant}") + try_default = False + wheels, repo_url, download_filename = None, None, None + try: + wheels, repo_url = precompiled_wheel_utils.fetch_metadata_for_variant( + commit, variant + ) + except Exception as e: + logger.warning( + "Failed to fetch precompiled wheel metadata for variant %s: %s", + variant, + e, + ) + try_default = True # try outside handler to keep the stacktrace simple + if try_default: + print("Trying the default variant from remote") + wheels, repo_url = precompiled_wheel_utils.fetch_metadata_for_variant( + commit, None + ) + # if this also fails, then we have nothing more to try / cache + assert wheels is not None and repo_url is not None, ( + "Failed to fetch precompiled wheel metadata" + ) + # The metadata.json has the following format: + # see .buildkite/scripts/generate-nightly-index.py for details + """[{ + "package_name": "vllm", + "version": "0.11.2.dev278+gdbc3d9991", + "build_tag": null, + "python_tag": "cp38", + "abi_tag": "abi3", + "platform_tag": "manylinux1_x86_64", + "variant": null, + "filename": "vllm-0.11.2.dev278+gdbc3d9991-cp38-abi3-manylinux1_x86_64.whl", + "path": "../vllm-0.11.2.dev278%2Bgdbc3d9991-cp38-abi3-manylinux1_x86_64.whl" + }, + ...]""" + from urllib.parse import urljoin + + for wheel in wheels: + # TODO: maybe check more compatibility later? (python_tag, abi_tag, etc) + if wheel.get("package_name") == "vllm" and arch in wheel.get( + "platform_tag", "" + ): + print(f"Found precompiled wheel metadata: {wheel}") + if "path" not in wheel: + raise ValueError(f"Wheel metadata missing path: {wheel}") + wheel_url = urljoin(repo_url, wheel["path"]) + download_filename = wheel.get("filename") + print(f"Using precompiled wheel URL: {wheel_url}") + break + else: + raise ValueError( + f"No precompiled vllm wheel found for architecture {arch} " + f"from repo {repo_url}. All available wheels: {wheels}" + ) + + return wheel_url, download_filename + + @staticmethod + def extract_precompiled_and_patch_package( + wheel_url_or_path: str, download_filename: str | None + ) -> dict: import tempfile import zipfile temp_dir = None try: if not os.path.isfile(wheel_url_or_path): - wheel_filename = wheel_url_or_path.split("/")[-1] + # use provided filename first, then derive from URL + wheel_filename = download_filename or wheel_url_or_path.split("/")[-1] temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") wheel_path = os.path.join(temp_dir, wheel_filename) print(f"Downloading wheel from {wheel_url_or_path} to {wheel_path}") @@ -354,14 +467,22 @@ class precompiled_wheel_utils: "vllm/cumem_allocator.abi3.so", ] - compiled_regex = re.compile( + flash_attn_regex = re.compile( r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" ) + triton_kernels_regex = re.compile( + r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" + ) file_members = list( filter(lambda x: x.filename in files_to_copy, wheel.filelist) ) file_members += list( - filter(lambda x: compiled_regex.match(x.filename), wheel.filelist) + filter(lambda x: flash_attn_regex.match(x.filename), wheel.filelist) + ) + file_members += list( + filter( + lambda x: triton_kernels_regex.match(x.filename), wheel.filelist + ) ) for file in file_members: @@ -387,10 +508,6 @@ class precompiled_wheel_utils: @staticmethod def get_base_commit_in_main_branch() -> str: - # Force to use the nightly wheel. This is mainly used for CI testing. - if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: - return "nightly" - try: # Get the latest commit hash of the upstream main branch. resp_json = subprocess.check_output( @@ -401,6 +518,7 @@ class precompiled_wheel_utils: ] ).decode("utf-8") upstream_main_commit = json.loads(resp_json)["sha"] + print(f"Upstream main branch latest commit: {upstream_main_commit}") # In Docker build context, .git may be immutable or missing. if envs.VLLM_DOCKER_BUILD_CONTEXT: @@ -541,7 +659,7 @@ def get_vllm_version() -> str: if envs.VLLM_TARGET_DEVICE == "empty": version += f"{sep}empty" elif _is_cuda(): - if envs.VLLM_USE_PRECOMPILED: + if envs.VLLM_USE_PRECOMPILED and not envs.VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX: version += f"{sep}precompiled" else: cuda_version = str(get_nvcc_cuda_version()) @@ -648,38 +766,13 @@ package_data = { ] } + # If using precompiled, extract and patch package_data (in advance of setup) if envs.VLLM_USE_PRECOMPILED: - assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" - wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) - if wheel_location is not None: - wheel_url = wheel_location - else: - import platform - - arch = platform.machine() - if arch == "x86_64": - wheel_tag = "manylinux1_x86_64" - elif arch == "aarch64": - wheel_tag = "manylinux2014_aarch64" - else: - raise ValueError(f"Unsupported architecture: {arch}") - base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch() - wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" - nightly_wheel_url = ( - f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" - ) - from urllib.request import urlopen - - try: - with urlopen(wheel_url) as resp: - if resp.status != 200: - wheel_url = nightly_wheel_url - except Exception as e: - print(f"[warn] Falling back to nightly wheel: {e}") - wheel_url = nightly_wheel_url - - patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(wheel_url) + wheel_url, download_filename = precompiled_wheel_utils.determine_wheel_url() + patch = precompiled_wheel_utils.extract_precompiled_and_patch_package( + wheel_url, download_filename + ) for pkg, files in patch.items(): package_data.setdefault(pkg, []).extend(files) @@ -704,7 +797,7 @@ setup( "bench": ["pandas", "matplotlib", "seaborn", "datasets"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], - "runai": ["runai-model-streamer[s3,gcs] >= 0.15.0"], + "runai": ["runai-model-streamer[s3,gcs] >= 0.15.3"], "audio": [ "librosa", "soundfile", diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 521d6c33dd390..9e1cc309edd1d 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -13,12 +13,15 @@ import pytest import torch from vllm import LLM +from vllm.platforms import current_platform from vllm.v1.engine.llm_engine import LLMEngine from ..conftest import HfRunner, VllmRunner from ..models.utils import check_outputs_equal from ..utils import multi_gpu_test +ATTN_BACKEND = ["ROCM_ATTN"] if current_platform.is_rocm() else ["FLASH_ATTN"] + MODELS = [ "hmellor/tiny-random-Gemma2ForCausalLM", "meta-llama/Llama-3.2-1B-Instruct", @@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) +@pytest.mark.parametrize("backend", ATTN_BACKEND) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("async_scheduling", [True, False]) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index dc9c69bf58b95..3bd0b6609d88d 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -260,13 +260,18 @@ def test_deep_sleep_fp8_kvcache(): llm.sleep(level=2) used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline - assert used_bytes < 3 * GiB_bytes + + # Rocm uses more memory for CudaGraphs, so we add 2 GiB more for the threshold + rocm_extra_mem_bytes = 2 * GiB_bytes if current_platform.is_rocm() else 0 + mem_threshold_after_sleep = 3 * GiB_bytes + rocm_extra_mem_bytes + assert used_bytes < mem_threshold_after_sleep llm.wake_up(tags=["weights"]) llm.collective_rpc("reload_weights") used_bytes = current_platform.get_current_memory_usage() - used_bytes_baseline - assert used_bytes < 4 * GiB_bytes + mem_threshold_after_wake_up = 4 * GiB_bytes + rocm_extra_mem_bytes + assert used_bytes < mem_threshold_after_wake_up # now allocate kv cache and cuda graph memory llm.wake_up(tags=["kv_cache"]) diff --git a/tests/benchmarks/test_param_sweep.py b/tests/benchmarks/test_param_sweep.py new file mode 100644 index 0000000000000..467797d9915c9 --- /dev/null +++ b/tests/benchmarks/test_param_sweep.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import tempfile +from pathlib import Path + +import pytest + +from vllm.benchmarks.sweep.param_sweep import ParameterSweep, ParameterSweepItem + + +class TestParameterSweepItem: + """Test ParameterSweepItem functionality.""" + + @pytest.mark.parametrize( + "input_dict,expected", + [ + ( + {"compilation_config.use_inductor_graph_partition": False}, + "--compilation-config.use_inductor_graph_partition=false", + ), + ( + {"compilation_config.use_inductor_graph_partition": True}, + "--compilation-config.use_inductor_graph_partition=true", + ), + ], + ) + def test_nested_boolean_params(self, input_dict, expected): + """Test that nested boolean params use =true/false syntax.""" + item = ParameterSweepItem.from_record(input_dict) + cmd = item.apply_to_cmd(["vllm", "serve", "model"]) + assert expected in cmd + + @pytest.mark.parametrize( + "input_dict,expected", + [ + ({"enable_prefix_caching": False}, "--no-enable-prefix-caching"), + ({"enable_prefix_caching": True}, "--enable-prefix-caching"), + ({"disable_log_stats": False}, "--no-disable-log-stats"), + ({"disable_log_stats": True}, "--disable-log-stats"), + ], + ) + def test_non_nested_boolean_params(self, input_dict, expected): + """Test that non-nested boolean params use --no- prefix.""" + item = ParameterSweepItem.from_record(input_dict) + cmd = item.apply_to_cmd(["vllm", "serve", "model"]) + assert expected in cmd + + @pytest.mark.parametrize( + "compilation_config", + [ + {"cudagraph_mode": "full", "mode": 2, "use_inductor_graph_partition": True}, + { + "cudagraph_mode": "piecewise", + "mode": 3, + "use_inductor_graph_partition": False, + }, + ], + ) + def test_nested_dict_value(self, compilation_config): + """Test that nested dict values are serialized as JSON.""" + item = ParameterSweepItem.from_record( + {"compilation_config": compilation_config} + ) + cmd = item.apply_to_cmd(["vllm", "serve", "model"]) + assert "--compilation-config" in cmd + # The dict should be JSON serialized + idx = cmd.index("--compilation-config") + assert json.loads(cmd[idx + 1]) == compilation_config + + @pytest.mark.parametrize( + "input_dict,expected_key,expected_value", + [ + ({"model": "test-model"}, "--model", "test-model"), + ({"max_tokens": 100}, "--max-tokens", "100"), + ({"temperature": 0.7}, "--temperature", "0.7"), + ], + ) + def test_string_and_numeric_values(self, input_dict, expected_key, expected_value): + """Test that string and numeric values are handled correctly.""" + item = ParameterSweepItem.from_record(input_dict) + cmd = item.apply_to_cmd(["vllm", "serve"]) + assert expected_key in cmd + assert expected_value in cmd + + @pytest.mark.parametrize( + "input_dict,expected_key,key_idx_offset", + [ + ({"max_tokens": 200}, "--max-tokens", 1), + ({"enable_prefix_caching": False}, "--no-enable-prefix-caching", 0), + ], + ) + def test_replace_existing_parameter(self, input_dict, expected_key, key_idx_offset): + """Test that existing parameters in cmd are replaced.""" + item = ParameterSweepItem.from_record(input_dict) + + if key_idx_offset == 1: + # Key-value pair + cmd = item.apply_to_cmd(["vllm", "serve", "--max-tokens", "100", "model"]) + assert expected_key in cmd + idx = cmd.index(expected_key) + assert cmd[idx + 1] == "200" + assert "100" not in cmd + else: + # Boolean flag + cmd = item.apply_to_cmd( + ["vllm", "serve", "--enable-prefix-caching", "model"] + ) + assert expected_key in cmd + assert "--enable-prefix-caching" not in cmd + + +class TestParameterSweep: + """Test ParameterSweep functionality.""" + + def test_from_records_list(self): + """Test creating ParameterSweep from a list of records.""" + records = [ + {"max_tokens": 100, "temperature": 0.7}, + {"max_tokens": 200, "temperature": 0.9}, + ] + sweep = ParameterSweep.from_records(records) + assert len(sweep) == 2 + assert sweep[0]["max_tokens"] == 100 + assert sweep[1]["max_tokens"] == 200 + + def test_read_from_dict(self): + """Test creating ParameterSweep from a dict format.""" + data = { + "experiment1": {"max_tokens": 100, "temperature": 0.7}, + "experiment2": {"max_tokens": 200, "temperature": 0.9}, + } + sweep = ParameterSweep.read_from_dict(data) + assert len(sweep) == 2 + + # Check that items have the _benchmark_name field + names = {item["_benchmark_name"] for item in sweep} + assert names == {"experiment1", "experiment2"} + + # Check that parameters are preserved + for item in sweep: + if item["_benchmark_name"] == "experiment1": + assert item["max_tokens"] == 100 + assert item["temperature"] == 0.7 + elif item["_benchmark_name"] == "experiment2": + assert item["max_tokens"] == 200 + assert item["temperature"] == 0.9 + + def test_read_json_list_format(self): + """Test reading JSON file with list format.""" + records = [ + {"max_tokens": 100, "temperature": 0.7}, + {"max_tokens": 200, "temperature": 0.9}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(records, f) + temp_path = Path(f.name) + + try: + sweep = ParameterSweep.read_json(temp_path) + assert len(sweep) == 2 + assert sweep[0]["max_tokens"] == 100 + assert sweep[1]["max_tokens"] == 200 + finally: + temp_path.unlink() + + def test_read_json_dict_format(self): + """Test reading JSON file with dict format.""" + data = { + "experiment1": {"max_tokens": 100, "temperature": 0.7}, + "experiment2": {"max_tokens": 200, "temperature": 0.9}, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + temp_path = Path(f.name) + + try: + sweep = ParameterSweep.read_json(temp_path) + assert len(sweep) == 2 + + # Check that items have the _benchmark_name field + names = {item["_benchmark_name"] for item in sweep} + assert names == {"experiment1", "experiment2"} + finally: + temp_path.unlink() + + def test_unique_benchmark_names_validation(self): + """Test that duplicate _benchmark_name values raise an error.""" + # Test with duplicate names in list format + records = [ + {"_benchmark_name": "exp1", "max_tokens": 100}, + {"_benchmark_name": "exp1", "max_tokens": 200}, + ] + + with pytest.raises(ValueError, match="Duplicate _benchmark_name values"): + ParameterSweep.from_records(records) + + def test_unique_benchmark_names_multiple_duplicates(self): + """Test validation with multiple duplicate names.""" + records = [ + {"_benchmark_name": "exp1", "max_tokens": 100}, + {"_benchmark_name": "exp1", "max_tokens": 200}, + {"_benchmark_name": "exp2", "max_tokens": 300}, + {"_benchmark_name": "exp2", "max_tokens": 400}, + ] + + with pytest.raises(ValueError, match="Duplicate _benchmark_name values"): + ParameterSweep.from_records(records) + + def test_no_benchmark_names_allowed(self): + """Test that records without _benchmark_name are allowed.""" + records = [ + {"max_tokens": 100, "temperature": 0.7}, + {"max_tokens": 200, "temperature": 0.9}, + ] + sweep = ParameterSweep.from_records(records) + assert len(sweep) == 2 + + def test_mixed_benchmark_names_allowed(self): + """Test that mixing records with and without _benchmark_name is allowed.""" + records = [ + {"_benchmark_name": "exp1", "max_tokens": 100}, + {"max_tokens": 200, "temperature": 0.9}, + ] + sweep = ParameterSweep.from_records(records) + assert len(sweep) == 2 + + +class TestParameterSweepItemKeyNormalization: + """Test key normalization in ParameterSweepItem.""" + + def test_underscore_to_hyphen_conversion(self): + """Test that underscores are converted to hyphens in CLI.""" + item = ParameterSweepItem.from_record({"max_tokens": 100}) + cmd = item.apply_to_cmd(["vllm", "serve"]) + assert "--max-tokens" in cmd + + def test_nested_key_preserves_suffix(self): + """Test that nested keys preserve the suffix format.""" + # The suffix after the dot should preserve underscores + item = ParameterSweepItem.from_record( + {"compilation_config.some_nested_param": "value"} + ) + cmd = item.apply_to_cmd(["vllm", "serve"]) + # The prefix (compilation_config) gets converted to hyphens, + # but the suffix (some_nested_param) is preserved + assert any("compilation-config.some_nested_param" in arg for arg in cmd) diff --git a/tests/benchmarks/test_plot_filters.py b/tests/benchmarks/test_plot_filters.py new file mode 100644 index 0000000000000..2b58a99125e6c --- /dev/null +++ b/tests/benchmarks/test_plot_filters.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pandas as pd +import pytest + +from vllm.benchmarks.sweep.plot import ( + PlotEqualTo, + PlotFilterBase, + PlotFilters, + PlotGreaterThan, + PlotGreaterThanOrEqualTo, + PlotLessThan, + PlotLessThanOrEqualTo, + PlotNotEqualTo, +) + + +class TestPlotFilters: + """Test PlotFilter functionality including 'inf' edge case.""" + + def setup_method(self): + """Create sample DataFrames for testing.""" + # DataFrame with numeric values + self.df_numeric = pd.DataFrame( + { + "request_rate": [1.0, 5.0, 10.0, 50.0, 100.0], + "value": [10, 20, 30, 40, 50], + } + ) + + # DataFrame with float('inf') - note: string "inf" values are coerced + # to float when loading data, so we only test with float('inf') + self.df_inf_float = pd.DataFrame( + { + "request_rate": [1.0, 5.0, 10.0, float("inf"), float("inf")], + "value": [10, 20, 30, 40, 50], + } + ) + + @pytest.mark.parametrize( + "target,expected_count", + [ + ("5.0", 1), + ("10.0", 1), + ("1.0", 1), + ], + ) + def test_equal_to_numeric(self, target, expected_count): + """Test PlotEqualTo with numeric values.""" + filter_obj = PlotEqualTo("request_rate", target) + result = filter_obj.apply(self.df_numeric) + assert len(result) == expected_count + + def test_equal_to_inf_float(self): + """Test PlotEqualTo with float('inf').""" + filter_obj = PlotEqualTo("request_rate", "inf") + result = filter_obj.apply(self.df_inf_float) + # Should match both float('inf') entries because float('inf') == float('inf') + assert len(result) == 2 + + @pytest.mark.parametrize( + "target,expected_count", + [ + ("5.0", 4), # All except 5.0 + ("1.0", 4), # All except 1.0 + ], + ) + def test_not_equal_to_numeric(self, target, expected_count): + """Test PlotNotEqualTo with numeric values.""" + filter_obj = PlotNotEqualTo("request_rate", target) + result = filter_obj.apply(self.df_numeric) + assert len(result) == expected_count + + def test_not_equal_to_inf_float(self): + """Test PlotNotEqualTo with float('inf').""" + filter_obj = PlotNotEqualTo("request_rate", "inf") + result = filter_obj.apply(self.df_inf_float) + # Should exclude float('inf') entries + assert len(result) == 3 + + @pytest.mark.parametrize( + "target,expected_count", + [ + ("10.0", 2), # 1.0, 5.0 + ("50.0", 3), # 1.0, 5.0, 10.0 + ("5.0", 1), # 1.0 + ], + ) + def test_less_than(self, target, expected_count): + """Test PlotLessThan with numeric values.""" + filter_obj = PlotLessThan("request_rate", target) + result = filter_obj.apply(self.df_numeric) + assert len(result) == expected_count + + @pytest.mark.parametrize( + "target,expected_count", + [ + ("10.0", 3), # 1.0, 5.0, 10.0 + ("5.0", 2), # 1.0, 5.0 + ], + ) + def test_less_than_or_equal_to(self, target, expected_count): + """Test PlotLessThanOrEqualTo with numeric values.""" + filter_obj = PlotLessThanOrEqualTo("request_rate", target) + result = filter_obj.apply(self.df_numeric) + assert len(result) == expected_count + + @pytest.mark.parametrize( + "target,expected_count", + [ + ("10.0", 2), # 50.0, 100.0 + ("5.0", 3), # 10.0, 50.0, 100.0 + ], + ) + def test_greater_than(self, target, expected_count): + """Test PlotGreaterThan with numeric values.""" + filter_obj = PlotGreaterThan("request_rate", target) + result = filter_obj.apply(self.df_numeric) + assert len(result) == expected_count + + @pytest.mark.parametrize( + "target,expected_count", + [ + ("10.0", 3), # 10.0, 50.0, 100.0 + ("5.0", 4), # 5.0, 10.0, 50.0, 100.0 + ], + ) + def test_greater_than_or_equal_to(self, target, expected_count): + """Test PlotGreaterThanOrEqualTo with numeric values.""" + filter_obj = PlotGreaterThanOrEqualTo("request_rate", target) + result = filter_obj.apply(self.df_numeric) + assert len(result) == expected_count + + @pytest.mark.parametrize( + "filter_str,expected_var,expected_target,expected_type", + [ + ("request_rate==5.0", "request_rate", "5.0", PlotEqualTo), + ("request_rate!=10.0", "request_rate", "10.0", PlotNotEqualTo), + ("request_rate<50.0", "request_rate", "50.0", PlotLessThan), + ("request_rate<=50.0", "request_rate", "50.0", PlotLessThanOrEqualTo), + ("request_rate>10.0", "request_rate", "10.0", PlotGreaterThan), + ("request_rate>=10.0", "request_rate", "10.0", PlotGreaterThanOrEqualTo), + ("request_rate==inf", "request_rate", "inf", PlotEqualTo), + ("request_rate!='inf'", "request_rate", "inf", PlotNotEqualTo), + ], + ) + def test_parse_str(self, filter_str, expected_var, expected_target, expected_type): + """Test parsing filter strings.""" + filter_obj = PlotFilterBase.parse_str(filter_str) + assert isinstance(filter_obj, expected_type) + assert filter_obj.var == expected_var + assert filter_obj.target == expected_target + + def test_parse_str_inf_edge_case(self): + """Test parsing 'inf' string in filter.""" + filter_obj = PlotFilterBase.parse_str("request_rate==inf") + assert isinstance(filter_obj, PlotEqualTo) + assert filter_obj.var == "request_rate" + assert filter_obj.target == "inf" + + def test_parse_multiple_filters(self): + """Test parsing multiple filters.""" + filters = PlotFilters.parse_str("request_rate>5.0,value<=40") + assert len(filters) == 2 + assert isinstance(filters[0], PlotGreaterThan) + assert isinstance(filters[1], PlotLessThanOrEqualTo) + + def test_parse_empty_filter(self): + """Test parsing empty filter string.""" + filters = PlotFilters.parse_str("") + assert len(filters) == 0 diff --git a/tests/compile/distributed/test_async_tp.py b/tests/compile/distributed/test_async_tp.py index 86d409f1eadb0..2eb18e25c98bf 100644 --- a/tests/compile/distributed/test_async_tp.py +++ b/tests/compile/distributed/test_async_tp.py @@ -326,7 +326,7 @@ def async_tp_pass_on_test_model( vllm_config = VllmConfig() vllm_config.compilation_config = CompilationConfig( pass_config=PassConfig( - enable_async_tp=True, + fuse_gemm_comms=True, ), ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) @@ -413,7 +413,7 @@ def test_async_tp_pass_correctness( "mode": CompilationMode.VLLM_COMPILE, "compile_sizes": [2, 4, 8], "splitting_ops": [], - "pass_config": {"enable_async_tp": async_tp_enabled}, + "pass_config": {"fuse_gemm_comms": async_tp_enabled}, } async_tp_args = [ diff --git a/tests/compile/distributed/test_fusion_all_reduce.py b/tests/compile/distributed/test_fusion_all_reduce.py index d401d57032752..fc8d1f98ebf87 100644 --- a/tests/compile/distributed/test_fusion_all_reduce.py +++ b/tests/compile/distributed/test_fusion_all_reduce.py @@ -295,7 +295,7 @@ def all_reduce_fusion_pass_on_test_model( ) ) vllm_config.compilation_config.pass_config = PassConfig( - enable_fi_allreduce_fusion=True, enable_noop=True + fuse_allreduce_rms=True, eliminate_noops=True ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.parallel_config.rank = local_rank # Setup rank for debug path diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index 661172e1965b5..bd326f1157d8f 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -20,13 +20,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer from ...utils import flat_product, multi_gpu_test -is_blackwell = lambda: current_platform.is_device_capability(100) +is_blackwell = lambda: current_platform.is_device_capability_family(100) """Are we running on Blackwell, a lot of tests depend on it""" class Matches(NamedTuple): attention_fusion: int = 0 allreduce_fusion: int = 0 + rms_quant_norm_fusion: int = 0 sequence_parallel: int = 0 async_tp: int = 0 @@ -40,6 +41,7 @@ class ModelBackendTestCase(NamedTuple): MODELS_FP8: list[ModelBackendTestCase] = [] MODELS_FP4: list[ModelBackendTestCase] = [] +MODELS_GROUP_FP8: list[ModelBackendTestCase] = [] MODELS: list[ModelBackendTestCase] = [] # tp-only if current_platform.is_cuda(): @@ -138,6 +140,17 @@ elif current_platform.is_rocm(): CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"] +def has_cuda_graph_wrapper_metadata() -> bool: + from importlib import import_module + + try: + module = import_module("torch._inductor.utils") + module.CUDAGraphWrapperMetadata # noqa B018 + except AttributeError: + return False + return True + + @pytest.mark.parametrize( "model_name, model_kwargs, backend, matches, custom_ops", # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 @@ -145,7 +158,20 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"] # quant_fp4 only has the custom impl + list(flat_product(MODELS_FP4, [""])), ) -@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +@pytest.mark.parametrize( + "inductor_graph_partition", + [ + pytest.param( + True, + marks=pytest.mark.skipif( + not has_cuda_graph_wrapper_metadata(), + reason="This test requires" + "torch._inductor.utils.CUDAGraphWrapperMetadata to run", + ), + ), + False, + ], +) def test_attn_quant( model_name: str, model_kwargs: dict[str, Any], @@ -192,7 +218,7 @@ def test_attn_quant( splitting_ops=splitting_ops, # Common mode=CompilationMode.VLLM_COMPILE, - pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True), # Inductor caches custom passes by default as well via uuid inductor_compile_config={"force_disable_caches": True}, ) @@ -282,9 +308,9 @@ def test_tp2_attn_quant_allreduce_rmsnorm( # Common mode=CompilationMode.VLLM_COMPILE, pass_config=PassConfig( - enable_attn_fusion=True, - enable_noop=True, - enable_fi_allreduce_fusion=True, + fuse_attn_quant=True, + eliminate_noops=True, + fuse_allreduce_rms=True, ), # Inductor caches custom passes by default as well via uuid inductor_compile_config={"force_disable_caches": True}, @@ -298,10 +324,14 @@ def test_tp2_attn_quant_allreduce_rmsnorm( r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", log_holder.text, ) - assert len(log_matches) == 2, log_holder.text + # 2 for each compile range + # (global compile range can be split due to fuse_allreduce_rmsnorm) + num_compile_ranges = len(compilation_config.get_compile_ranges()) + assert num_compile_ranges in [1, 2] - assert int(log_matches[0]) == matches.attention_fusion - assert int(log_matches[1]) == matches.attention_fusion + assert len(log_matches) == 2 * num_compile_ranges, log_holder.text + + assert all(int(log_match) == matches.attention_fusion for log_match in log_matches) log_matches = re.findall( r"collective_fusion.py:\d+] Replaced (\d+) patterns", @@ -312,6 +342,12 @@ def test_tp2_attn_quant_allreduce_rmsnorm( assert int(log_matches[0]) == matches.allreduce_fusion assert int(log_matches[1]) == matches.allreduce_fusion + log_matches = re.findall( + r"pass_manager.py:\d+] Skipping .*AllReduceFusionPass.* with compile range", + log_holder.text, + ) + assert len(log_matches) == 2 * (num_compile_ranges - 1), log_holder.text + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( @@ -384,10 +420,10 @@ def test_tp2_attn_quant_async_tp( # Common level=CompilationMode.VLLM_COMPILE, pass_config=PassConfig( - enable_attn_fusion=True, - enable_noop=True, - enable_sequence_parallelism=True, - enable_async_tp=True, + fuse_attn_quant=True, + eliminate_noops=True, + enable_sp=True, + fuse_gemm_comms=True, ), # Inductor caches custom passes by default as well via uuid inductor_compile_config={"force_disable_caches": True}, @@ -446,7 +482,6 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg # No cudagraphs by default if compilation_config.cudagraph_mode is None: compilation_config.cudagraph_mode = CUDAGraphMode.NONE - llm = LLM( model=model, compilation_config=compilation_config, @@ -459,3 +494,85 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # Get the compile ranges split points after vllm config post init + # in order to compute compile ranges correctly + compilation_config.compile_ranges_split_points = ( + llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points + ) + + +if current_platform.is_cuda(): + MODELS_GROUP_FP8 = [ + ModelBackendTestCase( + model_name="Qwen/Qwen3-30B-A3B-FP8", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=AttentionBackendEnum.TRITON_ATTN, + matches=Matches( + rms_quant_norm_fusion=48, + ), + ), + ] + +CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"] + + +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, matches, custom_ops", + # Test rms norm+group quant_fp8 fusion + list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +def test_rms_group_quant( + model_name: str, + model_kwargs: dict[str, Any], + backend: AttentionBackendEnum, + matches: Matches, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + custom_ops=custom_ops_list, + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + splitting_ops=splitting_ops, + # Common + mode=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig(eliminate_noops=True, enable_fusion=True), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model(compilation_config, model_name, **model_kwargs) + + log_matches = re.findall( + r"\[fusion.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(log_matches) == 1, log_holder.text + assert int(log_matches[0]) == matches.rms_quant_norm_fusion diff --git a/tests/compile/distributed/test_sequence_parallelism.py b/tests/compile/distributed/test_sequence_parallelism.py index 30084dfd5a950..d9fdc3acc3d6f 100644 --- a/tests/compile/distributed/test_sequence_parallelism.py +++ b/tests/compile/distributed/test_sequence_parallelism.py @@ -153,7 +153,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): ] def ops_in_model(self): - if self.vllm_config.compilation_config.pass_config.enable_fusion: + if self.vllm_config.compilation_config.pass_config.fuse_norm_quant: return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] elif RMSNorm.enabled(): return [ @@ -183,7 +183,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("enable_fusion", [True, False]) +@pytest.mark.parametrize("fuse_norm_quant", [True, False]) @pytest.mark.parametrize("dynamic", [False, True]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_sequence_parallelism_pass( @@ -193,7 +193,7 @@ def test_sequence_parallelism_pass( seq_len: int, hidden_size: int, dtype: torch.dtype, - enable_fusion: bool, + fuse_norm_quant: bool, dynamic: bool, ): num_processes = 2 @@ -211,7 +211,7 @@ def test_sequence_parallelism_pass( seq_len, hidden_size, dtype, - enable_fusion, + fuse_norm_quant, dynamic, ), nprocs=nprocs, @@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, - enable_fusion: bool, + fuse_norm_quant: bool, dynamic: bool, ): current_platform.seed_everything(0) @@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model( cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings custom_ops=custom_ops_list, pass_config=PassConfig( - enable_sequence_parallelism=True, - enable_fusion=enable_fusion, - enable_noop=True, + enable_sp=True, + fuse_norm_quant=fuse_norm_quant, + eliminate_noops=True, ), ) # NoOp needed for fusion device_config = DeviceConfig(device=torch.device("cuda")) @@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model( sequence_parallelism_pass, ] - if enable_fusion: + if fuse_norm_quant: fusion_pass = RMSNormQuantFusionPass(vllm_config) passes_for_backend.append(fusion_pass) diff --git a/tests/compile/fullgraph/test_full_graph.py b/tests/compile/fullgraph/test_full_graph.py index 2c11ecef7f029..3cd1d4be2ebdc 100644 --- a/tests/compile/fullgraph/test_full_graph.py +++ b/tests/compile/fullgraph/test_full_graph.py @@ -122,7 +122,9 @@ def test_full_graph( CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm"], - pass_config=PassConfig(enable_fusion=True, enable_noop=True), + pass_config=PassConfig( + fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True + ), ), *model_info, ) diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index c65e5a25934d2..8fa305d6d72f5 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +import multiprocessing import tempfile from contextlib import contextmanager @@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): artifacts = compiled_mod.aot_compiled_fn._artifacts guards_string = artifacts.compiled_fn.shape_env.format_guards() assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +@use_vllm_config(make_vllm_config()) +def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): + """ + Test that compiling gpt2 twice results in a cache hit and + capture torch dynamic symbol creations to ensure make_symbol + not called on cache hit. + """ + + import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module + from torch.utils._sympy.symbol import make_symbol + + from vllm import LLM + + create_symbol_counter = multiprocessing.Value("i", 0) + original_make_symbol = make_symbol + + @functools.wraps(original_make_symbol) + def counting_make_symbol(prefix, idx, **kwargs): + with create_symbol_counter.get_lock(): + create_symbol_counter.value += 1 + return original_make_symbol(prefix, idx, **kwargs) + + symbolic_shapes_module.make_symbol = counting_make_symbol + try: + with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname: + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + # First compilation - initialize model and generate + llm_model = LLM( + model="gpt2", + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ), + max_model_len=256, + ) + + llm_model.generate("Hello, my name is") + assert create_symbol_counter.value == 2 + create_symbol_counter.value = 0 + + # Clean up first model + del llm_model + + # Second compilation - should hit cache + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + llm_model = LLM( + model="gpt2", + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ), + max_model_len=256, + ) + llm_model.generate("Hello, my name is") + + assert create_symbol_counter.value == 0 + + finally: + # Restore original method + symbolic_shapes_module.make_symbol = original_make_symbol diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py new file mode 100644 index 0000000000000..14ae8233f1131 --- /dev/null +++ b/tests/compile/test_compile_ranges.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +from torch import fx as fx +from torch import nn + +# This import automatically registers `torch.ops.silly.attention` +import tests.compile.silly_attention # noqa +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.compilation.inductor_pass import ( + InductorPass, + get_pass_context, +) +from vllm.config import ( + VllmConfig, + set_current_vllm_config, +) +from vllm.config.compilation import CompilationConfig, CompilationMode +from vllm.config.scheduler import SchedulerConfig +from vllm.config.utils import Range +from vllm.forward_context import set_forward_context + +BATCH_SIZE = 64 +MLP_SIZE = 128 + + +@support_torch_compile +class TestModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, batch_sizes: list[int]): + with set_forward_context({}, vllm_config=vllm_config): + model(torch.randn(BATCH_SIZE, MLP_SIZE)) + for batch_size in batch_sizes: + model(torch.randn(batch_size, MLP_SIZE)) + + +class PostGradRangeChecker(InductorPass): + def __init__(self, ranges: list[Range]): + self.ranges = ranges + self.num_calls = 0 + + def __call__(self, graph: fx.Graph): + compile_range = get_pass_context().compile_range + assert compile_range in self.ranges, ( + f"Compile range {compile_range} not in {self.ranges}" + ) + self.num_calls += 1 + + def uuid(self) -> str: + state: dict[str, Any] = {} + return InductorPass.hash_dict(state) + + +def test_compile_ranges(use_fresh_inductor_cache): + post_grad_range_checker = PostGradRangeChecker( + [ + Range(start=1, end=8), + Range(start=16, end=16), + Range(start=9, end=32), + Range(start=64, end=64), + Range(start=33, end=8192), + ] + ) + torch.set_default_device("cuda") + vllm_config = VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + max_model_len=8192, + is_encoder_decoder=False, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_ranges_split_points=[8, 32], + compile_sizes=[16, 64, 128], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_range_checker, + }, + ), + ) + + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix="").eval() + # Number of compilations: 3 for each compile range + 2 compile sizes + batch_sizes = [1, 4, 16, 24, 48, 64, 8192] + + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=5, + ): + run_model(vllm_config, model, batch_sizes) + assert post_grad_range_checker.num_calls == 5 + + +def test_compile_config_get_compile_ranges(): + compilation_config = CompilationConfig( + compile_ranges_split_points=[8, 32], + ) + VllmConfig( + scheduler_config=SchedulerConfig( + max_num_batched_tokens=8192, + max_model_len=8192, + is_encoder_decoder=False, + ), + compilation_config=compilation_config, + ) + assert compilation_config.get_compile_ranges() == [ + Range(start=1, end=8), + Range(start=9, end=32), + Range(start=33, end=8192), + ] + + +def test_inductor_cache_compile_ranges(monkeypatch, use_fresh_inductor_cache): + # To force multiple compilations, we disable the compile cache + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + post_grad_range_checker = PostGradRangeChecker( + ranges=[ + Range(start=1, end=8), + Range(start=9, end=8192), + ] + ) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=8192, + max_model_len=8192, + is_encoder_decoder=False, + ) + torch.set_default_device("cuda") + + def create_vllm_config(): + return VllmConfig( + scheduler_config=scheduler_config, + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + compile_ranges_split_points=[8], + inductor_compile_config={ + "post_grad_custom_post_pass": post_grad_range_checker, + }, + ), + ) + + vllm_config_1 = create_vllm_config() + with set_current_vllm_config(vllm_config_1): + model1 = TestModel(vllm_config=vllm_config_1, prefix="").eval() + batch_sizes = [1, 16] + run_model(vllm_config_1, model1, batch_sizes) + assert post_grad_range_checker.num_calls == 2 + + post_grad_range_checker.num_calls = 0 + # Create a new vllm config with the new pass context + vllm_config_2 = create_vllm_config() + with set_current_vllm_config(vllm_config_2): + model2 = TestModel(vllm_config=vllm_config_2, prefix="").eval() + batch_sizes = [4, 32] + run_model(vllm_config_2, model2, batch_sizes) + # Check that cache is used, so the number of calls + # should be 0 + assert post_grad_range_checker.num_calls == 0 diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index a9e5ccee520e3..04bb56ecb6470 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -9,8 +9,8 @@ from pydantic import ValidationError from vllm.compilation.counter import compilation_counter from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig -from vllm.config.compilation import CompilationMode +from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig +from vllm.config.compilation import CompilationMode, PassConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.utils.torch_utils import _is_torch_equal_or_newer @@ -191,7 +191,7 @@ def test_splitting_ops_dynamic(): config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - pass_config={"enable_attn_fusion": True, "enable_noop": True}, + pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True), custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, ) @@ -206,7 +206,7 @@ def test_splitting_ops_dynamic(): config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - pass_config={"enable_attn_fusion": True, "enable_noop": True}, + pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True), custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, # work around for accessing all attntion ops @@ -219,7 +219,7 @@ def test_splitting_ops_dynamic(): compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, - pass_config={"enable_attn_fusion": True, "enable_noop": True}, + pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True), custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, ) @@ -227,12 +227,76 @@ def test_splitting_ops_dynamic(): # With inductor graph partition, attn_fusion and splitting_ops # work together. Default splitting_ops include attention ops. assert config.compilation_config.splitting_ops_contain_attention() - # enable_attn_fusion is directly supported under + # fuse_attn_quant is directly supported under # use_inductor_graph_partition=True, and cudagraph_mode # is unchanged. assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE +def test_moe_splitting_ops_deepep_ht_piecewise(): + # Non-inductor, non-attn-fusion case: DeepEP HT with dp>1 + # should add MoE ops to splitting_ops on top of attention ops. + config = VllmConfig( + parallel_config=ParallelConfig( + all2all_backend="deepep_high_throughput", + data_parallel_size=8, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ), + ) + splitting_ops = config.compilation_config.splitting_ops + assert splitting_ops is not None + assert "vllm::moe_forward" in splitting_ops + assert "vllm::moe_forward_shared" in splitting_ops + + +def test_moe_splitting_ops_deepep_ht_inductor_partition(): + # Inductor partition case: user-provided splitting_ops should be + # preserved and MoE ops should be appended for DeepEP HT with dp>1. + config = VllmConfig( + parallel_config=ParallelConfig( + all2all_backend="deepep_high_throughput", + data_parallel_size=8, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + use_inductor_graph_partition=True, + splitting_ops=[ + "vllm::unified_attention", + "vllm::moe_forward", + "vllm::moe_forward_shared", + ], + ), + ) + splitting_ops = config.compilation_config.splitting_ops + assert splitting_ops == [ + "vllm::unified_attention", + "vllm::moe_forward", + "vllm::moe_forward_shared", + ] + + +def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor(): + # Pure attn-fusion case without inductor partition: even with + # DeepEP HT and dp>1, we should not re-enable piecewise compilation + # or add MoE ops into splitting_ops. + config = VllmConfig( + parallel_config=ParallelConfig( + all2all_backend="deepep_high_throughput", + data_parallel_size=8, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + pass_config={"fuse_attn_quant": True, "eliminate_noops": True}, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + ), + ) + assert config.compilation_config.splitting_ops == [] + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL + + def test_should_split(): import torch @@ -301,7 +365,7 @@ def test_should_split(): "cudagraph_capture_sizes", "max_cudagraph_capture_size", "tp_size", - "enable_sequence_parallelism", + "enable_sp", "max_num_batched_tokens", "cudagraph_mode", "expected_max_size", @@ -339,7 +403,7 @@ def test_cudagraph_sizes_post_init( cudagraph_capture_sizes, max_cudagraph_capture_size, tp_size, - enable_sequence_parallelism, + enable_sp, max_num_batched_tokens, cudagraph_mode, expected_max_size, @@ -355,11 +419,12 @@ def test_cudagraph_sizes_post_init( compilation_config = CompilationConfig( cudagraph_capture_sizes=cudagraph_capture_sizes, max_cudagraph_capture_size=max_cudagraph_capture_size, - pass_config={ - "enable_sequence_parallelism": enable_sequence_parallelism, - "enable_fusion": True, - "enable_noop": True, - }, + pass_config=PassConfig( + enable_sp=enable_sp, + fuse_norm_quant=True, + fuse_act_quant=True, + eliminate_noops=True, + ), cudagraph_mode=cudagraph_mode, ) engine_args = EngineArgs( diff --git a/tests/compile/test_dynamic_shapes_compilation.py b/tests/compile/test_dynamic_shapes_compilation.py index c20aea822fe81..9ccb363b088f5 100644 --- a/tests/compile/test_dynamic_shapes_compilation.py +++ b/tests/compile/test_dynamic_shapes_compilation.py @@ -2,13 +2,22 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +import tempfile +from contextlib import contextmanager import pytest import torch from vllm import LLM, SamplingParams -from vllm.config.compilation import CompilationMode, DynamicShapesType -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.config.compilation import ( + CompilationMode, + DynamicShapesConfig, + DynamicShapesType, +) +from vllm.forward_context import set_forward_context +from vllm.tokenizers import get_tokenizer from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -27,23 +36,30 @@ def get_test_models(): DynamicShapesType.BACKED_SIZE_OBLIVIOUS, ], ) -@pytest.mark.parametrize("use_aot_compile", ["0"]) +@pytest.mark.parametrize("use_aot_compile", ["0", "1"]) @pytest.mark.parametrize("use_bytecode_hook", [True, False]) +@pytest.mark.parametrize("evaluate_guards", [False, True]) @pytest.mark.skipif( not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" ) def test_dynamic_shapes_compilation( - monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook + monkeypatch, + model_name, + shapes_type, + use_aot_compile, + use_bytecode_hook, + evaluate_guards, ): """Test that all dynamic shapes types compile successfully""" - print( - f"\nTesting model: {model_name} with {shapes_type.name}, " - f"AOT compile: {use_aot_compile}, " - f"Bytecode hook: {use_bytecode_hook}" - ) if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED: pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0") + if evaluate_guards and shapes_type == DynamicShapesType.UNBACKED: + pytest.skip("unbacked dynamic shapes do not add guards") + + if evaluate_guards and use_aot_compile: + pytest.skip("evaluate_guards requires use_aot_compile=0") + monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile) monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") @@ -58,6 +74,7 @@ def test_dynamic_shapes_compilation( "mode": CompilationMode.VLLM_COMPILE, "dynamic_shapes_config": { "type": shapes_type.value, + "evaluate_guards": evaluate_guards, }, }, ) @@ -86,3 +103,117 @@ def test_dynamic_shapes_compilation( torch.cuda.empty_cache() torch.cuda.synchronize() print("GPU memory cleared") + + +@pytest.mark.parametrize("use_aot_compile", ["0", "1"]) +@pytest.mark.parametrize( + "dynamic_shapes_type", + [ + DynamicShapesType.BACKED, + DynamicShapesType.BACKED_SIZE_OBLIVIOUS, + ], +) +@pytest.mark.parametrize("evaluate_guards", [False, True]) +def test_model_specialization_with_evaluate_guards( + monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards +): + """Test that evaluate_guards correctly detects shape specialization + violations. + """ + + if ( + use_aot_compile == "1" + and dynamic_shapes_type == DynamicShapesType.BACKED + and evaluate_guards + ): + pytest.skip("evaluate_guards for backed does not work with aot_compile=1") + + @support_torch_compile + class ModelWithSizeCheck(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + + def forward(self, x: torch.Tensor): + # This will cause specialization - torch.compile will guard on + # sx.shape[0] + if x.shape[0] >= 10: + return x * 10 + else: + return x * 10 + + @support_torch_compile + class ModelWithOneSizeCheck(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + + def forward(self, x: torch.Tensor): + # This will cause 0/1 specializations. + if x.shape[0] == 0: + return x * 10 + if x.shape[0] == 1: + return x * 10 + else: + return x * 10 + + @contextmanager + def use_vllm_config(vllm_config: VllmConfig): + with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config): + yield + + monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true") + monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile) + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0") + + # Create vllm config with the desired settings + from vllm.config import CompilationMode + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + dynamic_shapes_config=DynamicShapesConfig( + type=dynamic_shapes_type, + evaluate_guards=evaluate_guards, + ), + ) + ) + + def test(model_class, input1, input2, is_01_specialization=False): + with ( + torch.no_grad(), + use_vllm_config(vllm_config), + tempfile.TemporaryDirectory() as tmpdirname, + ): + monkeypatch.setenv("VLLM_CACHE_ROOT", tmpdirname) + + model = model_class(vllm_config=vllm_config).cuda() + + model(input1) + + if evaluate_guards and ( + not ( + is_01_specialization + and dynamic_shapes_type == DynamicShapesType.BACKED + ) + ): + # This should fail because guards were added. + with pytest.raises(RuntimeError) as excinfo: + model(input2) + + # Expected failure - guard was violated + error_msg = str(excinfo.value) + assert ( + "GuardManager check failed" in error_msg + or "Detected recompile when torch.compile stance" in error_msg + ), error_msg + + else: + model(input2) + + test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda()) + test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda()) + test( + ModelWithOneSizeCheck, + torch.randn(20, 10).cuda(), + torch.randn(1, 10).cuda(), + is_01_specialization=True, + ) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 515e0a93ac2a8..ad5ead36e2310 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -128,14 +128,12 @@ class TestFusedAddRMSNorm(torch.nn.Module): class TestRotaryEmbedding(torch.nn.Module): - def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000): + def __init__(self, head_dim=64, max_position=2048, base=10000): super().__init__() self.head_dim = head_dim - self.rotary_dim = rotary_dim or head_dim self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.rotary_dim, max_position=max_position, rope_parameters={"rope_type": "default", "rope_theta": base}, ) @@ -170,7 +168,6 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters={"rope_type": "default", "rope_theta": base}, ) @@ -223,7 +220,11 @@ def test_fix_functionalization( model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( custom_ops=["all"], - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True), + pass_config=PassConfig( + fuse_norm_quant=do_fusion, + fuse_act_quant=do_fusion, + eliminate_noops=True, + ), ), ) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 286f2276367a0..6b72c595cd779 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools + import pytest import torch import vllm.plugins +from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.matcher_utils import QUANT_OPS @@ -18,6 +21,9 @@ from vllm.config import ( VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -25,10 +31,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, + cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, ) from vllm.platforms import current_platform +from vllm.utils.deep_gemm import is_deep_gemm_supported from ..utils import override_cutlass_fp8_supported from .backend import TestBackend @@ -44,7 +52,7 @@ class TestModel(torch.nn.Module): self, hidden_size: int, eps: float, - static: bool, + group_shape: GroupShape, cuda_force_torch: bool, *args, **kwargs, @@ -52,8 +60,17 @@ class TestModel(torch.nn.Module): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN + if group_shape.is_per_group(): + self.wscale = [ + torch.rand( + (hidden_size // group_shape[1], hidden_size // group_shape[1]), + dtype=torch.float32, + ) + for _ in range(3) + ] + else: + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + static = group_shape == GroupShape.PER_TENSOR quant_scale = ScaleDesc(torch.float32, static, group_shape) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: @@ -61,18 +78,29 @@ class TestModel(torch.nn.Module): else: self.scale = [None for _ in range(3)] self.w = [ - torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - for _ in range(3) + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) ] + if not group_shape.is_per_group(): + self.w = [self.w[0].t() for _ in range(3)] - with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=static, + if group_shape.is_per_group(): + self.fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(group_shape[1], group_shape[1]), act_quant_group_shape=group_shape, + cutlass_block_fp8_supported=cutlass_block_fp8_supported(), + use_aiter_and_is_supported=False, ) + self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled() + else: + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=static, + act_quant_group_shape=group_shape, + ) + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + self.group_shape = group_shape def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -119,13 +147,87 @@ class TestModel(torch.nn.Module): ) +GROUP_SHAPES = [ + GroupShape.PER_TOKEN, + GroupShape.PER_TENSOR, + GroupShape(1, 128), + GroupShape(1, 64), +] + + +class TestRmsnormGroupFp8QuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, eps: float, **kwargs): + super().__init__() + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(128, 128), + act_quant_group_shape=GroupShape(1, 128), + cutlass_block_fp8_supported=False, + use_aiter_and_is_supported=True, + ) + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(3) + ] + + scale_hidden_size = (hidden_size + 128 - 1) // 128 + self.wscale = [ + torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32) + for _ in range(3) + ] + + self.norm_weight = [torch.ones(hidden_size) for _ in range(4)] + self.eps = eps + + def forward(self, x): + # avoid having graph input be an arg to a pattern directly + x = resid = torch.relu(x) + y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps) + + x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0]) + # make sure resid is used for replacement to work + y2, resid = rocm_aiter_ops.rms_norm2d_with_add( + x2, resid, self.norm_weight[1], self.eps + ) + + x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1]) + + y3, resid = rocm_aiter_ops.rms_norm2d_with_add( + x3, resid, self.norm_weight[2], self.eps + ) + + x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2]) + + y4, resid = rocm_aiter_ops.rms_norm2d_with_add( + x4, resid, self.norm_weight[3], self.eps + ) + return y4 + + def ops_in_model_before(self): + return [ + torch.ops.vllm.rocm_aiter_rms_norm, + torch.ops.vllm.rocm_aiter_group_fp8_quant, + ] + + def ops_in_model_before_partial(self): + return [] + + def ops_in_model_after(self): + return [ + torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant, + torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant, + ] + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) -@pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) -@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) +@pytest.mark.parametrize("group_shape", GROUP_SHAPES) +@pytest.mark.parametrize( + "model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op", + list(itertools.product([TestModel], [True, False], [True, False])) + + [(TestRmsnormGroupFp8QuantModel, False, False)], +) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -139,16 +241,29 @@ def test_fusion_rmsnorm_quant( hidden_size, num_tokens, eps, - static, + group_shape, + model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, cuda_force_torch, ): + if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND: + pytest.skip("AITER is not supported on this GPU.") + torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + if not enable_quant_fp8_custom_op and group_shape.is_per_group(): + pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") + + # Skip test for 64-bit group shape when running with cutlass or deepgemm + if group_shape == GroupShape(1, 64) and ( + cutlass_block_fp8_supported() or is_deep_gemm_supported() + ): + pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm") + custom_ops = [] if enable_rms_norm_custom_op: custom_ops.append("+rms_norm") @@ -159,19 +274,32 @@ def test_fusion_rmsnorm_quant( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops, - pass_config=PassConfig(enable_fusion=True, enable_noop=True), + pass_config=PassConfig( + fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True + ), ), ) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = RMSNormQuantFusionPass(vllm_config) + if model_class is TestRmsnormGroupFp8QuantModel: + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterRMSNormFp8GroupQuantFusionPass, + ) + + fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config) + else: + fusion_pass = RMSNormQuantFusionPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend2 = TestBackend(noop_pass, cleanup_pass) - model = TestModel(hidden_size, eps, static, cuda_force_torch) - + model = model_class( + hidden_size=hidden_size, + eps=eps, + group_shape=group_shape, + cuda_force_torch=cuda_force_torch, + ) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) @@ -200,7 +328,10 @@ def test_fusion_rmsnorm_quant( # there's a risk that the fused add doesn't get included in the # replacement and only the rms part gets fused with quant. # Hence, we check only 2 add nodes are left (final fused rmsnorm add). - if not enable_rms_norm_custom_op: + if ( + not enable_rms_norm_custom_op + and model_class is not TestRmsnormGroupFp8QuantModel + ): n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) assert n_add_nodes(backend.graph_pre_pass) == 7 diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index dbe12dc5de705..db95dff5e0fc7 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -12,13 +12,13 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import Attention -from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( + AttentionConfig, CacheConfig, CompilationConfig, CompilationMode, @@ -318,18 +318,24 @@ def test_attention_quant_pattern( torch.set_default_dtype(dtype) torch.manual_seed(42) + model_config = ModelConfig( + model=model_name, + max_model_len=2048, + dtype=dtype, + ) vllm_config = VllmConfig( - model_config=ModelConfig( - model=model_name, - max_model_len=2048, - dtype=dtype, + model_config=model_config, + scheduler_config=SchedulerConfig( + max_num_seqs=1024, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, ), - scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops_list, ), cache_config=CacheConfig(cache_dtype="fp8"), + attention_config=AttentionConfig(backend=backend), ) # Create test inputs @@ -347,7 +353,6 @@ def test_attention_quant_pattern( with ( set_current_vllm_config(vllm_config_unfused), set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused), - global_force_attn_backend_context_manager(backend), ): model_unfused = model_class( num_qo_heads=num_qo_heads, @@ -368,12 +373,11 @@ def test_attention_quant_pattern( # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( - enable_attn_fusion=True, enable_noop=True + fuse_attn_quant=True, eliminate_noops=True ) with ( set_current_vllm_config(vllm_config), set_forward_context(attn_metadata=None, vllm_config=vllm_config), - global_force_attn_backend_context_manager(backend), ): model_fused = model_class( num_qo_heads=num_qo_heads, diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/test_noop_elimination.py index 0ccc1a0161629..bfe08382fd949 100644 --- a/tests/compile/test_noop_elimination.py +++ b/tests/compile/test_noop_elimination.py @@ -51,7 +51,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size): vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - pass_config=PassConfig(enable_noop=True), + pass_config=PassConfig(eliminate_noops=True), ) ) with vllm.config.set_current_vllm_config(vllm_config): @@ -99,7 +99,7 @@ def test_non_noop_slice_preserved(): vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, - pass_config=PassConfig(enable_noop=True), + pass_config=PassConfig(eliminate_noops=True), ) ) with vllm.config.set_current_vllm_config(vllm_config): diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 1c40c599f7487..6ed77b0085f51 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -5,9 +5,14 @@ import copy import pytest import torch -from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.compilation.inductor_pass import ( + CallableInductorPass, + InductorPass, + pass_context, +) from vllm.compilation.pass_manager import PostGradPassManager from vllm.config import ModelConfig, VllmConfig +from vllm.config.utils import Range # dummy custom pass that doesn't inherit @@ -42,32 +47,37 @@ class ProperPass(InductorPass): ], ) def test_pass_manager_uuid(callable): - # Some passes need dtype to be set - config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) + # Set the pass context as PassManager uuid uses it + with pass_context(Range(start=1, end=8)): + # Some passes need dtype to be set + config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) - pass_manager = PostGradPassManager() - pass_manager.configure(config) + pass_manager = PostGradPassManager() + pass_manager.configure(config) - # Check that UUID is different if the same pass is added 2x - pass_manager.add(callable) - uuid1 = pass_manager.uuid() - pass_manager.add(callable) - uuid2 = pass_manager.uuid() - assert uuid1 != uuid2 + # Check that UUID is different if the same pass is added 2x + pass_manager.add(callable) + uuid1 = pass_manager.uuid() + pass_manager.add(callable) + uuid2 = pass_manager.uuid() + assert uuid1 != uuid2 - # UUID should be the same as the original one, - # as we constructed in the same way. - pass_manager2 = PostGradPassManager() - pass_manager2.configure(config) - pass_manager2.add(callable) - assert uuid1 == pass_manager2.uuid() + # UUID should be the same as the original one, + # as we constructed in the same way. + pass_manager2 = PostGradPassManager() + pass_manager2.configure(config) + pass_manager2.add(callable) + assert uuid1 == pass_manager2.uuid() - # UUID should be different due to config change - config2 = copy.deepcopy(config) - config2.compilation_config.pass_config.enable_fusion = ( - not config2.compilation_config.pass_config.enable_fusion - ) - pass_manager3 = PostGradPassManager() - pass_manager3.configure(config2) - pass_manager3.add(callable) - assert uuid1 != pass_manager3.uuid() + # UUID should be different due to config change + config2 = copy.deepcopy(config) + config2.compilation_config.pass_config.fuse_norm_quant = ( + not config2.compilation_config.pass_config.fuse_norm_quant + ) + config2.compilation_config.pass_config.fuse_act_quant = ( + not config2.compilation_config.pass_config.fuse_act_quant + ) + pass_manager3 = PostGradPassManager() + pass_manager3.configure(config2) + pass_manager3.add(callable) + assert uuid1 != pass_manager3.uuid() diff --git a/tests/compile/test_qk_norm_rope_fusion.py b/tests/compile/test_qk_norm_rope_fusion.py index 5ebb95b6db332..e0968ac799256 100644 --- a/tests/compile/test_qk_norm_rope_fusion.py +++ b/tests/compile/test_qk_norm_rope_fusion.py @@ -140,7 +140,7 @@ def test_qk_norm_rope_fusion( custom_ops=custom_ops, pass_config=PassConfig( enable_qk_norm_rope_fusion=True, - enable_noop=True, + eliminate_noops=True, ), ), ) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 0ddb82b7c3fc2..eb0dee8d4e399 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -7,6 +7,7 @@ import torch import vllm.envs as envs from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor +from vllm._aiter_ops import IS_AITER_FOUND from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.compilation.activation_quant_fusion import ( FUSED_OPS, @@ -24,6 +25,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, kFp8StaticTensorSym, @@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): return [FUSED_OPS[kNvfp4Quant]] +class TestSiluMulGroupFp8QuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, **kwargs): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(128, 128), + act_quant_group_shape=GroupShape(1, 128), + cutlass_block_fp8_supported=False, + use_aiter_and_is_supported=True, + ) + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + + scale_hidden_size = (hidden_size + 128 - 1) // 128 + self.wscale = torch.rand( + (scale_hidden_size, scale_hidden_size), dtype=torch.float32 + ) + + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() + + def forward(self, x): + y = self.silu_and_mul(x) + x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale) + return x2 + + def ops_in_model_before(self): + return [ + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, + ] + + def ops_in_model_after(self): + return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant] + + @pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module): @pytest.mark.parametrize( "model_class, enable_quant_fp8_custom_op, cuda_force_torch", list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) - + [(TestSiluMulNvfp4QuantModel, False, False)], + + [ + (TestSiluMulNvfp4QuantModel, False, False), + (TestSiluMulGroupFp8QuantModel, False, False), + ], ) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant( num_tokens: int, hidden_size: int, dtype: torch.dtype, - model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], + model_class: type[ + TestSiluMulFp8QuantModel + | TestSiluMulNvfp4QuantModel + | TestSiluMulGroupFp8QuantModel + ], enable_silu_mul_custom_op: bool, enable_quant_fp8_custom_op: bool, cuda_force_torch: bool, ): if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): pytest.skip("NVFP4 is not supported on this GPU.") + if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND: + pytest.skip("AITER is not supported on this GPU.") torch.set_default_device("cuda") torch.set_default_dtype(dtype) @@ -168,14 +212,20 @@ def test_fusion_silu_and_mul_quant( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops, - pass_config=PassConfig(enable_fusion=True, enable_noop=True), + pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True), ), ) with set_current_vllm_config(config): - fusion_pass = ActivationQuantFusionPass(config) + fusion_passes = [ActivationQuantFusionPass(config)] + if IS_AITER_FOUND: + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterSiluMulFp8GroupQuantFusionPass, + ) - passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] + fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] + + passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)] backend = TestBackend(*passes) model = model_class( hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x @@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant( atol, rtol = 1e-3, 1e-3 elif model_class == TestSiluMulNvfp4QuantModel: atol, rtol = 1e-1, 1e-1 + elif model_class == TestSiluMulGroupFp8QuantModel: + atol, rtol = 5e-2, 5e-2 torch.testing.assert_close( result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol ) - assert fusion_pass.matched_count == 1 + assert sum([p.matched_count for p in fusion_passes]) == 1 # In pre-nodes, quant op should be present and fused kernels should not backend.check_before_ops(model.ops_in_model_before()) diff --git a/tests/conftest.py b/tests/conftest.py index 317b36ba6cb80..a03f40a9a72ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ import threading from collections.abc import Generator from contextlib import nullcontext from enum import Enum -from typing import Any, Callable, TypedDict, TypeVar, cast +from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING import numpy as np import pytest @@ -59,6 +59,7 @@ from vllm.distributed import ( ) from vllm.logger import init_logger from vllm.logprobs import Logprob +from vllm.multimodal.base import MediaWithBytes from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams @@ -66,6 +67,14 @@ from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.collection_utils import is_list_of from vllm.utils.torch_utils import set_default_torch_num_threads +from torch._inductor.utils import fresh_cache + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + from transformers.generation.utils import GenerateOutput + + logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) @@ -193,6 +202,27 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): cleanup_dist_env_and_memory() +@pytest.fixture +def workspace_init(): + """Initialize the workspace manager for tests that need it. + + This fixture initializes the workspace manager with a CUDA device + if available, and resets it after the test completes. Tests that + create a full vLLM engine should NOT use this fixture as the engine + will initialize the workspace manager itself. + """ + from vllm.v1.worker.workspace import ( + init_workspace_manager, + reset_workspace_manager, + ) + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + init_workspace_manager(device) + yield + reset_workspace_manager() + + @pytest.fixture(autouse=True) def dynamo_reset(): yield @@ -201,10 +231,7 @@ def dynamo_reset(): @pytest.fixture def example_prompts() -> list[str]: - prompts = [] - for filename in _TEST_PROMPTS: - prompts += _read_prompts(filename) - return prompts + return [prompt for filename in _TEST_PROMPTS for prompt in _read_prompts(filename)] @pytest.fixture @@ -223,10 +250,7 @@ class DecoderPromptType(Enum): @pytest.fixture def example_long_prompts() -> list[str]: - prompts = [] - for filename in _LONG_PROMPTS: - prompts += _read_prompts(filename) - return prompts + return [prompt for filename in _LONG_PROMPTS for prompt in _read_prompts(filename)] @pytest.fixture(scope="session") @@ -352,10 +376,13 @@ class HfRunner: trust_remote_code=trust_remote_code, ) else: - model = auto_cls.from_pretrained( - model_name, - trust_remote_code=trust_remote_code, - **model_kwargs, + model = cast( + nn.Module, + auto_cls.from_pretrained( + model_name, + trust_remote_code=trust_remote_code, + **model_kwargs, + ), ) # in case some unquantized custom models are not in same dtype @@ -373,10 +400,12 @@ class HfRunner: self.model = model if not skip_tokenizer_init: - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - dtype=dtype, - trust_remote_code=trust_remote_code, + self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = ( + AutoTokenizer.from_pretrained( + model_name, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) ) # don't put this import at the top level @@ -397,6 +426,7 @@ class HfRunner: images: PromptImageInput | None = None, videos: PromptVideoInput | None = None, audios: PromptAudioInput | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]: if images is not None: assert len(prompts) == len(images) @@ -410,10 +440,18 @@ class HfRunner: all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = [] for i, prompt in enumerate(prompts): if isinstance(prompt, str): - processor_kwargs: dict[str, Any] = { - "text": prompt, - "return_tensors": "pt", - } + # Create a copy to avoid modifying the original dict + processor_kwargs = ( + tokenization_kwargs.copy() + if tokenization_kwargs is not None + else {} + ) + processor_kwargs.update( + { + "text": prompt, + "return_tensors": "pt", + } + ) if images is not None and (image := images[i]) is not None: processor_kwargs["images"] = image if videos is not None and (video := videos[i]) is not None: @@ -494,7 +532,7 @@ class HfRunner: outputs: list[tuple[list[list[int]], list[str]]] = [] for inputs in all_inputs: - output_ids = self.model.generate( + output_ids: torch.Tensor = self.model.generate( **self.wrap_device(inputs), use_cache=True, **kwargs, @@ -504,8 +542,7 @@ class HfRunner: skip_special_tokens=True, clean_up_tokenization_spaces=False, ) - output_ids = output_ids.cpu().tolist() - outputs.append((output_ids, output_str)) + outputs.append((output_ids.cpu().tolist(), output_str)) return outputs def generate_greedy( @@ -573,7 +610,7 @@ class HfRunner: all_logprobs: list[list[torch.Tensor]] = [] for inputs in all_inputs: - output = self.model.generate( + output: "GenerateOutput" = self.model.generate( **self.wrap_device(inputs), use_cache=True, do_sample=False, @@ -655,7 +692,7 @@ class HfRunner: all_output_strs: list[str] = [] for inputs in all_inputs: - output = self.model.generate( + output: "GenerateOutput" = self.model.generate( **self.wrap_device(inputs), use_cache=True, do_sample=False, @@ -665,10 +702,16 @@ class HfRunner: **kwargs, ) + # Encoder-decoder models return decoder_hidden_states instead of + # hidden_states + hidden_states = ( + getattr(output, "hidden_states", None) or output.decoder_hidden_states + ) + ( seq_logprobs_lst, output_len, - ) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs) + ) = self._hidden_states_to_logprobs(hidden_states, num_logprobs) all_logprobs.append(seq_logprobs_lst) seq_ids = output.sequences[0] @@ -725,7 +768,7 @@ class VllmRunner: tokenizer_name: str | None = None, tokenizer_mode: str = "auto", trust_remote_code: bool = True, - seed: int | None = 0, + seed: int = 0, max_model_len: int | None = 1024, dtype: str = "auto", disable_log_stats: bool = True, @@ -1174,6 +1217,7 @@ def caplog_mp_spawn(tmp_path, monkeypatch): "level": level, "filename": log_path.as_posix(), } + config["loggers"]["vllm"]["level"] = level config_path.write_text(json.dumps(config)) @@ -1388,7 +1432,11 @@ class LocalAssetServer: return f"{self.base_url}/{name}" def get_image_asset(self, name: str) -> Image.Image: - return fetch_image(self.url_for(name)) + image = fetch_image(self.url_for(name)) + # Unwrap MediaWithBytes if present + if isinstance(image, MediaWithBytes): + image = image.media + return image @pytest.fixture(scope="session") @@ -1456,3 +1504,14 @@ def clean_gpu_memory_between_tests(): if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + + +@pytest.fixture +def use_fresh_inductor_cache(): + """ + Use a fresh inductor cache for the test. + This is useful to ensure that the test is not affected by the + previous test calls. + """ + with fresh_cache(): + yield diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 7e4713b8aece0..aa47f28a34dd5 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -16,16 +16,35 @@ from typing import Literal, NamedTuple import pytest import torch +from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k +from tests.utils import RemoteOpenAIServer, create_new_process_for_each_test from vllm.config.model import RunnerOption from vllm.logger import init_logger from ..models.registry import HF_EXAMPLE_MODELS -from ..utils import compare_two_settings, create_new_process_for_each_test logger = init_logger("test_context_parallel") VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" +CP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "deepseek-ai/DeepSeek-V2-Lite-Chat", + "Qwen/Qwen2.5-1.5B-Instruct", +] + +# GSM8K eval configuration +NUM_QUESTIONS = 256 # Fast eval for CI +NUM_SHOTS = 5 # Few-shot examples +# tp accuracy with 2% buffer +MIN_ACCURACY = { + # .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml + "deepseek-ai/DeepSeek-V2-Lite-Chat": 0.64, + # .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml + "Qwen/Qwen2.5-1.5B-Instruct": 0.52, +} + class ParallelSetup(NamedTuple): tp_size: int @@ -38,7 +57,6 @@ class ParallelSetup(NamedTuple): class CPTestOptions(NamedTuple): multi_node_only: bool - load_format: str | None = None attn_backend: str | None = None @@ -54,17 +72,20 @@ class CPTestSettings: *, tp_base: int = 4, pp_base: int = 1, - dcp_base: int = 1, + dcp_multipliers: list[float] | None = None, cp_kv_cache_interleave_size: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", - load_format: str | None = None, attn_backend: str | None = None, ): parallel_setups = [] + if dcp_multipliers is None: + dcp_multipliers = [ + 0.5, + ] for eager_mode_val in [False]: for pp_multiplier in [1]: - for dcp_multiplier in [0.5, 1]: + for dcp_multiplier in dcp_multipliers: for chunked_prefill_val in [True]: parallel_setups.append( ParallelSetup( @@ -82,7 +103,6 @@ class CPTestSettings: runner=runner, test_options=CPTestOptions( multi_node_only=multi_node_only, - load_format=load_format, attn_backend=attn_backend, ), ) @@ -101,7 +121,27 @@ class CPTestSettings: ) -def _compare_cp_with_tp( +CP_TEXT_GENERATION_MODELS = { + "deepseek-ai/DeepSeek-V2-Lite-Chat": [ + CPTestSettings.detailed(dcp_multipliers=[1]), + CPTestSettings.detailed( + dcp_multipliers=[0.5], + cp_kv_cache_interleave_size=64, + attn_backend="FLASHMLA", + ), + ], + "Qwen/Qwen2.5-1.5B-Instruct": [ + CPTestSettings.detailed( + cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN" + ), + CPTestSettings.detailed( + cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER" + ), + ], +} + + +def _test_cp_gsm8k( model_id: str, parallel_setup: ParallelSetup, distributed_backend: str, @@ -121,7 +161,7 @@ def _compare_cp_with_tp( chunked_prefill, ) = parallel_setup - multi_node_only, load_format, attn_backend = test_options + multi_node_only, attn_backend = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_transformers_version(on_fail="skip") @@ -130,22 +170,7 @@ def _compare_cp_with_tp( tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides - if load_format == "dummy": - # Avoid OOM - text_overrides = { - "num_hidden_layers": 4, - "hidden_size": 512, - "intermediate_size": 800, - "num_attention_heads": 4, - "num_key_value_heads": 1, - } - - if is_multimodal: - hf_overrides.update({"text_config": text_overrides}) - else: - hf_overrides.update(text_overrides) - else: - model_info.check_available_online(on_fail="skip") + model_info.check_available_online(on_fail="skip") if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") @@ -157,90 +182,70 @@ def _compare_cp_with_tp( if multi_node_only and not VLLM_MULTI_NODE: pytest.skip("Not in multi-node setting") - common_args = [ + server_args = [ # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", "--max-model-len", - "2048", + "4096", "--max-num-seqs", - "8", + "64", ] if chunked_prefill: - common_args.append("--enable-chunked-prefill") + server_args.append("--enable-chunked-prefill") if eager_mode: - common_args.append("--enforce-eager") + server_args.append("--enforce-eager") if runner != "auto": - common_args.extend(["--runner", runner]) + server_args.extend(["--runner", runner]) if trust_remote_code: - common_args.append("--trust-remote-code") + server_args.append("--trust-remote-code") if tokenizer_mode: - common_args.extend(["--tokenizer-mode", tokenizer_mode]) - if load_format: - common_args.extend(["--load-format", load_format]) + server_args.extend(["--tokenizer-mode", tokenizer_mode]) if hf_overrides: - common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + server_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) - if not attn_backend: - cp_env = tp_env = {} - else: - cp_env = tp_env = { - "VLLM_ATTENTION_BACKEND": attn_backend, - } - - cp_args = [ - *common_args, - "--tensor-parallel-size", - str(tp_size), - "--pipeline-parallel-size", - str(pp_size), - "--decode-context-parallel-size", - str(dcp_size), - "--dcp-kv-cache-interleave-size", - str(cp_kv_cache_interleave_size), - "--distributed-executor-backend", - distributed_backend, - ] - - tp_args = [ - *common_args, - "--tensor-parallel-size", - str(tp_size), - "--pipeline-parallel-size", - str(pp_size), - "--distributed-executor-backend", - distributed_backend, - ] - - compare_two_settings( - model_id, - cp_args, - tp_args, - cp_env, - tp_env, - method=method, - max_wait_seconds=720, + server_args.extend( + [ + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--decode-context-parallel-size", + str(dcp_size), + "--dcp-kv-cache-interleave-size", + str(cp_kv_cache_interleave_size), + "--distributed-executor-backend", + distributed_backend, + ] ) + server_env = {} + if attn_backend: + server_env["VLLM_ATTENTION_BACKEND"] = attn_backend -CP_TEXT_GENERATION_MODELS = { - "deepseek-ai/DeepSeek-V2-Lite-Chat": [ - CPTestSettings.detailed(), - CPTestSettings.detailed(tp_base=2), - CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64), - ], - "bigcode/gpt_bigcode-santacoder": [ - CPTestSettings.detailed(), - CPTestSettings.detailed(tp_base=2), - ], -} + with RemoteOpenAIServer( + model_id, + server_args, + env_dict=server_env, + max_wait_seconds=720, + ) as remote_server: + host = f"http://{remote_server.host}" + port = remote_server.port -CP_TEST_MODELS = [ - # TODO support other models - # [LANGUAGE GENERATION] - "deepseek-ai/DeepSeek-V2-Lite-Chat", - "bigcode/gpt_bigcode-santacoder", -] + # Run GSM8K evaluation + results = evaluate_gsm8k( + num_questions=NUM_QUESTIONS, + num_shots=NUM_SHOTS, + host=host, + port=port, + ) + + # Validate accuracy is reasonable + accuracy = results["accuracy"] + min_accuracy = MIN_ACCURACY[model_id] + assert accuracy >= min_accuracy, ( + f"TP+DCP accuracy too low: {accuracy:.3f} < {min_accuracy:.3f}" + ) @pytest.mark.parametrize( @@ -274,12 +279,12 @@ def test_cp_generation( ): pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher") if ( - model_id == "bigcode/gpt_bigcode-santacoder" + model_id == "Qwen/Qwen2.5-1.5B-Instruct" and torch.cuda.get_device_capability() != (9, 0) ): pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0") - _compare_cp_with_tp( + _test_cp_gsm8k( model_id, parallel_setup, distributed_backend, diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index 79805a7cce53b..a53a61840e79e 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -4,7 +4,7 @@ import pytest import torch -from vllm.distributed.eplb.rebalance_algo import rebalance_experts +from vllm.distributed.eplb.policy.default import DefaultEplbPolicy def test_basic_rebalance(): @@ -23,7 +23,7 @@ def test_basic_rebalance(): num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -77,7 +77,7 @@ def test_single_gpu_case(): num_nodes = 1 num_gpus = 1 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -99,7 +99,7 @@ def test_equal_weights(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -122,7 +122,7 @@ def test_extreme_weight_imbalance(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -150,7 +150,7 @@ def test_multiple_layers(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -175,14 +175,14 @@ def test_parameter_validation(): # Test non-divisible case - this should handle normally without throwing # errors because the function will fall back to global load balancing # strategy - phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4) + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4) assert phy2log.shape == (1, 8) assert logcnt.shape == (1, 4) # Test cases that will actually cause errors: # num_physical_experts not divisible by num_gpus with pytest.raises(AssertionError): - rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4 + DefaultEplbPolicy.rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4 def test_small_scale_hierarchical(): @@ -197,7 +197,7 @@ def test_small_scale_hierarchical(): num_nodes = 2 # 2 nodes num_gpus = 4 # 4 GPUs - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -224,7 +224,7 @@ def test_global_load_balance_fallback(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -246,7 +246,7 @@ def test_device_compatibility(device): num_nodes = 1 num_gpus = 2 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -263,7 +263,9 @@ def test_additional_cases(): weight1 = torch.tensor( [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]] ) - phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) + phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts( + weight1, 24, 8, 4, 8 + ) assert phy2log1.shape == (1, 24) assert logcnt1.shape == (1, 16) @@ -276,7 +278,9 @@ def test_additional_cases(): [12, 25, 50, 100, 150, 200], # Increasing weights ] ) - phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) + phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts( + weight2, 10, 3, 1, 2 + ) assert phy2log2.shape == (2, 10) assert logcnt2.shape == (2, 6) @@ -300,7 +304,7 @@ if __name__ == "__main__": num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) print(phy2log) diff --git a/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py b/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py new file mode 100644 index 0000000000000..951b692e1edaf --- /dev/null +++ b/tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py @@ -0,0 +1,276 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4 + +from dataclasses import dataclass + +import pytest +import torch + +from tests.kernels.moe.utils import make_test_quant_config +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + get_dp_group, +) +from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4Config, + ModelOptNvFp4FusedMoE, +) + +from .eplb_utils import distributed_run, set_env_vars_and_device + + +@dataclass +class TestConfig: + num_layers: int + num_experts: int + num_local_experts: int + num_topk: int + hidden_size: int + intermediate_size: int + num_tokens: int + + +def make_fused_moe_layer( + rank: int, + layer_idx: int, + test_config: TestConfig, +) -> FusedMoE: + quant_config = None + + device = torch.device(f"cuda:{rank}") + + quant_config = ModelOptNvFp4Config( + is_checkpoint_nvfp4_serialized=True, + kv_cache_quant_algo=None, + exclude_modules=[], + ) + + fml = FusedMoE( + num_experts=test_config.num_experts, + top_k=test_config.num_topk, + hidden_size=test_config.hidden_size, + intermediate_size=test_config.intermediate_size, + prefix=f"dummy_layer_{layer_idx}", + activation="silu", + is_act_and_mul=True, + params_dtype=torch.bfloat16, + quant_config=quant_config, + ) + + nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml) + nvfp4_fused_moe.create_weights( + fml, + test_config.num_local_experts, + test_config.hidden_size, + test_config.intermediate_size, + params_dtype=torch.uint8, + global_num_experts=test_config.num_experts, + ) + + fml = fml.to(device) + w1_q, w2_q, quant_config = make_test_quant_config( + test_config.num_local_experts, + test_config.intermediate_size, + test_config.hidden_size, + in_dtype=torch.bfloat16, + quant_dtype="nvfp4", + block_shape=None, + per_act_token_quant=False, + ) + + fml.w13_weight.data = w1_q + fml.w2_weight.data = w2_q + + fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5 + fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5 + fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5 + fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5 + fml.w2_weight_scale.data = ( + torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5 + ).to(fml.w2_weight_scale.data.dtype) + fml.w13_weight_scale.data = ( + torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5 + ).to(fml.w13_weight_scale.data.dtype) + + nvfp4_fused_moe.process_weights_after_loading(fml) + + fml.maybe_init_modular_kernel() + + return fml + + +def _test_eplb_fml(env, world_size: int, test_config: TestConfig): + set_env_vars_and_device(env) + + vllm_config = VllmConfig() + vllm_config.parallel_config.data_parallel_size = world_size + vllm_config.parallel_config.enable_expert_parallel = True + + with set_current_vllm_config(vllm_config): + ensure_model_parallel_initialized( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + + ep_group = get_dp_group().cpu_group + ep_rank = torch.distributed.get_rank() + + device = torch.device(f"cuda:{ep_rank}") + + fml_layers = [ + make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device) + for layer_idx in range(test_config.num_layers) + ] + rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers] + + hidden_states = [] + router_logits = [] + for layer_idx in range(test_config.num_layers): + hidden_states.append( + torch.randn( + (test_config.num_tokens, test_config.hidden_size), + dtype=torch.bfloat16, + device=device, + ) + ) + router_logits.append( + torch.randn( + (test_config.num_tokens, test_config.num_experts), + dtype=torch.bfloat16, + device=device, + ) + ) + + out_before_shuffle = [] + with set_forward_context( + {}, + num_tokens=test_config.num_tokens, + num_tokens_across_dp=torch.tensor( + [test_config.num_tokens] * world_size, device="cpu", dtype=torch.int + ), + vllm_config=vllm_config, + ): + for lidx, fml in enumerate(fml_layers): + out_before_shuffle.append( + fml(hidden_states[lidx].clone(), router_logits[lidx].clone()) + ) + + indices = torch.zeros( + test_config.num_layers, test_config.num_experts, dtype=torch.long + ) + for lidx in range(test_config.num_layers): + indices[lidx] = torch.Tensor(range(test_config.num_experts)) + + shuffled_indices = torch.zeros_like(indices) + for lidx in range(test_config.num_layers): + shuffled_indices[lidx] = torch.randperm(test_config.num_experts) + + rearrange_expert_weights_inplace( + indices, + shuffled_indices, + rank_expert_weights, + ep_group, + is_profile=False, + ) + + num_global_experts = test_config.num_experts + + logical_to_physical_map_list = [] + for lidx, fml in enumerate(fml_layers): + physical_to_logical_map = shuffled_indices[lidx].to(device) + logical_to_physical_map = torch.empty( + (num_global_experts,), dtype=torch.int32, device=device + ) + logical_to_physical_map[physical_to_logical_map] = torch.arange( + 0, num_global_experts, dtype=torch.int32, device=device + ) + logical_to_physical_map_list.append( + logical_to_physical_map.reshape(num_global_experts, 1) + ) + + logical_to_physical_map = torch.stack(logical_to_physical_map_list) + + for lidx, fml in enumerate(fml_layers): + logical_replica_count = torch.ones( + (test_config.num_layers, num_global_experts), + dtype=torch.int32, + device=device, + ) + fml.enable_eplb = True + fml.set_eplb_state( + lidx, + torch.zeros( + (test_config.num_layers, num_global_experts), + dtype=torch.int32, + device=device, + ), + logical_to_physical_map, + logical_replica_count, + ) + + out_after_shuffle = [] + with set_forward_context( + {}, + num_tokens=test_config.num_tokens, + num_tokens_across_dp=torch.tensor( + [test_config.num_tokens] * world_size, device="cpu", dtype=torch.int + ), + vllm_config=vllm_config, + ): + for lidx, fml in enumerate(fml_layers): + out_after_shuffle.append( + fml(hidden_states[lidx].clone(), router_logits[lidx].clone()) + ) + + for lidx in range(test_config.num_layers): + torch.testing.assert_close( + out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1 + ) + + +@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("num_layers", [8]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("intermediate_size", [256]) +@pytest.mark.parametrize("num_tokens", [256]) +@pytest.mark.parametrize("backend", ["latency", "throughput"]) +def test_eplb_fml( + world_size: int, + num_layers: int, + num_experts: int, + hidden_size: int, + intermediate_size: int, + num_tokens: int, + backend: str, + monkeypatch, +): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") + monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend) + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + num_local_experts = num_experts // world_size + num_topk = 4 + + test_config = TestConfig( + num_layers=num_layers, + num_experts=num_experts, + num_local_experts=num_local_experts, + num_topk=num_topk, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + ) + + distributed_run( + _test_eplb_fml, + world_size, + test_config, + ) diff --git a/tests/distributed/test_eplb_spec_decode.py b/tests/distributed/test_eplb_spec_decode.py index c055b7a3f6dd7..22977ce94404b 100644 --- a/tests/distributed/test_eplb_spec_decode.py +++ b/tests/distributed/test_eplb_spec_decode.py @@ -6,6 +6,7 @@ import lm_eval import pytest from tests.utils import large_gpu_mark +from vllm.platforms import current_platform def get_model_args( @@ -22,7 +23,14 @@ def get_model_args( "num_speculative_tokens": 1, "max_model_len": model_max_len, } - + eplb_config = { + "num_redundant_experts": tp_size, + "window_size": 128, + "step_interval": 1024, + "log_balancedness": False, + } + if use_async: + eplb_config["use_async"] = True model_args = { "pretrained": model_name, "dtype": "auto", @@ -31,18 +39,19 @@ def get_model_args( "gpu_memory_utilization": 0.7, "speculative_config": speculative_config, "enable_expert_parallel": True, - "num_redundant_experts": tp_size, - "eplb_window_size": 128, - "eplb_step_interval": 1024, - "eplb_log_balancedness": False, + "eplb_config": eplb_config, "enable_eplb": True, "max_model_len": model_max_len, } - if use_async: - model_args["eplb_config"] = {"use_async": True} return model_args +pytestmark = pytest.mark.skipif( + current_platform.is_rocm(), + reason="EPLB with Spec Decode is a work in progress on ROCm.", +) + + @pytest.mark.parametrize( "model_setup", [ diff --git a/tests/distributed/test_kvlayout.py b/tests/distributed/test_kvlayout.py index b190b2820451b..c8177f1c7c2ff 100644 --- a/tests/distributed/test_kvlayout.py +++ b/tests/distributed/test_kvlayout.py @@ -61,7 +61,7 @@ def test_get_kv_connector_cache_layout_with_multi_connector(): kv_role="kv_both", kv_connector_extra_config={ "connectors": [ - {"kv_connector": "SharedStorageConnector", "kv_role": "kv_both"}, + {"kv_connector": "ExampleConnector", "kv_role": "kv_both"}, {"kv_connector": "NixlConnector", "kv_role": "kv_both"}, ] }, diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 89f035d2cdd6f..cc6251514c3dc 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -109,7 +109,7 @@ TEXT_GENERATION_MODELS = { "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(), "bigscience/bloomz-1b1": PPTestSettings.fast(), "zai-org/chatglm3-6b": PPTestSettings.fast(), - "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"), + "CohereLabs/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"), "databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"), "Deci/DeciLM-7B-instruct": PPTestSettings.fast(), "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(), diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index f38c509775ed5..0a7907aadeab5 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -32,7 +32,8 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" class ParallelSetup(NamedTuple): tp_size: int pp_size: int - enable_fusion: bool + fuse_norm_quant: bool + fuse_act_quant: bool eager_mode: bool chunked_prefill: bool @@ -66,7 +67,8 @@ class SPTestSettings: ParallelSetup( tp_size=tp_base, pp_size=pp_multiplier * pp_base, - enable_fusion=False, + fuse_norm_quant=False, + fuse_act_quant=False, eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, ) @@ -97,7 +99,8 @@ class SPTestSettings: ParallelSetup( tp_size=tp_base, pp_size=pp_multiplier * pp_base, - enable_fusion=False, + fuse_norm_quant=False, + fuse_act_quant=False, eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, ) @@ -126,7 +129,8 @@ class SPTestSettings: ParallelSetup( tp_size=tp_base, pp_size=pp_base, - enable_fusion=fusion_val, + fuse_norm_quant=fusion_val, + fuse_act_quant=fusion_val, eager_mode=True, chunked_prefill=False, ) @@ -162,7 +166,7 @@ def _compare_sp( test_options: SPTestOptions, num_gpus_available: int, use_inductor_graph_partition: bool, - enable_async_tp: bool, + fuse_gemm_comms: bool, *, method: Literal["generate", "encode"], is_multimodal: bool, @@ -170,7 +174,8 @@ def _compare_sp( ( tp_size, pp_size, - enable_fusion, + fuse_norm_quant, + fuse_act_quant, eager_mode, chunked_prefill, ) = parallel_setup @@ -248,10 +253,11 @@ def _compare_sp( "mode": CompilationMode.VLLM_COMPILE, "compile_sizes": [4, 8], "pass_config": { - "enable_sequence_parallelism": True, - "enable_async_tp": enable_async_tp, - "enable_fusion": enable_fusion, - "enable_noop": True, + "enable_sp": True, + "fuse_gemm_comms": fuse_gemm_comms, + "fuse_norm_quant": fuse_norm_quant, + "fuse_act_quant": fuse_act_quant, + "eliminate_noops": True, }, "use_inductor_graph_partition": use_inductor_graph_partition, } @@ -309,7 +315,7 @@ SP_TEST_MODELS = [ ], ) @pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) -@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP +@pytest.mark.parametrize("fuse_gemm_comms", [False]) # TODO: enable async TP @create_new_process_for_each_test() def test_tp_sp_generation( model_id: str, @@ -319,7 +325,7 @@ def test_tp_sp_generation( test_options: SPTestOptions, num_gpus_available, use_inductor_graph_partition: bool, - enable_async_tp: bool, + fuse_gemm_comms: bool, ): if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") @@ -328,7 +334,7 @@ def test_tp_sp_generation( if ( "fp8" in model_id.lower() and current_platform.get_device_capability() < (9, 0) - and (not enable_async_tp) + and (not fuse_gemm_comms) ): pytest.skip("FP8 reduction support begins with sm90 capable devices.") @@ -340,7 +346,7 @@ def test_tp_sp_generation( test_options, num_gpus_available, use_inductor_graph_partition, - enable_async_tp=enable_async_tp, + fuse_gemm_comms=fuse_gemm_comms, method="generate", is_multimodal=False, ) diff --git a/tests/distributed/test_shm_storage.py b/tests/distributed/test_shm_storage.py index b9a5c22447fd8..9ab35a292f872 100644 --- a/tests/distributed/test_shm_storage.py +++ b/tests/distributed/test_shm_storage.py @@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int): modality=modality, key=key, data=torch.empty((size,), dtype=torch.int8), - field=MultiModalSharedField(1), + field=MultiModalSharedField(batch_size=1), ) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index e46f118f8e846..c2cf77ffa12b6 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -350,21 +350,35 @@ def test_human_readable_model_len(): assert args.max_model_len == 1_000_000 args = parser.parse_args(["--max-model-len", "10k"]) assert args.max_model_len == 10_000 + args = parser.parse_args(["--max-model-len", "2g"]) + assert args.max_model_len == 2_000_000_000 + args = parser.parse_args(["--max-model-len", "2t"]) + assert args.max_model_len == 2_000_000_000_000 # Capital args = parser.parse_args(["--max-model-len", "3K"]) - assert args.max_model_len == 1024 * 3 + assert args.max_model_len == 2**10 * 3 args = parser.parse_args(["--max-model-len", "10M"]) assert args.max_model_len == 2**20 * 10 + args = parser.parse_args(["--max-model-len", "4G"]) + assert args.max_model_len == 2**30 * 4 + args = parser.parse_args(["--max-model-len", "4T"]) + assert args.max_model_len == 2**40 * 4 # Decimal values args = parser.parse_args(["--max-model-len", "10.2k"]) assert args.max_model_len == 10200 # ..truncated to the nearest int - args = parser.parse_args(["--max-model-len", "10.212345k"]) + args = parser.parse_args(["--max-model-len", "10.2123451234567k"]) assert args.max_model_len == 10212 + args = parser.parse_args(["--max-model-len", "10.2123451234567m"]) + assert args.max_model_len == 10212345 + args = parser.parse_args(["--max-model-len", "10.2123451234567g"]) + assert args.max_model_len == 10212345123 + args = parser.parse_args(["--max-model-len", "10.2123451234567t"]) + assert args.max_model_len == 10212345123456 # Invalid (do not allow decimals with binary multipliers) - for invalid in ["1a", "pwd", "10.24", "1.23M"]: + for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]: with pytest.raises(ArgumentError): - args = parser.parse_args(["--max-model-len", invalid]) + parser.parse_args(["--max-model-len", invalid]) diff --git a/tests/model_executor/model_loader/runai_model_streamer/__init__.py b/tests/entrypoints/openai/parser/__init__.py similarity index 100% rename from tests/model_executor/model_loader/runai_model_streamer/__init__.py rename to tests/entrypoints/openai/parser/__init__.py diff --git a/tests/entrypoints/openai/parser/test_harmony_utils.py b/tests/entrypoints/openai/parser/test_harmony_utils.py new file mode 100644 index 0000000000000..1d34fc51ad563 --- /dev/null +++ b/tests/entrypoints/openai/parser/test_harmony_utils.py @@ -0,0 +1,1201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem +from openai.types.responses.response_output_item import McpCall +from openai_harmony import Author, Message, Role, TextContent + +from tests.entrypoints.openai.utils import verify_harmony_messages +from vllm.entrypoints.openai.parser.harmony_utils import ( + auto_drop_analysis_messages, + get_encoding, + has_custom_tools, + parse_chat_input_to_harmony_message, + parse_chat_output, + parse_input_to_harmony_message, + parse_output_message, +) + + +class TestCommonParseInputToHarmonyMessage: + """ + Tests for scenarios that are common to both Chat Completion + parse_chat_input_to_harmony_message and Responsees API + parse_input_to_harmony_message functions. + """ + + @pytest.fixture( + params=[parse_chat_input_to_harmony_message, parse_input_to_harmony_message] + ) + def parse_function(self, request): + return request.param + + def test_assistant_message_with_tool_calls(self, parse_function): + """Test parsing assistant message with tool calls.""" + chat_msg = { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + } + }, + { + "function": { + "name": "search_web", + "arguments": '{"query": "latest news"}', + } + }, + ], + } + + messages = parse_function(chat_msg) + + assert len(messages) == 2 + + # First tool call + assert messages[0].author.role == Role.ASSISTANT + assert messages[0].content[0].text == '{"location": "San Francisco"}' + assert messages[0].channel == "commentary" + assert messages[0].recipient == "functions.get_weather" + assert messages[0].content_type == "json" + + # Second tool call + assert messages[1].author.role == Role.ASSISTANT + assert messages[1].content[0].text == '{"query": "latest news"}' + assert messages[1].channel == "commentary" + assert messages[1].recipient == "functions.search_web" + assert messages[1].content_type == "json" + + def test_assistant_message_with_empty_tool_call_arguments(self, parse_function): + """Test parsing assistant message with tool call having None arguments.""" + chat_msg = { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "get_current_time", + "arguments": None, + } + } + ], + } + + messages = parse_function(chat_msg) + + assert len(messages) == 1 + assert messages[0].content[0].text == "" + assert messages[0].recipient == "functions.get_current_time" + + def test_system_message(self, parse_function): + """Test parsing system message.""" + chat_msg = { + "role": "system", + "content": "You are a helpful assistant", + } + + messages = parse_function(chat_msg) + + assert len(messages) == 1 + # System messages are converted using Message.from_dict + # which should preserve the role + assert messages[0].author.role == Role.SYSTEM + + def test_developer_message(self, parse_function): + """Test parsing developer message.""" + chat_msg = { + "role": "developer", + "content": "Use concise language", + } + + messages = parse_function(chat_msg) + + assert len(messages) == 1 + assert messages[0].author.role == Role.DEVELOPER + + def test_user_message_with_string_content(self, parse_function): + """Test parsing user message with string content.""" + chat_msg = { + "role": "user", + "content": "What's the weather in San Francisco?", + } + + messages = parse_function(chat_msg) + + assert len(messages) == 1 + assert messages[0].author.role == Role.USER + assert messages[0].content[0].text == "What's the weather in San Francisco?" + + def test_user_message_with_array_content(self, parse_function): + """Test parsing user message with array content.""" + chat_msg = { + "role": "user", + "content": [ + {"text": "What's in this image? "}, + {"text": "Please describe it."}, + ], + } + + messages = parse_function(chat_msg) + + assert len(messages) == 1 + assert messages[0].author.role == Role.USER + assert len(messages[0].content) == 2 + assert messages[0].content[0].text == "What's in this image? " + assert messages[0].content[1].text == "Please describe it." + + def test_assistant_message_with_string_content(self, parse_function): + """Test parsing assistant message with string content (no tool calls).""" + chat_msg = { + "role": "assistant", + "content": "Hello! How can I help you today?", + } + + messages = parse_function(chat_msg) + + assert len(messages) == 1 + assert messages[0].author.role == Role.ASSISTANT + assert messages[0].content[0].text == "Hello! How can I help you today?" + + def test_pydantic_model_input(self, parse_function): + """Test parsing Pydantic model input (has model_dump method).""" + + class MockPydanticModel: + def model_dump(self, exclude_none=True): + return { + "role": "user", + "content": "Test message", + } + + chat_msg = MockPydanticModel() + messages = parse_function(chat_msg) + + assert len(messages) == 1 + assert messages[0].author.role == Role.USER + assert messages[0].content[0].text == "Test message" + + def test_tool_call_with_missing_function_fields(self, parse_function): + """Test parsing tool call with missing name or arguments.""" + chat_msg = { + "role": "assistant", + "tool_calls": [ + { + "function": {} # Missing both name and arguments + } + ], + } + + messages = parse_function(chat_msg) + + assert len(messages) == 1 + assert messages[0].recipient == "functions." + assert messages[0].content[0].text == "" + + def test_array_content_with_missing_text(self, parse_function): + """Test parsing array content where text field is missing.""" + chat_msg = { + "role": "user", + "content": [ + {}, # Missing text field + {"text": "actual text"}, + ], + } + + messages = parse_function(chat_msg) + + assert len(messages) == 1 + assert len(messages[0].content) == 2 + assert messages[0].content[0].text == "" + assert messages[0].content[1].text == "actual text" + + +class TestParseInputToHarmonyMessage: + """ + Tests for scenarios that are specific to the Responses API + parse_input_to_harmony_message function. + """ + + def test_message_with_empty_content(self): + """Test parsing message with empty string content.""" + chat_msg = { + "role": "user", + "content": "", + } + + messages = parse_input_to_harmony_message(chat_msg) + + assert len(messages) == 1 + assert messages[0].content[0].text == "" + + def test_tool_message_with_string_content(self): + """Test parsing tool message with string content.""" + chat_msg = { + "role": "tool", + "name": "get_weather", + "content": "The weather in San Francisco is sunny, 72°F", + } + + messages = parse_input_to_harmony_message(chat_msg) + + assert len(messages) == 1 + assert messages[0].author.role == Role.TOOL + assert messages[0].author.name == "functions.get_weather" + assert ( + messages[0].content[0].text == "The weather in San Francisco is sunny, 72°F" + ) + assert messages[0].channel == "commentary" + + def test_tool_message_with_array_content(self): + """Test parsing tool message with array content.""" + chat_msg = { + "role": "tool", + "name": "search_results", + "content": [ + {"type": "text", "text": "Result 1: "}, + {"type": "text", "text": "Result 2: "}, + { + "type": "image", + "url": "http://example.com/img.png", + }, # Should be ignored + {"type": "text", "text": "Result 3"}, + ], + } + + messages = parse_input_to_harmony_message(chat_msg) + + assert len(messages) == 1 + assert messages[0].author.role == Role.TOOL + assert messages[0].author.name == "functions.search_results" + assert messages[0].content[0].text == "Result 1: Result 2: Result 3" + + def test_tool_message_with_empty_content(self): + """Test parsing tool message with None content.""" + chat_msg = { + "role": "tool", + "name": "empty_tool", + "content": None, + } + + messages = parse_input_to_harmony_message(chat_msg) + + assert len(messages) == 1 + assert messages[0].author.role == Role.TOOL + assert messages[0].author.name == "functions.empty_tool" + assert messages[0].content[0].text == "" + + +class TestParseChatInputToHarmonyMessage: + """ + Tests for scenarios that are specific to the Chat Completion API + parse_chat_input_to_harmony_message function. + """ + + def test_user_message_with_empty_content(self): + chat_msg = { + "role": "user", + "content": "", + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + verify_harmony_messages( + messages, + [ + { + "role": "user", + "content": "", + }, + ], + ) + + def test_user_message_with_none_content(self): + chat_msg = { + "role": "user", + "content": None, + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + verify_harmony_messages( + messages, + [ + { + "role": "user", + "content": "", + }, + ], + ) + + def test_assistant_message_with_empty_content(self): + chat_msg = { + "role": "assistant", + "content": "", + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + assert len(messages) == 0 + + def test_assistant_message_with_none_content(self): + chat_msg = { + "role": "assistant", + "content": None, + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + assert len(messages) == 0 + + def test_assistant_message_with_content_but_empty_reasoning(self): + chat_msg = { + "role": "assistant", + "content": "The answer is 4.", + "reasoning": "", + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + verify_harmony_messages( + messages, + [ + { + "role": "assistant", + "channel": "final", + "content": "The answer is 4.", + }, + ], + ) + + def test_assistant_message_with_reasoning_but_empty_content(self): + chat_msg = { + "role": "assistant", + "reasoning": "I'm thinking about the user's question.", + "content": "", + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + verify_harmony_messages( + messages, + [ + { + "role": "assistant", + "channel": "analysis", + "content": "I'm thinking about the user's question.", + }, + ], + ) + + def test_assistant_message_with_reasoning_but_none_content(self): + chat_msg = { + "role": "assistant", + "reasoning": "I'm thinking about the user's question.", + "content": None, + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + verify_harmony_messages( + messages, + [ + { + "role": "assistant", + "channel": "analysis", + "content": "I'm thinking about the user's question.", + }, + ], + ) + + def test_assistant_message_with_tool_calls_but_no_content(self): + chat_msg = { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + } + } + ], + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + verify_harmony_messages( + messages, + [ + { + "role": "assistant", + "channel": "commentary", + "recipient": "functions.get_weather", + "content": '{"location": "San Francisco"}', + "content_type": "json", + }, + ], + ) + + def test_assistant_message_with_tool_calls_and_content(self): + chat_msg = { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + } + } + ], + "content": "I'll call the tool.", + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + verify_harmony_messages( + messages, + [ + { + "role": "assistant", + "channel": "commentary", + "content": "I'll call the tool.", + }, + { + "role": "assistant", + "channel": "commentary", + "recipient": "functions.get_weather", + "content": '{"location": "San Francisco"}', + "content_type": "json", + }, + ], + ) + + def test_assistant_message_with_tool_calls_and_reasoning(self): + chat_msg = { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + } + } + ], + "reasoning": "I should use the get_weather tool.", + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + verify_harmony_messages( + messages, + [ + { + "role": "assistant", + "channel": "analysis", + "content": "I should use the get_weather tool.", + }, + { + "role": "assistant", + "channel": "commentary", + "recipient": "functions.get_weather", + "content": '{"location": "San Francisco"}', + "content_type": "json", + }, + ], + ) + + def test_assistant_message_with_tool_calls_and_reasoning_and_content(self): + chat_msg = { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + } + } + ], + "reasoning": "I should use the get_weather tool.", + "content": "I'll call the tool.", + } + + messages = parse_chat_input_to_harmony_message(chat_msg) + + verify_harmony_messages( + messages, + [ + { + "role": "assistant", + "channel": "commentary", + "content": "I'll call the tool.", + }, + { + "role": "assistant", + "channel": "analysis", + "content": "I should use the get_weather tool.", + }, + { + "role": "assistant", + "channel": "commentary", + "recipient": "functions.get_weather", + "content": '{"location": "San Francisco"}', + "content_type": "json", + }, + ], + ) + + def test_tool_message_with_string_content(self): + tool_id_names = { + "call_123": "get_weather", + } + chat_msg = { + "role": "tool", + "tool_call_id": "call_123", + "content": "The weather in San Francisco is sunny, 72°F", + } + + messages = parse_chat_input_to_harmony_message( + chat_msg, tool_id_names=tool_id_names + ) + + verify_harmony_messages( + messages, + [ + { + "role": "tool", + "name": "functions.get_weather", + "content": "The weather in San Francisco is sunny, 72°F", + "channel": "commentary", + }, + ], + ) + + def test_tool_message_with_array_content(self): + tool_id_names = { + "call_123": "search_results", + } + chat_msg = { + "role": "tool", + "tool_call_id": "call_123", + "content": [ + {"type": "text", "text": "Result 1: "}, + {"type": "text", "text": "Result 2: "}, + { + "type": "image", + "url": "http://example.com/img.png", + }, # Should be ignored + {"type": "text", "text": "Result 3"}, + ], + } + + messages = parse_chat_input_to_harmony_message( + chat_msg, tool_id_names=tool_id_names + ) + + verify_harmony_messages( + messages, + [ + { + "role": "tool", + "name": "functions.search_results", + "content": "Result 1: Result 2: Result 3", + "channel": "commentary", + }, + ], + ) + + def test_tool_message_with_empty_content(self): + tool_id_names = { + "call_123": "empty_tool", + } + chat_msg = { + "role": "tool", + "tool_call_id": "call_123", + "content": "", + } + + messages = parse_chat_input_to_harmony_message( + chat_msg, tool_id_names=tool_id_names + ) + + verify_harmony_messages( + messages, + [ + { + "role": "tool", + "name": "functions.empty_tool", + "content": "", + "channel": "commentary", + }, + ], + ) + + def test_tool_message_with_none_content(self): + tool_id_names = { + "call_123": "empty_tool", + } + chat_msg = { + "role": "tool", + "tool_call_id": "call_123", + "content": None, + } + + messages = parse_chat_input_to_harmony_message( + chat_msg, tool_id_names=tool_id_names + ) + + verify_harmony_messages( + messages, + [ + { + "role": "tool", + "name": "functions.empty_tool", + "content": "", + "channel": "commentary", + }, + ], + ) + + +class TestAutoDropAnalysisMessages: + def test_no_analysis_messages(self) -> None: + messages = [ + Message.from_role_and_content( + Role.ASSISTANT, "The answer is 4." + ).with_channel("final"), + ] + cleaned_messages = auto_drop_analysis_messages(messages) + assert cleaned_messages == messages + + def test_only_analysis_message(self) -> None: + messages = [ + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking about the user's question." + ).with_channel("analysis"), + ] + cleaned_messages = auto_drop_analysis_messages(messages) + assert cleaned_messages == messages + + def test_multiple_analysis_messages_without_final_message(self) -> None: + messages = [ + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking about the user's question." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking more." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking even more." + ).with_channel("analysis"), + ] + cleaned_messages = auto_drop_analysis_messages(messages) + assert cleaned_messages == messages + + def test_only_final_message(self) -> None: + messages = [ + Message.from_role_and_content( + Role.ASSISTANT, "The answer is 4." + ).with_channel("final"), + ] + cleaned_messages = auto_drop_analysis_messages(messages) + assert cleaned_messages == messages + + def test_drops_one_analysis_messages_before_final_message(self) -> None: + messages = [ + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking about the user's question." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "The answer is 4." + ).with_channel("final"), + Message.from_role_and_content( + Role.ASSISTANT, "I should think harder." + ).with_channel("analysis"), + ] + cleaned_messages = auto_drop_analysis_messages(messages) + # Should have dropped the first analysis message + assert cleaned_messages == messages[1:] + + def test_drops_all_analysis_messages_before_final_message(self) -> None: + messages = [ + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking about the user's question." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking more." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking even more." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "The answer is 4." + ).with_channel("final"), + Message.from_role_and_content( + Role.ASSISTANT, "I should think harder." + ).with_channel("analysis"), + ] + cleaned_messages = auto_drop_analysis_messages(messages) + # Should have dropped the first 3 analysis messages + assert cleaned_messages == messages[3:] + + def test_multiple_analysis_messages_with_multiple_final_messages(self) -> None: + messages = [ + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking about the user's question." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking more." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "I'm thinking even more." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "The answer is 4." + ).with_channel("final"), + Message.from_role_and_content( + Role.ASSISTANT, "I should think harder." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "The answer is 5." + ).with_channel("final"), + ] + cleaned_messages = auto_drop_analysis_messages(messages) + # Should have dropped all those analysis messages + assert len(cleaned_messages) == 2 + assert cleaned_messages[0].content[0].text == "The answer is 4." + assert cleaned_messages[1].content[0].text == "The answer is 5." + + def test_drops_non_assistant_analysis_messages(self) -> None: + messages = [ + Message.from_role_and_content( + Role.TOOL, "The tool thinks we should think harder." + ).with_channel("analysis"), + Message.from_role_and_content( + Role.ASSISTANT, "The answer is 4." + ).with_channel("final"), + ] + cleaned_messages = auto_drop_analysis_messages(messages) + # Should have dropped the analysis message + assert cleaned_messages == messages[1:] + + +class TestParseChatOutput: + def test_parse_chat_output_interrupted_first_message(self) -> None: + harmony_str = "<|channel|>final<|message|>I'm in the middle of answering" + token_ids = get_encoding().encode(harmony_str, allowed_special="all") + reasoning, final_content, _ = parse_chat_output(token_ids) + assert reasoning is None + assert final_content == "I'm in the middle of answering" + + def test_parse_chat_output_interrupted_reasoning_first_message(self) -> None: + harmony_str = "<|channel|>analysis<|message|>I'm in the middle of thinking" + token_ids = get_encoding().encode(harmony_str, allowed_special="all") + reasoning, final_content, _ = parse_chat_output(token_ids) + assert reasoning == "I'm in the middle of thinking" + assert final_content is None + + def test_parse_chat_output_complete_reasoning_interrupted_content(self) -> None: + harmony_str = ( + "<|channel|>analysis<|message|>I'm thinking.<|end|>" + "<|start|>assistant<|channel|>final" + "<|message|>I'm in the middle of answering" + ) + token_ids = get_encoding().encode(harmony_str, allowed_special="all") + reasoning, final_content, _ = parse_chat_output(token_ids) + assert reasoning == "I'm thinking." + assert final_content == "I'm in the middle of answering" + + def test_parse_chat_output_complete_content(self) -> None: + harmony_str = "<|channel|>final<|message|>The answer is 4.<|end|>" + token_ids = get_encoding().encode(harmony_str, allowed_special="all") + reasoning, final_content, _ = parse_chat_output(token_ids) + assert reasoning is None + assert final_content == "The answer is 4." + + def test_parse_chat_output_complete_commentary(self) -> None: + harmony_str = ( + "<|channel|>commentary<|message|>I need to call some tools.<|end|>" + ) + token_ids = get_encoding().encode(harmony_str, allowed_special="all") + reasoning, final_content, _ = parse_chat_output(token_ids) + assert reasoning is None + assert final_content == "I need to call some tools." + + def test_parse_chat_output_complete_reasoning(self) -> None: + harmony_str = ( + "<|channel|>analysis<|message|>I've thought hard about this.<|end|>" + ) + token_ids = get_encoding().encode(harmony_str, allowed_special="all") + reasoning, final_content, _ = parse_chat_output(token_ids) + assert reasoning == "I've thought hard about this." + assert final_content is None + + def test_parse_chat_output_complete_reasoning_and_content(self) -> None: + harmony_str = ( + "<|channel|>analysis<|message|>I've thought hard about this.<|end|>" + "<|start|>assistant<|channel|>final<|message|>The answer is 4.<|end|>" + ) + token_ids = get_encoding().encode(harmony_str, allowed_special="all") + reasoning, final_content, _ = parse_chat_output(token_ids) + assert reasoning == "I've thought hard about this." + assert final_content == "The answer is 4." + + +class TestParseOutputMessage: + """Tests for parse_output_message function.""" + + def test_commentary_with_no_recipient_creates_reasoning(self): + """Test that commentary with recipient=None (preambles) creates reasoning items. + + Per Harmony format, commentary channel can contain preambles to calling + multiple functions - explanatory text with no recipient. + """ + message = Message.from_role_and_content( + Role.ASSISTANT, "I will now search for the weather information." + ) + message = message.with_channel("commentary") + # recipient is None by default, representing a preamble + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseReasoningItem) + assert output_items[0].type == "reasoning" + assert ( + output_items[0].content[0].text + == "I will now search for the weather information." + ) + assert output_items[0].content[0].type == "reasoning_text" + + def test_commentary_with_function_recipient_creates_function_call(self): + """Test commentary with recipient='functions.X' creates function calls.""" + message = Message.from_role_and_content( + Role.ASSISTANT, '{"location": "San Francisco", "units": "celsius"}' + ) + message = message.with_channel("commentary") + message = message.with_recipient("functions.get_weather") + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseFunctionToolCall) + assert output_items[0].type == "function_call" + assert output_items[0].name == "get_weather" + assert ( + output_items[0].arguments + == '{"location": "San Francisco", "units": "celsius"}' + ) + assert output_items[0].call_id.startswith("call_") + assert output_items[0].id.startswith("fc_") + + def test_commentary_with_python_recipient_creates_reasoning(self): + """Test that commentary with recipient='python' creates reasoning items.""" + message = Message.from_role_and_content( + Role.ASSISTANT, "import numpy as np\nprint(np.array([1, 2, 3]))" + ) + message = message.with_channel("commentary") + message = message.with_recipient("python") + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseReasoningItem) + assert output_items[0].type == "reasoning" + assert ( + output_items[0].content[0].text + == "import numpy as np\nprint(np.array([1, 2, 3]))" + ) + + def test_commentary_with_browser_recipient_creates_reasoning(self): + """Test that commentary with recipient='browser' creates reasoning items.""" + message = Message.from_role_and_content( + Role.ASSISTANT, "Navigating to the specified URL" + ) + message = message.with_channel("commentary") + message = message.with_recipient("browser") + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseReasoningItem) + assert output_items[0].type == "reasoning" + assert output_items[0].content[0].text == "Navigating to the specified URL" + + def test_commentary_with_container_recipient_creates_reasoning(self): + """Test that commentary with recipient='container' creates reasoning items.""" + message = Message.from_role_and_content( + Role.ASSISTANT, "Running command in container" + ) + message = message.with_channel("commentary") + message = message.with_recipient("container") + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseReasoningItem) + assert output_items[0].type == "reasoning" + assert output_items[0].content[0].text == "Running command in container" + + def test_commentary_with_empty_content_and_no_recipient(self): + """Test edge case: empty commentary with recipient=None.""" + message = Message.from_role_and_content(Role.ASSISTANT, "") + message = message.with_channel("commentary") + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseReasoningItem) + assert output_items[0].content[0].text == "" + + def test_commentary_with_multiple_contents_and_no_recipient(self): + """Test multiple content items in commentary with no recipient.""" + contents = [ + TextContent(text="Step 1: Analyze the request"), + TextContent(text="Step 2: Prepare to call functions"), + ] + message = Message.from_role_and_contents(Role.ASSISTANT, contents) + message = message.with_channel("commentary") + + output_items = parse_output_message(message) + + assert len(output_items) == 2 + assert all(isinstance(item, ResponseReasoningItem) for item in output_items) + assert output_items[0].content[0].text == "Step 1: Analyze the request" + assert output_items[1].content[0].text == "Step 2: Prepare to call functions" + + def test_commentary_with_multiple_function_calls(self): + """Test multiple function calls in commentary channel.""" + contents = [ + TextContent(text='{"location": "San Francisco"}'), + TextContent(text='{"location": "New York"}'), + ] + message = Message.from_role_and_contents(Role.ASSISTANT, contents) + message = message.with_channel("commentary") + message = message.with_recipient("functions.get_weather") + + output_items = parse_output_message(message) + + assert len(output_items) == 2 + assert all(isinstance(item, ResponseFunctionToolCall) for item in output_items) + assert output_items[0].name == "get_weather" + assert output_items[1].name == "get_weather" + assert output_items[0].arguments == '{"location": "San Francisco"}' + assert output_items[1].arguments == '{"location": "New York"}' + + def test_commentary_with_unknown_recipient_creates_mcp_call(self): + """Test that commentary with unknown recipient creates MCP call.""" + message = Message.from_role_and_content(Role.ASSISTANT, '{"arg": "value"}') + message = message.with_channel("commentary") + message = message.with_recipient("custom_tool") + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], McpCall) + assert output_items[0].type == "mcp_call" + assert output_items[0].name == "custom_tool" + assert output_items[0].server_label == "custom_tool" + + def test_analysis_channel_creates_reasoning(self): + """Test that analysis channel creates reasoning items.""" + message = Message.from_role_and_content( + Role.ASSISTANT, "Analyzing the problem step by step..." + ) + message = message.with_channel("analysis") + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseReasoningItem) + assert output_items[0].type == "reasoning" + assert ( + output_items[0].content[0].text == "Analyzing the problem step by step..." + ) + + def test_non_assistant_message_returns_empty(self): + """Test that non-assistant messages return empty list. + + Per the implementation, tool messages to assistant (e.g., search results) + are not included in final output to align with OpenAI behavior. + """ + message = Message.from_author_and_content( + Author.new(Role.TOOL, "functions.get_weather"), + "The weather is sunny, 72°F", + ) + + output_items = parse_output_message(message) + + assert len(output_items) == 0 + + +def test_has_custom_tools() -> None: + assert not has_custom_tools(set()) + assert not has_custom_tools({"web_search_preview", "code_interpreter", "container"}) + assert has_custom_tools({"others"}) + assert has_custom_tools( + {"web_search_preview", "code_interpreter", "container", "others"} + ) + + +def test_parse_mcp_call_basic() -> None: + """Test that MCP calls are parsed with correct type and server_label.""" + message = Message.from_role_and_content(Role.ASSISTANT, '{"path": "/tmp"}') + message = message.with_recipient("filesystem") + message = message.with_channel("commentary") + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], McpCall) + assert output_items[0].type == "mcp_call" + assert output_items[0].name == "filesystem" + assert output_items[0].server_label == "filesystem" + assert output_items[0].arguments == '{"path": "/tmp"}' + assert output_items[0].status == "completed" + + +def test_parse_mcp_call_dotted_recipient() -> None: + """Test that dotted recipients extract the tool name correctly.""" + message = Message.from_role_and_content(Role.ASSISTANT, '{"cmd": "ls"}') + message = message.with_recipient("repo_browser.list") + message = message.with_channel("commentary") + + output_items = parse_output_message(message) + + assert len(output_items) == 1 + assert isinstance(output_items[0], McpCall) + assert output_items[0].name == "list" + assert output_items[0].server_label == "repo_browser" + + +def test_mcp_vs_function_call() -> None: + """Test that function calls are not parsed as MCP calls.""" + func_message = Message.from_role_and_content(Role.ASSISTANT, '{"arg": "value"}') + func_message = func_message.with_recipient("functions.my_tool") + func_message = func_message.with_channel("commentary") + + func_items = parse_output_message(func_message) + + assert len(func_items) == 1 + assert not isinstance(func_items[0], McpCall) + assert func_items[0].type == "function_call" + + +def test_mcp_vs_builtin_tools() -> None: + """Test that built-in tools (python, container) are not parsed as MCP calls.""" + # Test python (built-in tool) - should be reasoning, not MCP + python_message = Message.from_role_and_content(Role.ASSISTANT, "print('hello')") + python_message = python_message.with_recipient("python") + python_message = python_message.with_channel("commentary") + + python_items = parse_output_message(python_message) + + assert len(python_items) == 1 + assert not isinstance(python_items[0], McpCall) + assert python_items[0].type == "reasoning" + + +def test_parse_remaining_state_commentary_channel() -> None: + """Test parse_remaining_state with commentary channel and various recipients.""" + from unittest.mock import Mock + + from vllm.entrypoints.openai.parser.harmony_utils import parse_remaining_state + + # Test 1: functions.* recipient → should return function tool call + parser_func = Mock() + parser_func.current_content = '{"arg": "value"}' + parser_func.current_role = Role.ASSISTANT + parser_func.current_channel = "commentary" + parser_func.current_recipient = "functions.my_tool" + + func_items = parse_remaining_state(parser_func) + + assert len(func_items) == 1 + assert not isinstance(func_items[0], McpCall) + assert func_items[0].type == "function_call" + assert func_items[0].name == "my_tool" + assert func_items[0].status == "in_progress" + + # Test 2: MCP tool (not builtin) → should return MCP call + parser_mcp = Mock() + parser_mcp.current_content = '{"path": "/tmp"}' + parser_mcp.current_role = Role.ASSISTANT + parser_mcp.current_channel = "commentary" + parser_mcp.current_recipient = "filesystem" + + mcp_items = parse_remaining_state(parser_mcp) + + assert len(mcp_items) == 1 + assert isinstance(mcp_items[0], McpCall) + assert mcp_items[0].type == "mcp_call" + assert mcp_items[0].name == "filesystem" + assert mcp_items[0].server_label == "filesystem" + assert mcp_items[0].status == "in_progress" + + # Test 3: Built-in tool (python) + # should NOT return MCP call, falls through to reasoning + parser_builtin = Mock() + parser_builtin.current_content = "print('hello')" + parser_builtin.current_role = Role.ASSISTANT + parser_builtin.current_channel = "commentary" + parser_builtin.current_recipient = "python" + + builtin_items = parse_remaining_state(parser_builtin) + + # Should fall through to reasoning logic + assert len(builtin_items) == 1 + assert not isinstance(builtin_items[0], McpCall) + assert builtin_items[0].type == "reasoning" + + +def test_parse_remaining_state_analysis_channel() -> None: + """Test parse_remaining_state with analysis channel and various recipients.""" + from unittest.mock import Mock + + from vllm.entrypoints.openai.parser.harmony_utils import parse_remaining_state + + # Test 1: functions.* recipient → should return function tool call + parser_func = Mock() + parser_func.current_content = '{"arg": "value"}' + parser_func.current_role = Role.ASSISTANT + parser_func.current_channel = "analysis" + parser_func.current_recipient = "functions.my_tool" + + func_items = parse_remaining_state(parser_func) + + assert len(func_items) == 1 + assert not isinstance(func_items[0], McpCall) + assert func_items[0].type == "function_call" + assert func_items[0].name == "my_tool" + assert func_items[0].status == "in_progress" + + # Test 2: MCP tool (not builtin) → should return MCP call + parser_mcp = Mock() + parser_mcp.current_content = '{"query": "test"}' + parser_mcp.current_role = Role.ASSISTANT + parser_mcp.current_channel = "analysis" + parser_mcp.current_recipient = "database" + + mcp_items = parse_remaining_state(parser_mcp) + + assert len(mcp_items) == 1 + assert isinstance(mcp_items[0], McpCall) + assert mcp_items[0].type == "mcp_call" + assert mcp_items[0].name == "database" + assert mcp_items[0].server_label == "database" + assert mcp_items[0].status == "in_progress" + + # Test 3: Built-in tool (container) + # should NOT return MCP call, falls through to reasoning + parser_builtin = Mock() + parser_builtin.current_content = "docker run" + parser_builtin.current_role = Role.ASSISTANT + parser_builtin.current_channel = "analysis" + parser_builtin.current_recipient = "container" + + builtin_items = parse_remaining_state(parser_builtin) + + # Should fall through to reasoning logic + assert len(builtin_items) == 1 + assert not isinstance(builtin_items[0], McpCall) + assert builtin_items[0].type == "reasoning" diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 3d581a300b6a9..1ff30de31bbe5 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer): @pytest.mark.asyncio async def test_health_check_engine_dead_error(): # Import the health function directly to test it in isolation - from vllm.entrypoints.openai.api_server import health + from vllm.entrypoints.serve.instrumentator.health import health # Create a mock request that simulates what FastAPI would provide mock_request = Mock(spec=Request) diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py new file mode 100644 index 0000000000000..b194e9b74d874 --- /dev/null +++ b/tests/entrypoints/openai/test_chat_error.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from vllm.config.multimodal import MultiModalConfig +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM + +MODEL_NAME = "openai-community/gpt2" +MODEL_NAME_SHORT = "gpt2" +BASE_MODEL_PATHS = [ + BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), + BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT), +] + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + task = "generate" + runner_type = "generate" + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + multimodal_config = MultiModalConfig() + hf_config = MockHFConfig() + logits_processor_pattern = None + logits_processors: list[str] | None = None + diff_sampling_param: dict | None = None + allowed_local_media_path: str = "" + allowed_media_domains: list[str] | None = None + encoder_config = None + generation_config: str = "auto" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + skip_tokenizer_init = False + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} + + +def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: + models = OpenAIServingModels( + engine_client=engine, + base_model_paths=BASE_MODEL_PATHS, + ) + serving_chat = OpenAIServingChat( + engine, + models, + response_role="assistant", + request_logger=None, + chat_template=None, + chat_template_content_format="auto", + ) + + async def _fake_process_inputs( + request_id, + engine_prompt, + sampling_params, + *, + lora_request, + trace_headers, + priority, + ): + return dict(engine_prompt), {} + + async def _fake_preprocess_chat(*args, **kwargs): + # return conversation, engine_prompts + return ( + [{"role": "user", "content": "Test"}], + [{"prompt_token_ids": [1, 2, 3]}], + ) + + serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs) + serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat) + return serving_chat + + +@pytest.mark.asyncio +async def test_chat_error_non_stream(): + """test finish_reason='error' returns 500 InternalServerError (non-streaming)""" + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + serving_chat = _build_serving_chat(mock_engine) + + completion_output = CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + + request_output = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output], + finished=True, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + async def mock_generate(*args, **kwargs): + yield request_output + + mock_engine.generate = MagicMock(side_effect=mock_generate) + + request = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "Test prompt"}], + max_tokens=10, + stream=False, + ) + + response = await serving_chat.create_chat_completion(request) + + assert isinstance(response, ErrorResponse) + assert response.error.type == "InternalServerError" + assert response.error.message == "Internal server error" + assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_chat_error_stream(): + """test finish_reason='error' returns 500 InternalServerError (streaming)""" + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + serving_chat = _build_serving_chat(mock_engine) + + completion_output_1 = CompletionOutput( + index=0, + text="Hello", + token_ids=[100], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ) + + request_output_1 = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output_1], + finished=False, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + completion_output_2 = CompletionOutput( + index=0, + text="Hello", + token_ids=[100], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + + request_output_2 = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output_2], + finished=True, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + async def mock_generate(*args, **kwargs): + yield request_output_1 + yield request_output_2 + + mock_engine.generate = MagicMock(side_effect=mock_generate) + + request = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "Test prompt"}], + max_tokens=10, + stream=True, + ) + + response = await serving_chat.create_chat_completion(request) + + chunks = [] + async for chunk in response: + chunks.append(chunk) + + assert len(chunks) >= 2 + assert any("Internal server error" in chunk for chunk in chunks), ( + f"Expected error message in chunks: {chunks}" + ) + assert chunks[-1] == "data: [DONE]\n\n" diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index ee79ed59c4102..77087ac21ea8b 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -6,7 +6,7 @@ import pytest from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer from ...models.registry import HF_EXAMPLE_MODELS from ...utils import VLLM_PATH diff --git a/tests/entrypoints/openai/test_completion_error.py b/tests/entrypoints/openai/test_completion_error.py new file mode 100644 index 0000000000000..ca56cc2ddb6a7 --- /dev/null +++ b/tests/entrypoints/openai/test_completion_error.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from vllm.config.multimodal import MultiModalConfig +from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine.async_llm import AsyncLLM + +MODEL_NAME = "openai-community/gpt2" +MODEL_NAME_SHORT = "gpt2" +BASE_MODEL_PATHS = [ + BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), + BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT), +] + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + task = "generate" + runner_type = "generate" + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + multimodal_config = MultiModalConfig() + hf_config = MockHFConfig() + logits_processor_pattern = None + logits_processors: list[str] | None = None + diff_sampling_param: dict | None = None + allowed_local_media_path: str = "" + allowed_media_domains: list[str] | None = None + encoder_config = None + generation_config: str = "auto" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + skip_tokenizer_init = False + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} + + +def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: + models = OpenAIServingModels( + engine_client=engine, + base_model_paths=BASE_MODEL_PATHS, + ) + serving_completion = OpenAIServingCompletion( + engine, + models, + request_logger=None, + ) + + async def _fake_process_inputs( + request_id, + engine_prompt, + sampling_params, + *, + lora_request, + trace_headers, + priority, + ): + return dict(engine_prompt), {} + + serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs) + return serving_completion + + +@pytest.mark.asyncio +async def test_completion_error_non_stream(): + """test finish_reason='error' returns 500 InternalServerError (non-streaming)""" + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + serving_completion = _build_serving_completion(mock_engine) + + completion_output = CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + + request_output = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output], + finished=True, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + async def mock_generate(*args, **kwargs): + yield request_output + + mock_engine.generate = MagicMock(side_effect=mock_generate) + + request = CompletionRequest( + model=MODEL_NAME, + prompt="Test prompt", + max_tokens=10, + stream=False, + ) + + response = await serving_completion.create_completion(request) + + assert isinstance(response, ErrorResponse) + assert response.error.type == "InternalServerError" + assert response.error.message == "Internal server error" + assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_completion_error_stream(): + """test finish_reason='error' returns 500 InternalServerError (streaming)""" + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + serving_completion = _build_serving_completion(mock_engine) + + completion_output_1 = CompletionOutput( + index=0, + text="Hello", + token_ids=[100], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + ) + + request_output_1 = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output_1], + finished=False, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + completion_output_2 = CompletionOutput( + index=0, + text="Hello", + token_ids=[100], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + + request_output_2 = RequestOutput( + request_id="test-id", + prompt="Test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[completion_output_2], + finished=True, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + async def mock_generate(*args, **kwargs): + yield request_output_1 + yield request_output_2 + + mock_engine.generate = MagicMock(side_effect=mock_generate) + + request = CompletionRequest( + model=MODEL_NAME, + prompt="Test prompt", + max_tokens=10, + stream=True, + ) + + response = await serving_completion.create_completion(request) + + chunks = [] + async for chunk in response: + chunks.append(chunk) + + assert len(chunks) >= 2 + assert any("Internal server error" in chunk for chunk in chunks), ( + f"Expected error message in chunks: {chunks}" + ) + assert chunks[-1] == "data: [DONE]\n\n" diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index 4856cafef44b3..ea6b3d812d8fe 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -14,7 +14,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer from vllm.v1.engine.async_llm import AsyncLLM MODEL_NAME = "openai-community/gpt2" diff --git a/tests/entrypoints/openai/test_messages.py b/tests/entrypoints/openai/test_messages.py index 3e390ad496428..8de6c4cb6c887 100644 --- a/tests/entrypoints/openai/test_messages.py +++ b/tests/entrypoints/openai/test_messages.py @@ -69,9 +69,23 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic): stream=True, ) + first_chunk = None + chunk_count = 0 async for chunk in resp: + chunk_count += 1 + if first_chunk is None and chunk.type == "message_start": + first_chunk = chunk print(chunk.model_dump_json()) + assert chunk_count > 0 + assert first_chunk is not None, "message_start chunk was never observed" + assert first_chunk.message is not None, "first chunk should include message" + assert first_chunk.message.usage is not None, ( + "first chunk should include usage stats" + ) + assert first_chunk.message.usage.output_tokens == 0 + assert first_chunk.message.usage.input_tokens > 5 + @pytest.mark.asyncio async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic): diff --git a/tests/entrypoints/openai/test_response_api_parsable_context.py b/tests/entrypoints/openai/test_response_api_parsable_context.py new file mode 100644 index 0000000000000..6d97602f32475 --- /dev/null +++ b/tests/entrypoints/openai/test_response_api_parsable_context.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib +import json + +import pytest +import pytest_asyncio +from openai import OpenAI + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-8B" + + +@pytest.fixture(scope="module") +def server(): + assert importlib.util.find_spec("gpt_oss") is not None, ( + "Harmony tests require gpt_oss package to be installed" + ) + + args = [ + "--reasoning-parser", + "qwen3", + "--max_model_len", + "5000", + "--structured-outputs-config.backend", + "xgrammar", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--tool-server", + "demo", + ] + env_dict = dict( + VLLM_ENABLE_RESPONSES_API_STORE="1", + VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1", + PYTHON_EXECUTION_BACKEND="dangerously_use_uv", + ) + + with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_basic(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="What is 13 * 24?", + ) + assert response is not None + print("response: ", response) + assert response.status == "completed" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_reasoning_and_function_items(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input=[ + {"type": "message", "content": "Hello.", "role": "user"}, + { + "type": "reasoning", + "id": "lol", + "content": [ + { + "type": "reasoning_text", + "text": "We need to respond: greeting.", + } + ], + "summary": [], + }, + { + "arguments": '{"location": "Paris", "unit": "celsius"}', + "call_id": "call_5f7b38f3b81e4b8380fd0ba74f3ca3ab", + "name": "get_weather", + "type": "function_call", + "id": "fc_4fe5d6fc5b6c4d6fa5f24cc80aa27f78", + "status": "completed", + }, + { + "call_id": "call_5f7b38f3b81e4b8380fd0ba74f3ca3ab", + "id": "fc_4fe5d6fc5b6c4d6fa5f24cc80aa27f78", + "output": "The weather in Paris is 20 Celsius", + "status": "completed", + "type": "function_call_output", + }, + ], + temperature=0.0, + ) + assert response is not None + assert response.status == "completed" + # make sure we get a reasoning and text output + assert response.output[0].type == "reasoning" + assert response.output[1].type == "message" + assert type(response.output[1].content[0].text) is str + + +def get_horoscope(sign): + return f"{sign}: Next Tuesday you will befriend a baby otter." + + +def call_function(name, args): + if name == "get_horoscope": + return get_horoscope(**args) + else: + raise ValueError(f"Unknown function: {name}") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_call_first_turn(client: OpenAI, model_name: str): + tools = [ + { + "type": "function", + "name": "get_horoscope", + "description": "Get today's horoscope for an astrological sign.", + "parameters": { + "type": "object", + "properties": { + "sign": {"type": "string"}, + }, + "required": ["sign"], + "additionalProperties": False, + }, + "strict": True, + } + ] + + response = await client.responses.create( + model=model_name, + input="What is the horoscope for Aquarius today?", + tools=tools, + temperature=0.0, + ) + assert response is not None + assert response.status == "completed" + assert len(response.output) == 2 + assert response.output[0].type == "reasoning" + assert response.output[1].type == "function_call" + + function_call = response.output[1] + assert function_call.name == "get_horoscope" + assert function_call.call_id is not None + + args = json.loads(function_call.arguments) + assert "sign" in args + + # the multi turn function call is tested above in + # test_reasoning_and_function_items + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_mcp_tool_call(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="What is 13 * 24? Use python to calculate the result.", + tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], + extra_body={"enable_response_messages": True}, + temperature=0.0, + ) + + assert response is not None + assert response.status == "completed" + assert response.output[0].type == "reasoning" + assert response.output[1].type == "mcp_call" + assert type(response.output[1].arguments) is str + assert type(response.output[1].output) is str + assert response.output[2].type == "reasoning" + # make sure the correct math is in the final output + assert response.output[3].type == "message" + assert "312" in response.output[3].content[0].text + + # test raw input_messages / output_messages + assert len(response.input_messages) == 1 + assert len(response.output_messages) == 3 + assert "312" in response.output_messages[2]["message"] diff --git a/tests/entrypoints/openai/test_response_api_simple.py b/tests/entrypoints/openai/test_response_api_simple.py index 425b8199a0fd0..02e06297f3987 100644 --- a/tests/entrypoints/openai/test_response_api_simple.py +++ b/tests/entrypoints/openai/test_response_api_simple.py @@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str): assert response.status == "completed" +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_enable_response_messages(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="Hello?", + extra_body={"enable_response_messages": True}, + ) + assert response.status == "completed" + assert response.input_messages[0]["type"] == "raw_message_tokens" + assert type(response.input_messages[0]["message"]) is str + assert len(response.input_messages[0]["message"]) > 10 + assert type(response.input_messages[0]["tokens"][0]) is int + assert type(response.output_messages[0]["message"]) is str + assert len(response.output_messages[0]["message"]) > 10 + assert type(response.output_messages[0]["tokens"][0]) is int + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_reasoning_item(client: OpenAI, model_name: str): @@ -69,3 +87,48 @@ async def test_reasoning_item(client: OpenAI, model_name: str): assert response.output[0].type == "reasoning" assert response.output[1].type == "message" assert type(response.output[1].content[0].text) is str + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_streaming_output_consistency(client: OpenAI, model_name: str): + """Test that streaming delta text matches the final response output_text. + + This test verifies that when using streaming mode: + 1. The concatenated text from all 'response.output_text.delta' events + 2. Matches the 'output_text' in the final 'response.completed' event + """ + response = await client.responses.create( + model=model_name, + input="Say hello in one sentence.", + stream=True, + ) + + events = [] + async for event in response: + events.append(event) + + assert len(events) > 0 + + # Concatenate all delta text from streaming events + streaming_text = "".join( + event.delta for event in events if event.type == "response.output_text.delta" + ) + + # Get the final response from the last event + response_completed_event = events[-1] + assert response_completed_event.type == "response.completed" + assert response_completed_event.response.status == "completed" + + # Get output_text from the final response + final_output_text = response_completed_event.response.output_text + + # Verify final response has output + assert len(response_completed_event.response.output) > 0 + + # Verify streaming text matches final output_text + assert streaming_text == final_output_text, ( + f"Streaming text does not match final output_text.\n" + f"Streaming: {streaming_text!r}\n" + f"Final: {final_output_text!r}" + ) diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 8fd3545eccffa..6f2a50020699c 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -726,7 +726,7 @@ async def test_function_calling_required(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_system_message_with_tools(client: OpenAI, model_name: str): - from vllm.entrypoints.harmony_utils import get_system_message + from vllm.entrypoints.openai.parser.harmony_utils import get_system_message # Test with custom tools enabled - commentary channel should be available sys_msg = get_system_message(with_custom_tools=True) diff --git a/tests/entrypoints/openai/test_responses_error.py b/tests/entrypoints/openai/test_responses_error.py new file mode 100644 index 0000000000000..f8ea178288835 --- /dev/null +++ b/tests/entrypoints/openai/test_responses_error.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from http import HTTPStatus +from unittest.mock import MagicMock + +import pytest + +from vllm.entrypoints.openai.protocol import ErrorResponse +from vllm.entrypoints.openai.serving_engine import GenerationError, OpenAIServing + + +@pytest.mark.asyncio +async def test_raise_if_error_raises_generation_error(): + """test _raise_if_error raises GenerationError""" + # create a minimal OpenAIServing instance + mock_engine = MagicMock() + mock_engine.model_config = MagicMock() + mock_engine.model_config.max_model_len = 100 + mock_models = MagicMock() + + serving = OpenAIServing( + engine_client=mock_engine, + models=mock_models, + request_logger=None, + ) + + # test that error finish_reason raises GenerationError + with pytest.raises(GenerationError) as exc_info: + serving._raise_if_error("error", "test-request-id") + + assert str(exc_info.value) == "Internal server error" + assert exc_info.value.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + # test that other finish_reasons don't raise + serving._raise_if_error("stop", "test-request-id") # should not raise + serving._raise_if_error("length", "test-request-id") # should not raise + serving._raise_if_error(None, "test-request-id") # should not raise + + +@pytest.mark.asyncio +async def test_convert_generation_error_to_response(): + """test _convert_generation_error_to_response creates proper ErrorResponse""" + mock_engine = MagicMock() + mock_engine.model_config = MagicMock() + mock_engine.model_config.max_model_len = 100 + mock_models = MagicMock() + + serving = OpenAIServing( + engine_client=mock_engine, + models=mock_models, + request_logger=None, + ) + + # create a GenerationError + gen_error = GenerationError("Internal server error") + + # convert to ErrorResponse + error_response = serving._convert_generation_error_to_response(gen_error) + + assert isinstance(error_response, ErrorResponse) + assert error_response.error.type == "InternalServerError" + assert error_response.error.message == "Internal server error" + assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_convert_generation_error_to_streaming_response(): + """test _convert_generation_error_to_streaming_response output""" + mock_engine = MagicMock() + mock_engine.model_config = MagicMock() + mock_engine.model_config.max_model_len = 100 + mock_models = MagicMock() + + serving = OpenAIServing( + engine_client=mock_engine, + models=mock_models, + request_logger=None, + ) + + # create a GenerationError + gen_error = GenerationError("Internal server error") + + # convert to streaming error response + error_json = serving._convert_generation_error_to_streaming_response(gen_error) + + assert isinstance(error_json, str) + assert "Internal server error" in error_json + assert "InternalServerError" in error_json diff --git a/tests/entrypoints/openai/test_return_token_ids.py b/tests/entrypoints/openai/test_return_token_ids.py index feef48a36dfa1..8537082e3f8d1 100644 --- a/tests/entrypoints/openai/test_return_token_ids.py +++ b/tests/entrypoints/openai/test_return_token_ids.py @@ -3,7 +3,7 @@ import pytest -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer from ...utils import RemoteOpenAIServer diff --git a/tests/entrypoints/openai/test_return_tokens_as_ids.py b/tests/entrypoints/openai/test_return_tokens_as_ids.py index cedf6ce160607..d4d9a6c5b6120 100644 --- a/tests/entrypoints/openai/test_return_tokens_as_ids.py +++ b/tests/entrypoints/openai/test_return_tokens_as_ids.py @@ -7,7 +7,7 @@ import pytest -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer from ...utils import RemoteOpenAIServer diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 6a1b15c4131e0..444275e061c61 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -11,13 +11,25 @@ import pytest_asyncio from openai import OpenAI from vllm.config.multimodal import MultiModalConfig -from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from vllm.entrypoints.openai.parser.harmony_utils import get_encoding +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + RequestResponseMetadata, +) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers import ToolParserManager from vllm.v1.engine.async_llm import AsyncLLM from ...utils import RemoteOpenAIServer +from .utils import ( + accumulate_streaming_response, + verify_chat_response, + verify_harmony_messages, +) GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b" @@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction(): # Verify that data_parallel_rank defaults to None assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None + + +class TestServingChatWithHarmony: + """ + These tests ensure Chat Completion requests are being properly converted into + Harmony messages and Harmony response messages back into Chat Completion responses. + These tests are not exhaustive, but each one was created to cover a specific case + that we got wrong but is now fixed. + + Any changes to the tests and their expectations may result in changes to the + accuracy of model prompting and responses generated. It is suggested to run + an evaluation or benchmarking suite (such as bfcl multi_turn) to understand + any impact of changes in how we prompt Harmony models. + """ + + @pytest.fixture(params=[False, True], ids=["non_streaming", "streaming"]) + def stream(self, request) -> bool: + """Parameterize tests to run in both non-streaming and streaming modes.""" + return request.param + + @pytest.fixture() + def mock_engine(self) -> AsyncLLM: + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + return mock_engine + + @pytest.fixture() + def serving_chat(self, mock_engine) -> OpenAIServingChat: + chat = _build_serving_chat(mock_engine) + chat.use_harmony = True + chat.tool_parser = ToolParserManager.get_tool_parser("openai") + return chat + + def mock_request_output_from_req_and_token_ids( + self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False + ) -> RequestOutput: + # Our tests don't use most fields, so just get the token ids correct + completion_output = CompletionOutput( + index=0, + text="", + token_ids=token_ids, + cumulative_logprob=0.0, + logprobs=None, + ) + return RequestOutput( + request_id=req.request_id, + prompt=[], + prompt_token_ids=[], + prompt_logprobs=None, + outputs=[completion_output], + finished=finished, + ) + + @pytest.fixture + def weather_tools(self) -> list[dict[str, Any]]: + return [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + }, + "required": ["location"], + }, + }, + }, + ] + + @pytest.fixture + def weather_messages_start(self) -> list[dict[str, Any]]: + return [ + { + "role": "user", + "content": "What's the weather like in Paris today?", + }, + ] + + async def generate_response_from_harmony_str( + self, + serving_chat: OpenAIServingChat, + req: ChatCompletionRequest, + harmony_str: str, + stream: bool = False, + ) -> ChatCompletionResponse: + harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all") + + async def result_generator(): + if stream: + for token_id in harmony_token_ids: + yield self.mock_request_output_from_req_and_token_ids( + req, [token_id] + ) + yield self.mock_request_output_from_req_and_token_ids( + req, [], finished=True + ) + else: + yield self.mock_request_output_from_req_and_token_ids( + req, harmony_token_ids, finished=True + ) + + generator_func = ( + serving_chat.chat_completion_stream_generator + if stream + else serving_chat.chat_completion_full_generator + ) + + result = generator_func( + request=req, + result_generator=result_generator(), + request_id=req.request_id, + model_name=req.model, + conversation=[], + tokenizer=get_tokenizer(req.model), + request_metadata=RequestResponseMetadata( + request_id=req.request_id, + model_name=req.model, + ), + ) + + if stream: + return await accumulate_streaming_response(result) + return await result + + @pytest.mark.asyncio + async def test_simple_chat(self, serving_chat, stream): + messages = [{"role": "user", "content": "what is 1+1?"}] + + # Test the Harmony messages for the first turn's input + req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages, _ = serving_chat._make_request_with_harmony(req) + verify_harmony_messages( + input_messages, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user", "content": messages[0]["content"]}, + ], + ) + + # Test the Chat Completion response for the first turn's output + reasoning_str = "We need to think really hard about this." + final_str = "The answer is 2." + response_str = ( + f"<|channel|>analysis<|message|>{reasoning_str}<|end|>" + f"<|start|>assistant<|channel|>final<|message|>{final_str}<|end|>" + ) + response = await self.generate_response_from_harmony_str( + serving_chat, req, response_str, stream=stream + ) + verify_chat_response(response, content=final_str, reasoning=reasoning_str) + + # Add the output messages from the first turn as input to the second turn + for choice in response.choices: + messages.append(choice.message.model_dump(exclude_none=True)) + + # Test the Harmony messages for the second turn's input + req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) + verify_harmony_messages( + input_messages_2, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user"}, + # The analysis message should be dropped on subsequent inputs because + # of the subsequent assistant message to the final channel. + {"role": "assistant", "channel": "final", "content": final_str}, + ], + ) + + @pytest.mark.asyncio + async def test_tool_call_response_with_content( + self, serving_chat, stream, weather_tools, weather_messages_start + ): + tools = weather_tools + messages = list(weather_messages_start) + + # Test the Harmony messages for the first turn's input + req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools) + input_messages, _ = serving_chat._make_request_with_harmony(req) + verify_harmony_messages( + input_messages, + [ + {"role": "system"}, + {"role": "developer", "tool_definitions": ["get_weather"]}, + {"role": "user", "content": messages[0]["content"]}, + ], + ) + + # Test the Chat Completion response for the first turn's output + commentary_str = "We'll call get_weather." + tool_args_str = '{"location": "Paris"}' + response_str = ( + f"<|channel|>commentary<|message|>{commentary_str}<|end|>" + "<|start|>assistant to=functions.get_weather<|channel|>commentary" + f"<|constrain|>json<|message|>{tool_args_str}<|call|>" + ) + response = await self.generate_response_from_harmony_str( + serving_chat, req, response_str, stream=stream + ) + verify_chat_response( + response, + content=commentary_str, + tool_calls=[("get_weather", tool_args_str)], + ) + + tool_call = response.choices[0].message.tool_calls[0] + + # Add the output messages from the first turn as input to the second turn + for choice in response.choices: + messages.append(choice.message.model_dump(exclude_none=True)) + + # Add our tool output message + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": "20 degrees Celsius", + }, + ) + + # Test the Harmony messages for the second turn's input + req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) + verify_harmony_messages( + input_messages_2, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user"}, + { + "role": "assistant", + "channel": "commentary", + "content": commentary_str, + }, + { + "role": "assistant", + "channel": "commentary", + "recipient": "functions.get_weather", + "content": tool_args_str, + }, + { + "role": "tool", + "author_name": "functions.get_weather", + "channel": "commentary", + "recipient": "assistant", + "content": "20 degrees Celsius", + }, + ], + ) + + @pytest.mark.asyncio + async def test_tools_and_reasoning( + self, serving_chat, stream, weather_tools, weather_messages_start + ): + tools = weather_tools + messages = list(weather_messages_start) + + # Test the Harmony messages for the first turn's input + req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools) + input_messages, _ = serving_chat._make_request_with_harmony(req) + verify_harmony_messages( + input_messages, + [ + {"role": "system"}, + {"role": "developer", "tool_definitions": ["get_weather"]}, + {"role": "user", "content": messages[0]["content"]}, + ], + ) + + # Test the Chat Completion response for the first turn's output + reasoning_str = "I'll call get_weather." + tool_args_str = '{"location": "Paris"}' + response_str = ( + f"<|channel|>analysis<|message|>{reasoning_str}<|end|>" + "<|start|>assistant to=functions.get_weather<|channel|>commentary" + f"<|constrain|>json<|message|>{tool_args_str}<|call|>" + ) + response = await self.generate_response_from_harmony_str( + serving_chat, req, response_str, stream=stream + ) + verify_chat_response( + response, + reasoning=reasoning_str, + tool_calls=[("get_weather", tool_args_str)], + ) + + tool_call = response.choices[0].message.tool_calls[0] + + # Add the output messages from the first turn as input to the second turn + for choice in response.choices: + messages.append(choice.message.model_dump(exclude_none=True)) + + # Add our tool output message + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": "20 degrees Celsius", + }, + ) + + # Test the Harmony messages for the second turn's input + req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) + verify_harmony_messages( + input_messages_2, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user"}, + { + "role": "assistant", + "channel": "analysis", + "content": reasoning_str, + }, + { + "role": "assistant", + "channel": "commentary", + "recipient": "functions.get_weather", + "content": tool_args_str, + }, + { + "role": "tool", + "author_name": "functions.get_weather", + "channel": "commentary", + "recipient": "assistant", + "content": "20 degrees Celsius", + }, + ], + ) + + @pytest.mark.asyncio + async def test_multi_turn_tools_and_reasoning( + self, serving_chat, stream, weather_tools, weather_messages_start + ): + tools = weather_tools + messages = list(weather_messages_start) + + # Test the Harmony messages for the first turn's input + req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools) + input_messages, _ = serving_chat._make_request_with_harmony(req) + verify_harmony_messages( + input_messages, + [ + {"role": "system"}, + {"role": "developer", "tool_definitions": ["get_weather"]}, + {"role": "user", "content": messages[0]["content"]}, + ], + ) + + # Test the Chat Completion response for the first turn's output + reasoning_str = "I'll call get_weather." + paris_tool_args_str = '{"location": "Paris"}' + response_str = ( + f"<|channel|>analysis<|message|>{reasoning_str}<|end|>" + "<|start|>assistant to=functions.get_weather<|channel|>commentary" + f"<|constrain|>json<|message|>{paris_tool_args_str}<|call|>" + ) + response = await self.generate_response_from_harmony_str( + serving_chat, req, response_str, stream=stream + ) + verify_chat_response( + response, + reasoning=reasoning_str, + tool_calls=[("get_weather", paris_tool_args_str)], + ) + + tool_call = response.choices[0].message.tool_calls[0] + + # Add the output messages from the first turn as input to the second turn + for choice in response.choices: + messages.append(choice.message.model_dump(exclude_none=True)) + + # Add our tool output message + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": "20 degrees Celsius", + }, + ) + + # Test the Harmony messages for the second turn's input + req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) + verify_harmony_messages( + input_messages_2, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user"}, + { + "role": "assistant", + "channel": "analysis", + "content": reasoning_str, + }, + { + "role": "assistant", + "channel": "commentary", + "recipient": "functions.get_weather", + "content": paris_tool_args_str, + }, + { + "role": "tool", + "author_name": "functions.get_weather", + "channel": "commentary", + "recipient": "assistant", + "content": "20 degrees Celsius", + }, + ], + ) + + # Test the Chat Completion response for the second turn's output + paris_weather_str = "The weather in Paris today is 20 degrees Celsius." + response_str = f"<|channel|>final<|message|>{paris_weather_str}<|end|>" + response_2 = await self.generate_response_from_harmony_str( + serving_chat, req_2, response_str, stream=stream + ) + verify_chat_response(response_2, content=paris_weather_str) + + # Add the output messages from the second turn as input to the third turn + for choice in response_2.choices: + messages.append(choice.message.model_dump(exclude_none=True)) + + # Add a new user message for the third turn + messages.append( + { + "role": "user", + "content": "What's the weather like in Boston today?", + }, + ) + + # Test the Harmony messages for the third turn's input + req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages_3, _ = serving_chat._make_request_with_harmony(req_3) + verify_harmony_messages( + input_messages_3, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user"}, + { + "role": "assistant", + "channel": "commentary", + "recipient": "functions.get_weather", + "content": paris_tool_args_str, + }, + { + "role": "tool", + "author_name": "functions.get_weather", + "channel": "commentary", + "recipient": "assistant", + "content": "20 degrees Celsius", + }, + { + "role": "assistant", + "channel": "final", + "content": paris_weather_str, + }, + {"role": "user", "content": messages[-1]["content"]}, + ], + ) + + # Test the Chat Completion response for the third turn's output + reasoning_str = "I'll call get_weather." + boston_tool_args_str = '{"location": "Boston"}' + response_str = ( + f"<|channel|>analysis<|message|>{reasoning_str}<|end|>" + "<|start|>assistant to=functions.get_weather<|channel|>commentary" + f"<|constrain|>json<|message|>{boston_tool_args_str}<|call|>" + ) + response_3 = await self.generate_response_from_harmony_str( + serving_chat, req, response_str, stream=stream + ) + verify_chat_response( + response_3, + reasoning=reasoning_str, + tool_calls=[("get_weather", boston_tool_args_str)], + ) + + tool_call = response_3.choices[0].message.tool_calls[0] + + # Add the output messages from the third turn as input to the fourth turn + for choice in response_3.choices: + messages.append(choice.message.model_dump(exclude_none=True)) + + # Add our tool output message + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": "10 degrees Celsius", + }, + ) + + # Test the Harmony messages for the fourth turn's input + req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages_4, _ = serving_chat._make_request_with_harmony(req_4) + verify_harmony_messages( + input_messages_4, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user"}, + {"role": "assistant"}, + {"role": "tool"}, + { + "role": "assistant", + "channel": "final", + }, + {"role": "user"}, + { + "role": "assistant", + "channel": "analysis", + "content": reasoning_str, + }, + { + "role": "assistant", + "channel": "commentary", + "recipient": "functions.get_weather", + "content": boston_tool_args_str, + }, + { + "role": "tool", + "author_name": "functions.get_weather", + "channel": "commentary", + "recipient": "assistant", + "content": "10 degrees Celsius", + }, + ], + ) + + @pytest.mark.asyncio + async def test_non_tool_reasoning(self, serving_chat): + messages: list[dict[str, Any]] = [ + { + "role": "user", + "content": "What's 2+2?", + }, + { + "role": "assistant", + "reasoning": "Adding 2 and 2 is easy. The result is 4.", + "content": "4", + }, + ] + req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages, _ = serving_chat._make_request_with_harmony(req) + + verify_harmony_messages( + input_messages, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user", "content": messages[0]["content"]}, + # The reasoning that would have resulted in an analysis message is + # dropped because of a later assistant message to the final channel. + { + "role": "assistant", + "channel": "final", + "content": messages[1]["content"], + }, + ], + ) + + @pytest.mark.asyncio + async def test_non_tool_reasoning_empty_content(self, serving_chat): + messages: list[dict[str, Any]] = [ + { + "role": "user", + "content": "What's 2+2?", + }, + { + "role": "assistant", + "reasoning": "Adding 2 and 2 is easy. The result is 4.", + "content": "", + }, + ] + req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages, _ = serving_chat._make_request_with_harmony(req) + + verify_harmony_messages( + input_messages, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user", "content": messages[0]["content"]}, + { + "role": "assistant", + "channel": "analysis", + "content": messages[1]["reasoning"], + }, + ], + ) + + @pytest.mark.asyncio + async def test_non_tool_reasoning_empty_content_list(self, serving_chat): + messages: list[dict[str, Any]] = [ + { + "role": "user", + "content": "What's 2+2?", + }, + { + "role": "assistant", + "reasoning": "Adding 2 and 2 is easy. The result is 4.", + "content": [], + }, + ] + req = ChatCompletionRequest(model=MODEL_NAME, messages=messages) + input_messages, _ = serving_chat._make_request_with_harmony(req) + + verify_harmony_messages( + input_messages, + [ + {"role": "system"}, + {"role": "developer"}, + {"role": "user", "content": messages[0]["content"]}, + { + "role": "assistant", + "channel": "analysis", + "content": messages[1]["reasoning"], + }, + ], + ) diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 956a06dc5487c..192c7cafb7493 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -10,7 +10,7 @@ import pytest from vllm.config import ModelConfig from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer @pytest.fixture() diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index cf00f0a042241..7d03dccec30de 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -21,7 +21,7 @@ from vllm.entrypoints.openai.serving_responses import ( extract_tool_types, ) from vllm.entrypoints.tool_server import ToolServer -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import TokensPrompt class MockConversationContext(ConversationContext): @@ -237,7 +237,7 @@ class TestValidateGeneratorInput: """Test _validate_generator_input with valid prompt length""" # Create an engine prompt with valid length (less than max_model_len) valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len - engine_prompt = EngineTokensPrompt(prompt_token_ids=valid_prompt_token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=valid_prompt_token_ids) # Call the method result = serving_responses_instance._validate_generator_input(engine_prompt) @@ -247,7 +247,7 @@ class TestValidateGeneratorInput: # create an invalid engine prompt invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len - engine_prompt = EngineTokensPrompt(prompt_token_ids=invalid_prompt_token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=invalid_prompt_token_ids) # Call the method result = serving_responses_instance._validate_generator_input(engine_prompt) diff --git a/tests/entrypoints/openai/test_sparse_tensor_validation.py b/tests/entrypoints/openai/test_sparse_tensor_validation.py new file mode 100644 index 0000000000000..907c82b57dead --- /dev/null +++ b/tests/entrypoints/openai/test_sparse_tensor_validation.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Sparse tensor validation in embedding APIs. + +Tests verify that malicious sparse tensors are rejected before they can trigger +out-of-bounds memory writes during to_dense() operations. +""" + +import base64 +import io + +import pytest +import torch + +from vllm.entrypoints.renderer import CompletionRenderer +from vllm.multimodal.audio import AudioEmbeddingMediaIO +from vllm.multimodal.image import ImageEmbeddingMediaIO + + +def _encode_tensor(tensor: torch.Tensor) -> bytes: + """Helper to encode a tensor as base64 bytes.""" + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + return base64.b64encode(buffer.read()) + + +def _create_malicious_sparse_tensor() -> torch.Tensor: + """ + Create a malicious sparse COO tensor with out-of-bounds indices. + + This tensor has indices that point beyond the declared shape, which would + cause an out-of-bounds write when converted to dense format without + validation. + """ + # Create a 3x3 sparse tensor but with indices pointing to (10, 10) + indices = torch.tensor([[10], [10]]) # Out of bounds for 3x3 shape + values = torch.tensor([1.0]) + shape = (3, 3) + + # Create sparse tensor (this will be invalid) + sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32) + return sparse_tensor + + +def _create_valid_sparse_tensor() -> torch.Tensor: + """Create a valid sparse COO tensor for baseline testing.""" + indices = torch.tensor([[0, 1, 2], [0, 1, 2]]) + values = torch.tensor([1.0, 2.0, 3.0]) + shape = (3, 3) + + sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32) + return sparse_tensor + + +def _create_valid_dense_tensor() -> torch.Tensor: + """Create a valid dense tensor for baseline testing.""" + return torch.randn(10, 768, dtype=torch.float32) # (seq_len, hidden_size) + + +class TestPromptEmbedsValidation: + """Test sparse tensor validation in prompt embeddings (Completions API).""" + + def test_valid_dense_tensor_accepted(self, model_config): + """Baseline: Valid dense tensors should work normally.""" + renderer = CompletionRenderer(model_config) + + valid_tensor = _create_valid_dense_tensor() + encoded = _encode_tensor(valid_tensor) + + # Should not raise any exception + result = renderer.load_prompt_embeds(encoded) + assert len(result) == 1 + assert result[0]["prompt_embeds"].shape == valid_tensor.shape + + def test_valid_sparse_tensor_accepted(self): + """Baseline: Valid sparse tensors should load successfully.""" + io_handler = ImageEmbeddingMediaIO() + + valid_sparse = _create_valid_sparse_tensor() + encoded = _encode_tensor(valid_sparse) + + # Should not raise any exception (sparse tensors remain sparse) + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_sparse.shape + + def test_malicious_sparse_tensor_rejected(self, model_config): + """Security: Malicious sparse tensors should be rejected.""" + renderer = CompletionRenderer(model_config) + + malicious_tensor = _create_malicious_sparse_tensor() + encoded = _encode_tensor(malicious_tensor) + + # Should raise RuntimeError due to invalid sparse tensor + with pytest.raises((RuntimeError, ValueError)) as exc_info: + renderer.load_prompt_embeds(encoded) + + # Error should indicate sparse tensor validation failure + error_msg = str(exc_info.value).lower() + assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg + + def test_extremely_large_indices_rejected(self, model_config): + """Security: Sparse tensors with extremely large indices should be rejected.""" + renderer = CompletionRenderer(model_config) + + # Create tensor with indices far beyond reasonable bounds + indices = torch.tensor([[999999], [999999]]) + values = torch.tensor([1.0]) + shape = (10, 10) + + malicious_tensor = torch.sparse_coo_tensor( + indices, values, shape, dtype=torch.float32 + ) + encoded = _encode_tensor(malicious_tensor) + + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(encoded) + + def test_negative_indices_rejected(self, model_config): + """Security: Sparse tensors with negative indices should be rejected.""" + renderer = CompletionRenderer(model_config) + + # Create tensor with negative indices + indices = torch.tensor([[-1], [-1]]) + values = torch.tensor([1.0]) + shape = (10, 10) + + malicious_tensor = torch.sparse_coo_tensor( + indices, values, shape, dtype=torch.float32 + ) + encoded = _encode_tensor(malicious_tensor) + + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(encoded) + + +class TestImageEmbedsValidation: + """Test sparse tensor validation in image embeddings (Chat API).""" + + def test_valid_dense_tensor_accepted(self): + """Baseline: Valid dense tensors should work normally.""" + io_handler = ImageEmbeddingMediaIO() + + valid_tensor = _create_valid_dense_tensor() + encoded = _encode_tensor(valid_tensor) + + # Should not raise any exception + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_tensor.shape + + def test_valid_sparse_tensor_accepted(self): + """Baseline: Valid sparse tensors should load successfully.""" + io_handler = AudioEmbeddingMediaIO() + + valid_sparse = _create_valid_sparse_tensor() + encoded = _encode_tensor(valid_sparse) + + # Should not raise any exception (sparse tensors remain sparse) + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_sparse.shape + + def test_malicious_sparse_tensor_rejected(self): + """Security: Malicious sparse tensors should be rejected.""" + io_handler = ImageEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + encoded = _encode_tensor(malicious_tensor) + + # Should raise RuntimeError due to invalid sparse tensor + with pytest.raises((RuntimeError, ValueError)) as exc_info: + io_handler.load_base64("", encoded.decode("utf-8")) + + error_msg = str(exc_info.value).lower() + assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg + + def test_load_bytes_validates(self): + """Security: Validation should also work for load_bytes method.""" + io_handler = ImageEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + buffer = io.BytesIO() + torch.save(malicious_tensor, buffer) + buffer.seek(0) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_bytes(buffer.read()) + + +class TestAudioEmbedsValidation: + """Test sparse tensor validation in audio embeddings (Chat API).""" + + def test_valid_dense_tensor_accepted(self): + """Baseline: Valid dense tensors should work normally.""" + io_handler = AudioEmbeddingMediaIO() + + valid_tensor = _create_valid_dense_tensor() + encoded = _encode_tensor(valid_tensor) + + # Should not raise any exception + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.shape == valid_tensor.shape + + def test_valid_sparse_tensor_accepted(self): + """Baseline: Valid sparse tensors should be converted successfully.""" + io_handler = AudioEmbeddingMediaIO() + + valid_sparse = _create_valid_sparse_tensor() + encoded = _encode_tensor(valid_sparse) + + # Should not raise any exception + result = io_handler.load_base64("", encoded.decode("utf-8")) + assert result.is_sparse is False + + def test_malicious_sparse_tensor_rejected(self): + """Security: Malicious sparse tensors should be rejected.""" + io_handler = AudioEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + encoded = _encode_tensor(malicious_tensor) + + # Should raise RuntimeError due to invalid sparse tensor + with pytest.raises((RuntimeError, ValueError)) as exc_info: + io_handler.load_base64("", encoded.decode("utf-8")) + + error_msg = str(exc_info.value).lower() + assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg + + def test_load_bytes_validates(self): + """Security: Validation should also work for load_bytes method.""" + io_handler = AudioEmbeddingMediaIO() + + malicious_tensor = _create_malicious_sparse_tensor() + buffer = io.BytesIO() + torch.save(malicious_tensor, buffer) + buffer.seek(0) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_bytes(buffer.read()) + + +class TestSparseTensorValidationIntegration: + """ + These tests verify the complete attack chain is blocked at all entry points. + """ + + def test_attack_scenario_completions_api(self, model_config): + """ + Simulate a complete attack through the Completions API. + + Attack scenario: + 1. Attacker crafts malicious sparse tensor + 2. Encodes it as base64 + 3. Sends to /v1/completions with prompt_embeds parameter + 4. Server should reject before memory corruption occurs + """ + renderer = CompletionRenderer(model_config) + + # Step 1-2: Attacker creates malicious payload + attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) + + # Step 3-4: Server processes and should reject + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(attack_payload) + + def test_attack_scenario_chat_api_image(self): + """ + Simulate attack through Chat API with image_embeds. + + Verifies the image embeddings path is protected. + """ + io_handler = ImageEmbeddingMediaIO() + attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_base64("", attack_payload.decode("utf-8")) + + def test_attack_scenario_chat_api_audio(self): + """ + Simulate attack through Chat API with audio_embeds. + + Verifies the audio embeddings path is protected. + """ + io_handler = AudioEmbeddingMediaIO() + attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) + + with pytest.raises((RuntimeError, ValueError)): + io_handler.load_base64("", attack_payload.decode("utf-8")) + + def test_multiple_valid_embeddings_in_batch(self, model_config): + """ + Regression test: Multiple valid embeddings should still work. + + Ensures the fix doesn't break legitimate batch processing. + """ + renderer = CompletionRenderer(model_config) + + valid_tensors = [ + _encode_tensor(_create_valid_dense_tensor()), + _encode_tensor(_create_valid_dense_tensor()), + _encode_tensor(_create_valid_dense_tensor()), + ] + + # Should process all without error + result = renderer.load_prompt_embeds(valid_tensors) + assert len(result) == 3 + + def test_mixed_valid_and_malicious_rejected(self, model_config): + """ + Security: Batch with one malicious tensor should be rejected. + + Even if most tensors are valid, a single malicious one should + cause rejection of the entire batch. + """ + renderer = CompletionRenderer(model_config) + + mixed_batch = [ + _encode_tensor(_create_valid_dense_tensor()), + _encode_tensor(_create_malicious_sparse_tensor()), # Malicious + _encode_tensor(_create_valid_dense_tensor()), + ] + + # Should fail on the malicious tensor + with pytest.raises((RuntimeError, ValueError)): + renderer.load_prompt_embeds(mixed_batch) + + +# Pytest fixtures +@pytest.fixture +def model_config(): + """Mock ModelConfig for testing.""" + from vllm.config import ModelConfig + + return ModelConfig( + model="facebook/opt-125m", + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float32", + seed=0, + enable_prompt_embeds=True, # Required for prompt embeds tests + ) diff --git a/tests/entrypoints/openai/test_token_in_token_out.py b/tests/entrypoints/openai/test_token_in_token_out.py index 25eb5882be89c..c7f8abe27e6e0 100644 --- a/tests/entrypoints/openai/test_token_in_token_out.py +++ b/tests/entrypoints/openai/test_token_in_token_out.py @@ -7,7 +7,7 @@ import tempfile import pytest from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer from ...utils import RemoteOpenAIServer diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 751f94319eb9f..052f9fecc18de 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -5,7 +5,7 @@ import pytest import pytest_asyncio import requests -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer from ...utils import RemoteOpenAIServer diff --git a/tests/entrypoints/openai/test_transcription_validation_whisper.py b/tests/entrypoints/openai/test_transcription_validation_whisper.py index 82c50e58a0168..3c507ee0a3fa7 100644 --- a/tests/entrypoints/openai/test_transcription_validation_whisper.py +++ b/tests/entrypoints/openai/test_transcription_validation_whisper.py @@ -32,24 +32,20 @@ async def whisper_client(server): @pytest.mark.asyncio -async def test_basic_audio(mary_had_lamb): - server_args = ["--enforce-eager"] - +async def test_basic_audio(whisper_client, mary_had_lamb): # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. - with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: - client = remote_server.get_async_client() - transcription = await client.audio.transcriptions.create( - model=MODEL_NAME, - file=mary_had_lamb, - language="en", - response_format="text", - temperature=0.0, - ) - out = json.loads(transcription) - out_text = out["text"] - out_usage = out["usage"] - assert "Mary had a little lamb," in out_text - assert out_usage["seconds"] == 16, out_usage["seconds"] + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="text", + temperature=0.0, + ) + out = json.loads(transcription) + out_text = out["text"] + out_usage = out["usage"] + assert "Mary had a little lamb," in out_text + assert out_usage["seconds"] == 16, out_usage["seconds"] @pytest.mark.asyncio @@ -235,3 +231,16 @@ async def test_audio_prompt(mary_had_lamb, whisper_client): ) out_prompt = json.loads(transcription_wprompt)["text"] assert prefix in out_prompt + + +@pytest.mark.asyncio +async def test_audio_with_timestamp(mary_had_lamb, whisper_client): + transcription = await whisper_client.audio.transcriptions.create( + model=MODEL_NAME, + file=mary_had_lamb, + language="en", + response_format="verbose_json", + temperature=0.0, + ) + assert transcription.segments is not None + assert len(transcription.segments) > 0 diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index d83c6726e72da..ae8860ee877b4 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -8,6 +8,7 @@ import pytest import pytest_asyncio from transformers import AutoProcessor +from vllm.multimodal.base import MediaWithBytes from vllm.multimodal.utils import encode_image_base64, fetch_image from ...utils import RemoteOpenAIServer @@ -111,7 +112,11 @@ def get_hf_prompt_tokens(model_name, content, image_url): "content": f"{placeholder}{content}", } ] - images = [fetch_image(image_url)] + image = fetch_image(image_url) + # Unwrap MediaWithBytes if present + if isinstance(image, MediaWithBytes): + image = image.media + images = [image] prompt = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True diff --git a/tests/entrypoints/openai/test_vision_embeds.py b/tests/entrypoints/openai/test_vision_embeds.py index a6593c5b05e2e..42d9fe4840bbe 100644 --- a/tests/entrypoints/openai/test_vision_embeds.py +++ b/tests/entrypoints/openai/test_vision_embeds.py @@ -2,64 +2,47 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -import io import numpy as np import pytest import requests import torch +from vllm.utils.serial_utils import tensor2base64 + from ...utils import RemoteOpenAIServer -MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" -DTYPE = "float16" - -def _terratorch_dummy_inputs(model_name: str): +def _terratorch_dummy_messages(): pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) - buffer_tiff = io.BytesIO() - torch.save(pixel_values, buffer_tiff) - buffer_tiff.seek(0) - binary_data = buffer_tiff.read() - base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8") - - buffer_coord = io.BytesIO() - torch.save(location_coords, buffer_coord) - buffer_coord.seek(0) - binary_data = buffer_coord.read() - base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8") - - return { - "model": model_name, - "additional_data": {"prompt_token_ids": [1]}, - "encoding_format": "base64", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "image_embeds", - "image_embeds": { - "pixel_values": base64_tensor_embedding, - "location_coords": base64_coord_embedding, - }, - } - ], - } - ], - } + return [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": { + "pixel_values": tensor2base64(pixel_values), + "location_coords": tensor2base64(location_coords), + }, + } + ], + } + ] -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_request(model_name: str): +@pytest.mark.parametrize( + "model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] +) +def test_single_request(model_name: str): args = [ "--runner", "pooling", # use half precision for speed and memory savings in CI environment "--dtype", - DTYPE, + "float16", "--enforce-eager", "--trust-remote-code", "--max-num-seqs", @@ -70,11 +53,15 @@ async def test_single_request(model_name: str): "--enable-mm-embeds", ] - with RemoteOpenAIServer(MODEL_NAME, args) as server: - prompt = _terratorch_dummy_inputs(model_name) - - # test single pooling - response = requests.post(server.url_for("pooling"), json=prompt) + with RemoteOpenAIServer(model_name, args) as server: + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "messages": _terratorch_dummy_messages(), + "encoding_format": "base64", + }, + ) response.raise_for_status() output = response.json()["data"][0]["data"] diff --git a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py new file mode 100644 index 0000000000000..6ac48317e8bc6 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import ( + run_tool_extraction, + run_tool_extraction_streaming, +) +from vllm.entrypoints.openai.protocol import FunctionCall +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser, ToolParserManager + +SIMPLE_ARGS_DICT = { + "action": "create", + "id": "preferences", +} +SIMPLE_FUNCTION_JSON = json.dumps( + { + "name": "manage_user_memory", + "arguments": SIMPLE_ARGS_DICT, + }, + ensure_ascii=False, +) +SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON +SIMPLE_FUNCTION_CALL = FunctionCall( + name="manage_user_memory", + arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False), +) + + +PARAMETERLESS_FUNCTION_JSON = json.dumps( + { + "name": "manage_user_memory", + "arguments": {}, + }, + ensure_ascii=False, +) +PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON +PARAMETERLESS_FUNCTION_CALL = FunctionCall( + name="manage_user_memory", + arguments=json.dumps({}, ensure_ascii=False), +) + + +COMPLEX_ARGS_DICT = { + "action": "create", + "id": "preferences", + "content": { + "short_answers": True, + "hate_emojis": True, + "english_ui": False, + "russian_math_explanations": True, + }, +} +COMPLEX_FUNCTION_JSON = json.dumps( + { + "name": "manage_user_memory", + "arguments": COMPLEX_ARGS_DICT, + }, + ensure_ascii=False, +) +COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON +COMPLEX_FUNCTION_CALL = FunctionCall( + name="manage_user_memory", + arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False), +) + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike): + tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( + default_tokenizer + ) + model_output = "How can I help you today?" + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) + assert content == model_output + assert len(tool_calls) == 0 + + +TEST_CASES = [ + pytest.param( + True, + SIMPLE_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + None, + id="simple_streaming", + ), + pytest.param( + False, + SIMPLE_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + None, + id="simple_nonstreaming", + ), + pytest.param( + True, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_streaming", + ), + pytest.param( + False, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_nonstreaming", + ), + pytest.param( + True, + COMPLEX_FUNCTION_OUTPUT, + [COMPLEX_FUNCTION_CALL], + None, + id="complex_streaming", + ), + pytest.param( + False, + COMPLEX_FUNCTION_OUTPUT, + [COMPLEX_FUNCTION_CALL], + None, + id="complex_nonstreaming", + ), +] + + +@pytest.mark.parametrize( + "streaming, model_output, expected_tool_calls, expected_content", TEST_CASES +) +def test_tool_call( + streaming: bool, + model_output: str, + expected_tool_calls: list[FunctionCall], + expected_content: str | None, + default_tokenizer: TokenizerLike, +): + tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( + default_tokenizer + ) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) + assert content == expected_content + assert len(tool_calls) == len(expected_tool_calls) + for actual, expected in zip(tool_calls, expected_tool_calls): + assert actual.type == "function" + assert actual.function.name == expected.name + actual_args = json.loads(actual.function.arguments) + expected_args = json.loads(expected.arguments) + assert actual_args == expected_args + + +def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike): + tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( + default_tokenizer + ) + model_output_deltas = [ + "function call", + COMPLEX_FUNCTION_JSON[:40], + COMPLEX_FUNCTION_JSON[40:], + ] + reconstructor = run_tool_extraction_streaming( + tool_parser, + model_output_deltas, + assert_one_tool_per_delta=False, + ) + assert len(reconstructor.tool_calls) == 1 + call = reconstructor.tool_calls[0] + assert call.type == "function" + assert call.function.name == "manage_user_memory" + args_dict = json.loads(call.function.arguments) + assert args_dict == COMPLEX_ARGS_DICT diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py index b2303ab0e7b7c..8600aaf639431 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -6,8 +6,8 @@ import json import pytest from vllm.entrypoints.openai.protocol import ChatCompletionRequest -from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from ....utils import RemoteOpenAIServer @@ -271,7 +271,7 @@ async def test_streaming_product_tool_call(): @pytest.fixture def qwen_tokenizer() -> TokenizerLike: - from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.tokenizers import get_tokenizer return get_tokenizer("Qwen/Qwen3-32B") diff --git a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py index bdd5344652c4b..3944575321391 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py @@ -12,7 +12,7 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction_streaming, ) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.tool_parsers import ToolParser, ToolParserManager def make_tool_call(name, arguments): diff --git a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py index 6c286ca90ce48..3ce7801b45975 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py @@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch import pytest from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation -from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser @pytest.fixture diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py index 8aa88a007188f..3bd1ca7f528d0 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction_streaming, ) from vllm.entrypoints.openai.protocol import FunctionCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser, ToolParserManager # Test cases similar to pythonic parser but with Llama4 specific format SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]" diff --git a/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py index a0b9a3c563bc2..3774b3d1833e9 100644 --- a/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py @@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction_streaming, ) from vllm.entrypoints.openai.protocol import FunctionCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser, ToolParserManager # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" diff --git a/tests/entrypoints/openai/tool_parsers/test_openai_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_openai_tool_parser.py new file mode 100644 index 0000000000000..7cb87fd13ecfa --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_openai_tool_parser.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import jsonschema +import openai +import pytest +import pytest_asyncio +from rapidfuzz import fuzz + +from ....utils import RemoteOpenAIServer + +MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + "8192", + "--enforce-eager", + "--enable-auto-tool-choice", + "--tool-call-parser", + "openai", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + """Async fixture providing an OpenAI-compatible vLLM client.""" + async with server.get_async_client() as async_client: + yield async_client + + +# ========================================================== +# Tool Definitions +# ========================================================== +TOOLS = [ + { + "type": "function", + "function": { + "name": "calculator", + "description": "Performs basic arithmetic calculations.", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": ( + "Arithmetic expression to evaluate, e.g. '123 + 456'." + ), + } + }, + "required": ["expression"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_time", + "description": "Retrieves the current local time for a given city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name, e.g. 'New York'.", + } + }, + "required": ["city"], + }, + }, + }, +] + + +# ========================================================== +# Message Examples +# ========================================================== +MESSAGES_CALC = [ + {"role": "user", "content": "Calculate 123 + 456 using the calculator."} +] + +MESSAGES_GET_TIME = [ + {"role": "user", "content": "What is the current time in New York?"} +] + +MESSAGES_MULTIPLE_CALLS = [ + { + "role": "system", + "content": ( + "You can call multiple tools. " + "When using more than one, return single JSON object with tool_calls array" + "containing each tool call with its function name and arguments. " + "Do not output multiple JSON objects separately." + ), + }, + { + "role": "user", + "content": "First, calculate 7 * 8 using the calculator. " + "Then, use get_time to tell me the current time in New York.", + }, +] + +MESSAGES_INVALID_CALL = [ + { + "role": "user", + "content": "Can you help with something, " + "but don’t actually perform any calculation?", + } +] + + +# Expected outputs +FUNC_CALC = "calculator" +FUNC_ARGS_CALC = '{"expression":"123 + 456"}' + +FUNC_TIME = "get_time" +FUNC_ARGS_TIME = '{"city": "New York"}' + + +# ========================================================== +# Utility to extract reasoning and tool calls +# ========================================================== +def extract_reasoning_and_calls(chunks: list) -> tuple[str, list[str], list[str]]: + """ + Extract accumulated reasoning text and tool call arguments + from streaming chunks. + """ + reasoning_content: str = "" + tool_calls: dict[int, dict[str, str]] = {} + + for chunk in chunks: + choice = getattr(chunk.choices[0], "delta", None) + if not choice: + continue + + if hasattr(choice, "reasoning_content") and choice.reasoning_content: + reasoning_content += choice.reasoning_content + + for tc in getattr(choice, "tool_calls", []) or []: + idx = getattr(tc, "index", 0) + tool_entry = tool_calls.setdefault(idx, {"name": "", "arguments": ""}) + + if getattr(tc, "function", None): + func = tc.function + if getattr(func, "name", None): + tool_entry["name"] = func.name + if getattr(func, "arguments", None): + tool_entry["arguments"] += func.arguments + + function_names: list[str] = [v["name"] for _, v in sorted(tool_calls.items())] + arguments: list[str] = [v["arguments"] for _, v in sorted(tool_calls.items())] + + return reasoning_content, arguments, function_names + + +# ========================================================== +# Test Scenarios +# ========================================================== +@pytest.mark.asyncio +async def test_calculator_tool_call_and_argument_accuracy(client: openai.AsyncOpenAI): + """Verify calculator tool call is made and arguments are accurate.""" + + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES_CALC, + tools=TOOLS, + temperature=0.0, + stream=False, + ) + + message = response.choices[0].message + tool_calls = getattr(message, "tool_calls", []) + assert tool_calls, "No tool calls detected" + + calc_call = next((c for c in tool_calls if c.function.name == FUNC_CALC), None) + assert calc_call, "Calculator function not called" + + raw_args = calc_call.function.arguments + assert raw_args, "Calculator arguments missing" + assert "123" in raw_args and "456" in raw_args, ( + f"Expected values not in raw arguments: {raw_args}" + ) + + try: + parsed_args = json.loads(raw_args) + except json.JSONDecodeError: + pytest.fail(f"Invalid JSON in calculator arguments: {raw_args}") + + expected_expr = "123 + 456" + actual_expr = parsed_args.get("expression", "") + similarity = fuzz.ratio(actual_expr, expected_expr) + + assert similarity > 90, ( + f"Expression mismatch: expected '{expected_expr}' " + f"got '{actual_expr}' (similarity={similarity}%)" + ) + + +@pytest.mark.asyncio +async def test_streaming_tool_call_get_time_with_reasoning(client: openai.AsyncOpenAI): + """Verify streamed reasoning and tool call behavior for get_time.""" + + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES_GET_TIME, + tools=TOOLS, + temperature=0.0, + stream=True, + ) + + chunks = [chunk async for chunk in stream] + reasoning, arguments, function_names = extract_reasoning_and_calls(chunks) + + assert FUNC_TIME in function_names, "get_time function not called" + + assert any("New York" in arg for arg in arguments), ( + f"Expected get_time arguments for New York not found in {arguments}" + ) + + assert len(reasoning) > 0, "Expected reasoning content missing" + + assert any(keyword in reasoning for keyword in ["New York", "time", "current"]), ( + f"Reasoning is not relevant to the request: {reasoning}" + ) + + +@pytest.mark.asyncio +async def test_streaming_multiple_tools(client: openai.AsyncOpenAI): + """Test streamed multi-tool response with reasoning.""" + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES_MULTIPLE_CALLS, + tools=TOOLS, + temperature=0.0, + stream=True, + ) + + chunks = [chunk async for chunk in stream] + reasoning, arguments, function_names = extract_reasoning_and_calls(chunks) + + try: + assert FUNC_CALC in function_names, ( + f"Calculator tool missing — found {function_names}" + ) + assert FUNC_TIME in function_names, ( + f"Time tool missing — found {function_names}" + ) + assert len(reasoning) > 0, "Expected reasoning content in streamed response" + except AssertionError as e: + print(f"ERROR: {e}") + + +@pytest.mark.asyncio +async def test_invalid_tool_call(client: openai.AsyncOpenAI): + """ + Verify that ambiguous instructions that should not trigger a tool + do not produce any tool calls. + """ + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES_INVALID_CALL, + tools=TOOLS, + temperature=0.0, + stream=False, + ) + + message = response.choices[0].message + + assert message is not None, "Expected message in response" + assert hasattr(message, "content"), "Expected 'content' field in message" + + tool_calls = getattr(message, "tool_calls", []) + assert not tool_calls, ( + f"Model unexpectedly attempted a tool call on invalid input: {tool_calls}" + ) + + +@pytest.mark.asyncio +async def test_tool_call_with_temperature(client: openai.AsyncOpenAI): + """ + Verify model produces valid tool or text output + under non-deterministic sampling. + """ + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES_CALC, + tools=TOOLS, + temperature=0.7, + stream=False, + ) + + message = response.choices[0].message + assert message is not None, "Expected non-empty message in response" + assert message.tool_calls or message.content, ( + "Response missing both text and tool calls" + ) + + print(f"\nTool calls: {message.tool_calls}") + print(f"Text: {message.content}") + + +@pytest.mark.asyncio +async def test_tool_response_schema_accuracy(client: openai.AsyncOpenAI): + """Validate that tool call arguments adhere to their declared JSON schema.""" + response = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES_MULTIPLE_CALLS, + tools=TOOLS, + temperature=0.0, + ) + + calls = response.choices[0].message.tool_calls + assert calls, "No tool calls produced" + + for call in calls: + func_name = call.function.name + args = json.loads(call.function.arguments) + + schema: dict[str, object] | None = None + for tool_entry in TOOLS: + function_def = tool_entry.get("function") + if ( + function_def + and isinstance(function_def, dict) + and function_def.get("name") == func_name + ): + schema = function_def.get("parameters") + break + + assert schema is not None, f"No matching tool schema found for {func_name}" + + jsonschema.validate(instance=args, schema=schema) + + +@pytest.mark.asyncio +async def test_semantic_consistency_with_temperature(client: openai.AsyncOpenAI): + """Test that temperature variation doesn't cause contradictory reasoning.""" + responses = [] + for temp in [0.0, 0.5, 1.0]: + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES_CALC, + tools=TOOLS, + temperature=temp, + ) + text = (resp.choices[0].message.content or "").strip() + responses.append(text) + + # Compare fuzzy similarity between low- and mid-temperature outputs + low_mid_sim = fuzz.ratio(responses[0], responses[1]) + assert low_mid_sim > 60, ( + f"Semantic drift too large between T=0.0 and T=0.5 ({low_mid_sim}%)" + ) diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index 52202c55e8405..c4cad17fd2d01 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import ( run_tool_extraction_streaming, ) from vllm.entrypoints.openai.protocol import FunctionCall -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser, ToolParserManager # https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index 2d4f5f1734102..0b32e5f899ff4 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -10,8 +10,8 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser class StreamingToolReconstructor: diff --git a/tests/entrypoints/openai/utils.py b/tests/entrypoints/openai/utils.py new file mode 100644 index 0000000000000..501f6dcc91543 --- /dev/null +++ b/tests/entrypoints/openai/utils.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from collections.abc import AsyncGenerator +from typing import Any + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionStreamResponse, + ChatMessage, + UsageInfo, +) + + +async def accumulate_streaming_response( + stream_generator: AsyncGenerator[str, None], +) -> ChatCompletionResponse: + """ + Accumulate streaming SSE chunks into a complete ChatCompletionResponse. + + This helper parses the SSE format and builds up the complete response + by combining all the delta chunks. + """ + accumulated_content = "" + accumulated_reasoning = None + accumulated_tool_calls: list[dict[str, Any]] = [] + role = None + finish_reason = None + response_id = None + created = None + model = None + index = 0 + + async for chunk_str in stream_generator: + # Skip empty lines and [DONE] marker + if not chunk_str.strip() or chunk_str.strip() == "data: [DONE]": + continue + + # Parse SSE format: "data: {json}\n\n" + if chunk_str.startswith("data: "): + json_str = chunk_str[6:].strip() + try: + chunk_data = json.loads(json_str) + # print(f"DEBUG: Parsed chunk_data: {chunk_data}") + chunk = ChatCompletionStreamResponse(**chunk_data) + + # Store metadata from first chunk + if response_id is None: + response_id = chunk.id + created = chunk.created + model = chunk.model + + # Process each choice in the chunk + for choice in chunk.choices: + if choice.delta.role: + role = choice.delta.role + if choice.delta.content: + accumulated_content += choice.delta.content + if choice.delta.reasoning: + if accumulated_reasoning is None: + accumulated_reasoning = "" + accumulated_reasoning += choice.delta.reasoning + if choice.delta.tool_calls: + # Accumulate tool calls + for tool_call_delta in choice.delta.tool_calls: + # Find or create the tool call at this index + while len(accumulated_tool_calls) <= tool_call_delta.index: + accumulated_tool_calls.append( + { + "id": None, + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + + if tool_call_delta.id: + accumulated_tool_calls[tool_call_delta.index]["id"] = ( + tool_call_delta.id + ) + if tool_call_delta.function: + if tool_call_delta.function.name: + accumulated_tool_calls[tool_call_delta.index][ + "function" + ]["name"] += tool_call_delta.function.name + if tool_call_delta.function.arguments: + accumulated_tool_calls[tool_call_delta.index][ + "function" + ]["arguments"] += tool_call_delta.function.arguments + + if choice.finish_reason: + finish_reason = choice.finish_reason + if choice.index is not None: + index = choice.index + + except json.JSONDecodeError: + continue + + # Build the final message + message_kwargs = { + "role": role or "assistant", + "content": accumulated_content if accumulated_content else None, + "reasoning": accumulated_reasoning, + } + + # Only include tool_calls if there are any + if accumulated_tool_calls: + message_kwargs["tool_calls"] = [ + {"id": tc["id"], "type": tc["type"], "function": tc["function"]} + for tc in accumulated_tool_calls + ] + + message = ChatMessage(**message_kwargs) + + # Build the final response + choice = ChatCompletionResponseChoice( + index=index, + message=message, + finish_reason=finish_reason or "stop", + ) + + # Create usage info (with dummy values for tests) + usage = UsageInfo( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + + response = ChatCompletionResponse( + id=response_id or "chatcmpl-test", + object="chat.completion", + created=created or 0, + model=model or "test-model", + choices=[choice], + usage=usage, + ) + + return response + + +def verify_harmony_messages( + messages: list[Any], expected_messages: list[dict[str, Any]] +): + assert len(messages) == len(expected_messages) + for msg, expected in zip(messages, expected_messages): + if "role" in expected: + assert msg.author.role == expected["role"] + if "author_name" in expected: + assert msg.author.name == expected["author_name"] + if "channel" in expected: + assert msg.channel == expected["channel"] + if "recipient" in expected: + assert msg.recipient == expected["recipient"] + if "content" in expected: + assert msg.content[0].text == expected["content"] + if "content_type" in expected: + assert msg.content_type == expected["content_type"] + if "tool_definitions" in expected: + # Check that the tool definitions match the expected list of tool names + actual_tools = [t.name for t in msg.content[0].tools["functions"].tools] + assert actual_tools == expected["tool_definitions"] + + +def verify_chat_response( + response: ChatCompletionResponse, + content: str | None = None, + reasoning: str | None = None, + tool_calls: list[tuple[str, str]] | None = None, +): + assert len(response.choices) == 1 + message = response.choices[0].message + + if content is not None: + assert message.content == content + else: + assert not message.content + + if reasoning is not None: + assert message.reasoning == reasoning + else: + assert not message.reasoning + + if tool_calls: + assert message.tool_calls is not None + assert len(message.tool_calls) == len(tool_calls) + for tc, (expected_name, expected_args) in zip(message.tool_calls, tool_calls): + assert tc.function.name == expected_name + assert tc.function.arguments == expected_args + else: + assert not message.tool_calls diff --git a/tests/entrypoints/pooling/classify/test_offline.py b/tests/entrypoints/pooling/classify/test_offline.py index 1063c3b6b755c..a07fcd372721a 100644 --- a/tests/entrypoints/pooling/classify/test_offline.py +++ b/tests/entrypoints/pooling/classify/test_offline.py @@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM): @pytest.mark.skip_global_cleanup -def test_encode_api(llm: LLM): - # chunked prefill does not support all pooling - err_msg = "pooling_task must be one of.+" - with pytest.raises(ValueError, match=err_msg): - llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) +def test_token_classify(llm: LLM): + llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) def test_score_api(llm: LLM): diff --git a/tests/entrypoints/pooling/classify/test_online.py b/tests/entrypoints/pooling/classify/test_online.py index 6fef688586955..1a6c33b455e65 100644 --- a/tests/entrypoints/pooling/classify/test_online.py +++ b/tests/entrypoints/pooling/classify/test_online.py @@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): - # token_classify uses ALL pooling, which does not support chunked prefill. task = "token_classify" + input_text = ["This product was excellent and exceeded my expectations"] response = requests.post( server.url_for("pooling"), json={ "model": model_name, - "input": "test", + "input": input_text, "encoding_format": "float", "task": task, }, ) - assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith( - f"Task {task} is not supported" - ) + poolings = PoolingResponse.model_validate(response.json()) + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 8 + assert len(poolings.data[0].data[0]) == 2 @pytest.mark.asyncio diff --git a/tests/entrypoints/pooling/embed/test_offline.py b/tests/entrypoints/pooling/embed/test_offline.py index f5eab4c29ae18..12b47b1a08a8b 100644 --- a/tests/entrypoints/pooling/embed/test_offline.py +++ b/tests/entrypoints/pooling/embed/test_offline.py @@ -42,7 +42,7 @@ def llm(): @pytest.mark.skip_global_cleanup -def test_encode_api(llm: LLM): +def test_token_embed(llm: LLM): outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) multi_vector = outputs[0].outputs.data assert multi_vector.shape == (11, 384) diff --git a/tests/entrypoints/pooling/embed/test_online.py b/tests/entrypoints/pooling/embed/test_online.py index 6aac649bc3035..f96338c47f0be 100644 --- a/tests/entrypoints/pooling/embed/test_online.py +++ b/tests/entrypoints/pooling/embed/test_online.py @@ -18,12 +18,13 @@ from tests.utils import RemoteOpenAIServer from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse from vllm.platforms import current_platform -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer from vllm.utils.serial_utils import ( EMBED_DTYPE_TO_TORCH_DTYPE, ENDIANNESS, MetadataItem, binary2tensor, + build_metadata_items, decode_pooling_output, ) @@ -344,6 +345,55 @@ async def test_bytes_embed_dtype_and_endianness( ) +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_bytes_only_embed_dtype_and_endianness( + server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str +): + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] * 2 + + responses_float = await client.embeddings.create( + input=input_texts, model=model_name, encoding_format="float" + ) + float_data = [d.embedding for d in responses_float.data] + embedding_size = len(float_data[0]) + + for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()): + for endianness in ENDIANNESS: + responses_bytes = requests.post( + server.url_for("/v1/embeddings"), + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "bytes_only", + "embed_dtype": embed_dtype, + "endianness": endianness, + }, + ) + + assert "metadata" not in responses_bytes.headers + body = responses_bytes.content + items = build_metadata_items( + embed_dtype=embed_dtype, + endianness=endianness, + shape=(embedding_size,), + n_request=len(input_texts), + ) + + bytes_data = decode_pooling_output(items=items, body=body) + bytes_data = [x.to(torch.float32).tolist() for x in bytes_data] + + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=bytes_data, + name_0="float_data", + name_1="bytes_data", + tol=1e-2, + ) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"]) diff --git a/tests/entrypoints/pooling/embed/test_online_vision.py b/tests/entrypoints/pooling/embed/test_online_vision.py index 83e7048b9def6..eebbcdd2e4396 100644 --- a/tests/entrypoints/pooling/embed/test_online_vision.py +++ b/tests/entrypoints/pooling/embed/test_online_vision.py @@ -9,6 +9,7 @@ from transformers import AutoProcessor from tests.utils import VLLM_PATH, RemoteOpenAIServer from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse +from vllm.multimodal.base import MediaWithBytes from vllm.multimodal.utils import encode_image_base64, fetch_image MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" @@ -62,7 +63,11 @@ def get_hf_prompt_tokens(model_name, content, image_url): placeholder = "<|image_1|> " prompt = f"{placeholder}{content}" - images = [fetch_image(image_url)] + image = fetch_image(image_url) + # Unwrap MediaWithBytes if present + if isinstance(image, MediaWithBytes): + image = image.media + images = [image] inputs = processor(prompt, images, return_tensors="pt") return inputs.input_ids.shape[1] diff --git a/tests/entrypoints/pooling/pooling/test_online.py b/tests/entrypoints/pooling/pooling/test_online.py index 977c74d54a351..33add5bdaef49 100644 --- a/tests/entrypoints/pooling/pooling/test_online.py +++ b/tests/entrypoints/pooling/pooling/test_online.py @@ -12,12 +12,13 @@ import torch from tests.models.utils import check_embeddings_close from tests.utils import RemoteOpenAIServer from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer from vllm.utils.serial_utils import ( EMBED_DTYPE_TO_TORCH_DTYPE, ENDIANNESS, MetadataItem, binary2tensor, + build_metadata_items, decode_pooling_output, ) @@ -352,6 +353,61 @@ async def test_bytes_embed_dtype_and_endianness( ) +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_bytes_only_embed_dtype_and_endianness( + server: RemoteOpenAIServer, model_name: str +): + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] * 2 + + url = server.url_for("pooling") + float_response = requests.post( + url, + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "float", + }, + ) + responses_float = PoolingResponse.model_validate(float_response.json()) + float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data] + n_tokens = responses_float.usage.prompt_tokens // len(input_texts) + + for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()): + for endianness in ENDIANNESS: + responses_bytes = requests.post( + url, + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "bytes_only", + "embed_dtype": embed_dtype, + "endianness": endianness, + }, + ) + + assert "metadata" not in responses_bytes.headers + body = responses_bytes.content + items = build_metadata_items( + embed_dtype=embed_dtype, + endianness=endianness, + shape=(n_tokens, 1), + n_request=len(input_texts), + ) + bytes_data = decode_pooling_output(items=items, body=body) + bytes_data = [x.to(torch.float32).view(-1).tolist() for x in bytes_data] + + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=bytes_data, + name_0="float_data", + name_1="bytes_data", + tol=1e-2, + ) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"]) diff --git a/tests/entrypoints/pooling/reward/test_offline.py b/tests/entrypoints/pooling/reward/test_offline.py index 0255704cecd94..b061b55145155 100644 --- a/tests/entrypoints/pooling/reward/test_offline.py +++ b/tests/entrypoints/pooling/reward/test_offline.py @@ -36,6 +36,13 @@ def llm(): cleanup_dist_env_and_memory() +@pytest.mark.skip_global_cleanup +def test_config(llm: LLM): + vllm_config = llm.llm_engine.vllm_config + assert vllm_config.cache_config.enable_prefix_caching + assert vllm_config.scheduler_config.enable_chunked_prefill + + def test_pooling_params(llm: LLM): def get_outputs(use_activation): outputs = llm.reward( diff --git a/tests/entrypoints/sagemaker/conftest.py b/tests/entrypoints/sagemaker/conftest.py index ad219eec18b79..1c34d738fa7a3 100644 --- a/tests/entrypoints/sagemaker/conftest.py +++ b/tests/entrypoints/sagemaker/conftest.py @@ -45,7 +45,10 @@ def basic_server_with_lora(smollm2_lora_files): "64", ] - envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"} + envs = { + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True", + "SAGEMAKER_ENABLE_STATEFUL_SESSIONS": "True", + } with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=envs) as remote_server: yield remote_server diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index a351cda60621f..a87a4c35d3dc7 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,6 +6,7 @@ from collections.abc import Mapping from typing import Literal import pytest +import torch from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.assets.audio import AudioAsset @@ -28,8 +29,9 @@ from vllm.multimodal.utils import ( encode_image_base64, encode_video_base64, ) -from vllm.tokenizers import MistralTokenizer -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.utils.serial_utils import tensor2base64 from ..models.registry import HF_EXAMPLE_MODELS from ..utils import VLLM_PATH @@ -86,11 +88,6 @@ def phi3v_model_config_image_embeds(): ) -@pytest.fixture(scope="module") -def phi3v_tokenizer(): - return get_tokenizer(PHI3V_MODEL_ID) - - @pytest.fixture(scope="function") def qwen2_audio_model_config(): return ModelConfig( @@ -116,11 +113,6 @@ def audio_embeds_model_config(): ) -@pytest.fixture(scope="module") -def qwen2_audio_tokenizer(): - return get_tokenizer(QWEN2AUDIO_MODEL_ID) - - @pytest.fixture(scope="function") def qwen25omni_model_config_mm_interleaved(): return ModelConfig( @@ -135,11 +127,6 @@ def qwen25omni_model_config_mm_interleaved(): ) -@pytest.fixture(scope="module") -def qwen25omni_tokenizer(): - return get_tokenizer(QWEN25OMNI_MODEL_ID) - - @pytest.fixture(scope="function") def mistral_model_config(): return ModelConfig( @@ -151,11 +138,6 @@ def mistral_model_config(): ) -@pytest.fixture(scope="module") -def mistral_tokenizer(): - return get_tokenizer(MISTRAL_MODEL_ID) - - @pytest.fixture(scope="module") def image_url(): image = ImageAsset("cherry_blossom") @@ -240,7 +222,6 @@ def _assert_mm_data_inputs( def test_parse_chat_messages_single_image( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -254,7 +235,6 @@ def test_parse_chat_messages_single_image( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -267,7 +247,6 @@ def test_parse_chat_messages_single_image( def test_parse_chat_messages_single_image_with_uuid( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -288,7 +267,6 @@ def test_parse_chat_messages_single_image_with_uuid( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -301,7 +279,6 @@ def test_parse_chat_messages_single_image_with_uuid( def test_parse_chat_messages_single_empty_image_with_uuid( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -320,7 +297,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -333,7 +309,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid( def test_parse_chat_messages_single_image_with_bad_uuid_format( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -355,7 +330,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -368,7 +342,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format( def test_parse_chat_messages_multiple_images_with_uuids( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -398,7 +371,6 @@ def test_parse_chat_messages_multiple_images_with_uuids( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -414,7 +386,6 @@ def test_parse_chat_messages_multiple_images_with_uuids( def test_parse_chat_messages_multiple_empty_images_with_uuids( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -440,7 +411,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -456,7 +426,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids( def test_parse_chat_messages_mixed_empty_images_with_uuids( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -484,7 +453,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -501,7 +469,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids( @pytest.mark.asyncio async def test_parse_chat_messages_single_image_with_uuid_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -520,7 +487,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -534,7 +500,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async( @pytest.mark.asyncio async def test_parse_chat_messages_empty_image_with_uuid_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -553,7 +518,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -567,7 +531,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_with_uuids_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -593,7 +556,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -610,7 +572,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -636,7 +597,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -653,7 +613,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid2 = "my_uuid_2" @@ -677,7 +636,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -693,7 +651,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( def test_parse_chat_messages_empty_system( mistral_model_config, - mistral_tokenizer, ): # Test string format conversation, _, _ = parse_chat_messages( @@ -705,7 +662,6 @@ def test_parse_chat_messages_empty_system( }, ], mistral_model_config, - mistral_tokenizer, content_format="string", ) assert conversation == [ @@ -723,7 +679,6 @@ def test_parse_chat_messages_empty_system( }, ], mistral_model_config, - mistral_tokenizer, content_format="openai", ) assert conversation == [ @@ -735,7 +690,6 @@ def test_parse_chat_messages_empty_system( @pytest.mark.asyncio async def test_parse_chat_messages_single_image_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_future, mm_uuids = parse_chat_messages_futures( @@ -749,7 +703,6 @@ async def test_parse_chat_messages_single_image_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -762,7 +715,6 @@ async def test_parse_chat_messages_single_image_async( def test_parse_chat_messages_multiple_images( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -780,7 +732,6 @@ def test_parse_chat_messages_multiple_images( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -796,7 +747,6 @@ def test_parse_chat_messages_multiple_images( def test_parse_chat_messages_empty_pil_image_with_uuid( phi3v_model_config, - phi3v_tokenizer, ): uuid = "abcd" conversation, mm_data, mm_uuids = parse_chat_messages( @@ -810,7 +760,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -826,7 +775,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( def test_parse_chat_messages_empty_image_embeds_with_uuid( phi3v_model_config_image_embeds, - phi3v_tokenizer, ): uuid = "abcd" conversation, mm_data, mm_uuids = parse_chat_messages( @@ -840,7 +788,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( } ], phi3v_model_config_image_embeds, - phi3v_tokenizer, content_format="string", ) @@ -850,15 +797,18 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( "content": "<|image_1|>\nWhat's in this image?", } ] + assert mm_data is not None assert "image" in mm_data - assert mm_data["image"] is None + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 1 + assert mm_data["image"][0] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) def test_parse_chat_messages_empty_audio_embeds_with_uuid( audio_embeds_model_config, - qwen2_audio_tokenizer, ): """Test audio_embeds with UUID (no actual embeds data).""" uuid = "test-audio-uuid-123" @@ -874,27 +824,24 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid( } ], audio_embeds_model_config, - qwen2_audio_tokenizer, content_format="string", ) # Should have audio in mm_data as None (UUID provided) assert mm_data is not None assert "audio" in mm_data - assert mm_data["audio"] is None + assert isinstance(mm_data["audio"], list) + assert len(mm_data["audio"]) == 1 + assert mm_data["audio"][0] is None + # UUID should be recorded - assert mm_uuids is not None - assert "audio" in mm_uuids _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid]) def test_parse_chat_messages_audio_embeds_with_string( audio_embeds_model_config, - qwen2_audio_tokenizer, ): """Test audio_embeds with base64 string embedding data.""" - import base64 - import io import torch @@ -902,11 +849,7 @@ def test_parse_chat_messages_audio_embeds_with_string( audio_embedding = torch.randn(1, 128, 768) # Encode it as base64 - buffer = io.BytesIO() - torch.save(audio_embedding, buffer) - buffer.seek(0) - binary_data = buffer.read() - base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8") + base64_audio_embedding = tensor2base64(audio_embedding) conversation, mm_data, mm_uuids = parse_chat_messages( [ @@ -922,7 +865,6 @@ def test_parse_chat_messages_audio_embeds_with_string( } ], audio_embeds_model_config, - qwen2_audio_tokenizer, content_format="string", ) @@ -940,11 +882,8 @@ def test_parse_chat_messages_audio_embeds_with_string( @pytest.mark.asyncio async def test_parse_chat_messages_audio_embeds_async( audio_embeds_model_config, - qwen2_audio_tokenizer, ): """Test audio_embeds with async futures.""" - import base64 - import io import torch @@ -952,11 +891,7 @@ async def test_parse_chat_messages_audio_embeds_async( audio_embedding = torch.randn(1, 128, 768) # Encode it as base64 - buffer = io.BytesIO() - torch.save(audio_embedding, buffer) - buffer.seek(0) - binary_data = buffer.read() - base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8") + base64_audio_embedding = tensor2base64(audio_embedding) conversation, mm_future, mm_uuids = parse_chat_messages_futures( [ @@ -972,7 +907,6 @@ async def test_parse_chat_messages_audio_embeds_async( } ], audio_embeds_model_config, - qwen2_audio_tokenizer, content_format="string", ) @@ -988,10 +922,186 @@ async def test_parse_chat_messages_audio_embeds_async( _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) +def test_parse_chat_messages_multiple_image_embeds( + phi3v_model_config_image_embeds, +): + """Test that multiple image_embeds in a single message are now supported. + + This test validates the fix for the limitation that previously only allowed + one message with {'type': 'image_embeds'}. Now multiple image embeddings + can be provided in a single request, similar to regular images. + """ + # Create two sample image embedding tensors + image_embedding_1 = torch.randn(256, 1024) + image_embedding_2 = torch.randn(128, 1024) + + # Encode them as base64 using the convenience function + base64_image_embedding_1 = tensor2base64(image_embedding_1) + base64_image_embedding_2 = tensor2base64(image_embedding_2) + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": base64_image_embedding_1, + }, + { + "type": "image_embeds", + "image_embeds": base64_image_embedding_2, + }, + {"type": "text", "text": "Describe these two images."}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nDescribe these two images.", + } + ] + + # Verify mm_data contains a list of embeddings (not a single embedding) + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 2 + + # Verify each embedding has the correct shape + assert isinstance(mm_data["image"][0], torch.Tensor) + assert mm_data["image"][0].shape == image_embedding_1.shape + assert isinstance(mm_data["image"][1], torch.Tensor) + assert mm_data["image"][1].shape == image_embedding_2.shape + + # Verify UUIDs (None since we didn't provide any) + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + +def test_parse_chat_messages_multiple_image_embeds_with_uuids( + phi3v_model_config_image_embeds, +): + """Test multiple image_embeds with UUIDs. + + This validates that UUIDs are properly tracked for multiple embeddings. + """ + uuid1 = "image-uuid-1" + uuid2 = "image-uuid-2" + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": None, + "uuid": uuid1, + }, + { + "type": "image_embeds", + "image_embeds": None, + "uuid": uuid2, + }, + {"type": "text", "text": "Compare these images."}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nCompare these images.", + } + ] + + # Verify mm_data contains a list with None values (UUID references) + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 2 + assert mm_data["image"][0] is None + assert mm_data["image"][1] is None + + # Verify UUIDs are correctly tracked + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[uuid1, uuid2]) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_image_embeds_async( + phi3v_model_config_image_embeds, +): + """Test multiple image_embeds with async parsing. + + This validates the AsyncMultiModalItemTracker also supports multiple embeddings. + """ + # Create two sample image embedding tensors + image_embedding_1 = torch.randn(200, 768) + image_embedding_2 = torch.randn(150, 768) + + # Encode them as base64 using the convenience function + base64_image_embedding_1 = tensor2base64(image_embedding_1) + base64_image_embedding_2 = tensor2base64(image_embedding_2) + + conversation, mm_future, mm_uuids = parse_chat_messages_futures( + [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": base64_image_embedding_1, + }, + { + "type": "image_embeds", + "image_embeds": base64_image_embedding_2, + }, + {"type": "text", "text": "What do these images show?"}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat do these images show?", + } + ] + + # Await the future and verify mm_data + mm_data = await mm_future + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 2 + + # Verify each embedding has the correct shape + assert isinstance(mm_data["image"][0], torch.Tensor) + assert mm_data["image"][0].shape == image_embedding_1.shape + assert isinstance(mm_data["image"][1], torch.Tensor) + assert mm_data["image"][1].shape == image_embedding_2.shape + + # Verify UUIDs + _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) + + @pytest.mark.asyncio async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( phi3v_model_config_image_embeds, - phi3v_tokenizer, ): uuid = "abcd" conversation, mm_future, mm_uuids = parse_chat_messages_futures( @@ -1005,7 +1115,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( } ], phi3v_model_config_image_embeds, - phi3v_tokenizer, content_format="string", ) @@ -1018,14 +1127,108 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( mm_data = await mm_future assert mm_data is not None assert "image" in mm_data - assert mm_data["image"] is None + assert isinstance(mm_data["image"], list) + assert len(mm_data["image"]) == 1 + assert mm_data["image"][0] is None + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) +def test_parse_chat_messages_empty_dict_image_embeds( + phi3v_model_config_image_embeds, +): + """Test that empty dictionary for image_embeds is handled without errors.""" + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + {"type": "image_embeds", "image_embeds": {}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + } + ] + + # Verify mm_data contains an empty dictionary of embeddings + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], dict) + assert len(mm_data["image"]) == 0 + + # Verify UUIDs (None since we didn't provide any) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + +def test_parse_chat_messages_multiple_dict_image_embeds( + phi3v_model_config_image_embeds, +): + """Test that multiple dictionaries for image_embeds is handled without errors.""" + # Create two sample image embedding tensors + batch_size = 2 + image_embedding_1 = torch.randn(batch_size, 256, 1024) + image_embedding_2 = torch.randn(batch_size, 3) + + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": { + "image_embedding_1": tensor2base64(p), + "image_embedding_2": tensor2base64(i), + }, + } + for p, i in zip(image_embedding_1, image_embedding_2) + ] + + [ + {"type": "text", "text": "Describe these two images."}, + ], + } + ], + phi3v_model_config_image_embeds, + content_format="string", + ) + + # Verify conversation structure + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nDescribe these two images.", + } + ] + + # Verify mm_data contains a dictionary of multi-embeddings + assert mm_data is not None + assert "image" in mm_data + assert isinstance(mm_data["image"], dict) + assert len(mm_data["image"]) == batch_size + + # Verify each embedding has the correct shape + assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor) + assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape + assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor) + assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape + + # Verify UUIDs (None since we didn't provide any) + _assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None]) + + @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_future, mm_uuids = parse_chat_messages_futures( @@ -1043,7 +1246,6 @@ async def test_parse_chat_messages_multiple_images_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1059,7 +1261,6 @@ async def test_parse_chat_messages_multiple_images_async( def test_parse_chat_messages_placeholder_already_in_prompt( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1077,7 +1278,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) assert conversation == [ @@ -1092,7 +1292,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt( def test_parse_chat_messages_placeholder_one_already_in_prompt( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1111,7 +1310,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1128,7 +1326,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( def test_parse_chat_messages_multiple_images_across_messages( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1150,7 +1347,6 @@ def test_parse_chat_messages_multiple_images_across_messages( }, ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1165,7 +1361,6 @@ def test_parse_chat_messages_multiple_images_across_messages( def test_parse_chat_messages_multiple_images_with_uuids_across_messages( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -1196,7 +1391,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages( }, ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1211,7 +1405,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages( def test_parse_chat_messages_context_text_format( phi3v_model_config, - phi3v_tokenizer, ): conversation, mm_data, mm_uuids = parse_chat_messages( [ @@ -1223,7 +1416,6 @@ def test_parse_chat_messages_context_text_format( {"role": "user", "content": "What about this one?"}, ], phi3v_model_config, - phi3v_tokenizer, content_format="openai", ) @@ -1247,7 +1439,6 @@ def test_parse_chat_messages_context_text_format( def test_parse_chat_messages_rejects_too_many_images_in_one_message( phi3v_model_config, - phi3v_tokenizer, image_url, ): with warnings.catch_warnings(): @@ -1278,14 +1469,12 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) def test_parse_chat_messages_rejects_too_many_images_across_messages( phi3v_model_config, - phi3v_tokenizer, image_url, ): with warnings.catch_warnings(): @@ -1323,14 +1512,12 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( }, ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) def test_parse_chat_messages_multiple_images_uncommon_input( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1345,7 +1532,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1361,7 +1547,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input( def test_parse_chat_messages_multiple_images_interleave( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1381,7 +1566,6 @@ def test_parse_chat_messages_multiple_images_interleave( } ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1399,7 +1583,6 @@ def test_parse_chat_messages_multiple_images_interleave( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_interleave_async( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages_futures( @@ -1419,7 +1602,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async( } ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1437,7 +1619,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -1466,7 +1647,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( } ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1483,7 +1663,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( def test_parse_chat_messages_multiple_images_multiple_messages_interleave( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1506,7 +1685,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( }, ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1524,7 +1702,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -1556,7 +1733,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl }, ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1574,7 +1750,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, image_url, video_url, audio_url, @@ -1602,7 +1777,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( }, ], qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, content_format="string", ) @@ -1628,7 +1802,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave( qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, image_url, video_url, audio_url, @@ -1672,7 +1845,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl }, ], qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, content_format="string", ) @@ -1700,7 +1872,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501 qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, image_url, video_url, audio_url, @@ -1744,7 +1915,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes }, ], qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, content_format="string", ) @@ -1776,7 +1946,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501 qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, image_url, video_url, audio_url, @@ -1812,7 +1981,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message }, ], qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, content_format="string", ) @@ -1838,7 +2006,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message def test_parse_chat_messages_multiple_images_interleave_with_placeholders( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): with pytest.raises( @@ -1862,7 +2029,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( } ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -2238,9 +2404,7 @@ def test_resolve_content_format_examples(template_path, expected_format): assert resolved_format == expected_format -def test_parse_chat_messages_include_thinking_chunk( - mistral_model_config, mistral_tokenizer -): +def test_parse_chat_messages_include_thinking_chunk(mistral_model_config): messages = [ { "role": "system", @@ -2270,7 +2434,6 @@ def test_parse_chat_messages_include_thinking_chunk( conversation_with_thinking, _, _ = parse_chat_messages( messages, mistral_model_config, - mistral_tokenizer, content_format="openai", ) @@ -2354,7 +2517,6 @@ def test_apply_mistral_chat_template_thinking_chunk(): def test_parse_chat_messages_single_empty_audio_with_uuid( qwen2_audio_model_config, - qwen2_audio_tokenizer, ): audio_uuid = "abcd" conversation, mm_data, mm_uuids = parse_chat_messages( @@ -2372,7 +2534,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid( } ], qwen2_audio_model_config, - qwen2_audio_tokenizer, content_format="string", ) @@ -2390,7 +2551,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid( @pytest.mark.asyncio async def test_parse_chat_messages_single_empty_audio_with_uuid_async( qwen2_audio_model_config, - qwen2_audio_tokenizer, ): audio_uuid = "abcd" conversation, mm_future, mm_uuids = parse_chat_messages_futures( @@ -2408,7 +2568,6 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async( } ], qwen2_audio_model_config, - qwen2_audio_tokenizer, content_format="string", ) diff --git a/tests/entrypoints/test_harmony_utils.py b/tests/entrypoints/test_harmony_utils.py deleted file mode 100644 index 6fa051a678d68..0000000000000 --- a/tests/entrypoints/test_harmony_utils.py +++ /dev/null @@ -1,266 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from openai_harmony import Role - -from vllm.entrypoints.harmony_utils import ( - has_custom_tools, - parse_input_to_harmony_message, -) - - -class TestParseInputToHarmonyMessage: - """Tests for parse_input_to_harmony_message function.""" - - def test_assistant_message_with_tool_calls(self): - """Test parsing assistant message with tool calls.""" - chat_msg = { - "role": "assistant", - "tool_calls": [ - { - "function": { - "name": "get_weather", - "arguments": '{"location": "San Francisco"}', - } - }, - { - "function": { - "name": "search_web", - "arguments": '{"query": "latest news"}', - } - }, - ], - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 2 - - # First tool call - assert messages[0].author.role == Role.ASSISTANT - assert messages[0].content[0].text == '{"location": "San Francisco"}' - assert messages[0].channel == "commentary" - assert messages[0].recipient == "functions.get_weather" - assert messages[0].content_type == "json" - - # Second tool call - assert messages[1].author.role == Role.ASSISTANT - assert messages[1].content[0].text == '{"query": "latest news"}' - assert messages[1].channel == "commentary" - assert messages[1].recipient == "functions.search_web" - assert messages[1].content_type == "json" - - def test_assistant_message_with_empty_tool_call_arguments(self): - """Test parsing assistant message with tool call having None arguments.""" - chat_msg = { - "role": "assistant", - "tool_calls": [ - { - "function": { - "name": "get_current_time", - "arguments": None, - } - } - ], - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].content[0].text == "" - assert messages[0].recipient == "functions.get_current_time" - - def test_tool_message_with_string_content(self): - """Test parsing tool message with string content.""" - chat_msg = { - "role": "tool", - "name": "get_weather", - "content": "The weather in San Francisco is sunny, 72°F", - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].author.role == Role.TOOL - assert messages[0].author.name == "functions.get_weather" - assert ( - messages[0].content[0].text == "The weather in San Francisco is sunny, 72°F" - ) - assert messages[0].channel == "commentary" - - def test_tool_message_with_array_content(self): - """Test parsing tool message with array content.""" - chat_msg = { - "role": "tool", - "name": "search_results", - "content": [ - {"type": "text", "text": "Result 1: "}, - {"type": "text", "text": "Result 2: "}, - { - "type": "image", - "url": "http://example.com/img.png", - }, # Should be ignored - {"type": "text", "text": "Result 3"}, - ], - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].author.role == Role.TOOL - assert messages[0].content[0].text == "Result 1: Result 2: Result 3" - - def test_tool_message_with_empty_content(self): - """Test parsing tool message with None content.""" - chat_msg = { - "role": "tool", - "name": "empty_tool", - "content": None, - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].content[0].text == "" - - def test_system_message(self): - """Test parsing system message.""" - chat_msg = { - "role": "system", - "content": "You are a helpful assistant", - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - # System messages are converted using Message.from_dict - # which should preserve the role - assert messages[0].author.role == Role.SYSTEM - - def test_developer_message(self): - """Test parsing developer message.""" - chat_msg = { - "role": "developer", - "content": "Use concise language", - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].author.role == Role.DEVELOPER - - def test_user_message_with_string_content(self): - """Test parsing user message with string content.""" - chat_msg = { - "role": "user", - "content": "What's the weather in San Francisco?", - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].author.role == Role.USER - assert messages[0].content[0].text == "What's the weather in San Francisco?" - - def test_user_message_with_array_content(self): - """Test parsing user message with array content.""" - chat_msg = { - "role": "user", - "content": [ - {"text": "What's in this image? "}, - {"text": "Please describe it."}, - ], - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].author.role == Role.USER - assert len(messages[0].content) == 2 - assert messages[0].content[0].text == "What's in this image? " - assert messages[0].content[1].text == "Please describe it." - - def test_assistant_message_with_string_content(self): - """Test parsing assistant message with string content (no tool calls).""" - chat_msg = { - "role": "assistant", - "content": "Hello! How can I help you today?", - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].author.role == Role.ASSISTANT - assert messages[0].content[0].text == "Hello! How can I help you today?" - - def test_pydantic_model_input(self): - """Test parsing Pydantic model input (has model_dump method).""" - - class MockPydanticModel: - def model_dump(self, exclude_none=True): - return { - "role": "user", - "content": "Test message", - } - - chat_msg = MockPydanticModel() - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].author.role == Role.USER - assert messages[0].content[0].text == "Test message" - - def test_message_with_empty_content(self): - """Test parsing message with empty string content.""" - chat_msg = { - "role": "user", - "content": "", - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].content[0].text == "" - - def test_tool_call_with_missing_function_fields(self): - """Test parsing tool call with missing name or arguments.""" - chat_msg = { - "role": "assistant", - "tool_calls": [ - { - "function": {} # Missing both name and arguments - } - ], - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert messages[0].recipient == "functions." - assert messages[0].content[0].text == "" - - def test_array_content_with_missing_text(self): - """Test parsing array content where text field is missing.""" - chat_msg = { - "role": "user", - "content": [ - {}, # Missing text field - {"text": "actual text"}, - ], - } - - messages = parse_input_to_harmony_message(chat_msg) - - assert len(messages) == 1 - assert len(messages[0].content) == 2 - assert messages[0].content[0].text == "" - assert messages[0].content[1].text == "actual text" - - -def test_has_custom_tools() -> None: - assert not has_custom_tools(set()) - assert not has_custom_tools({"web_search_preview", "code_interpreter", "container"}) - assert has_custom_tools({"others"}) - assert has_custom_tools( - {"web_search_preview", "code_interpreter", "container", "others"} - ) diff --git a/tests/entrypoints/test_responses_utils.py b/tests/entrypoints/test_responses_utils.py index 893d806b65742..a522967111307 100644 --- a/tests/entrypoints/test_responses_utils.py +++ b/tests/entrypoints/test_responses_utils.py @@ -2,9 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall from openai.types.responses.response_function_tool_call_output_item import ( ResponseFunctionToolCallOutputItem, ) +from openai.types.responses.response_output_message import ResponseOutputMessage +from openai.types.responses.response_output_text import ResponseOutputText from openai.types.responses.response_reasoning_item import ( Content, ResponseReasoningItem, @@ -12,7 +15,8 @@ from openai.types.responses.response_reasoning_item import ( ) from vllm.entrypoints.responses_utils import ( - construct_chat_message_with_tool_call, + _construct_single_message_from_response_item, + construct_chat_messages_with_tool_call, convert_tool_responses_to_completions_format, ) @@ -40,7 +44,43 @@ class TestResponsesUtils: assert result == {"type": "function", "function": input_tool} - def test_construct_chat_message_with_tool_call(self): + def test_construct_chat_messages_with_tool_call(self): + """Test construction of chat messages with tool calls.""" + reasoning_item = ResponseReasoningItem( + id="lol", + summary=[], + type="reasoning", + content=[ + Content( + text="Leroy Jenkins", + type="reasoning_text", + ) + ], + encrypted_content=None, + status=None, + ) + mcp_tool_item = ResponseFunctionToolCall( + id="mcp_123", + call_id="call_123", + type="function_call", + status="completed", + name="python", + arguments='{"code": "123+456"}', + ) + input_items = [reasoning_item, mcp_tool_item] + messages = construct_chat_messages_with_tool_call(input_items) + + assert len(messages) == 1 + message = messages[0] + assert message["role"] == "assistant" + assert message["reasoning"] == "Leroy Jenkins" + assert message["tool_calls"][0]["id"] == "call_123" + assert message["tool_calls"][0]["function"]["name"] == "python" + assert ( + message["tool_calls"][0]["function"]["arguments"] == '{"code": "123+456"}' + ) + + def test_construct_single_message_from_response_item(self): item = ResponseReasoningItem( id="lol", summary=[], @@ -54,7 +94,7 @@ class TestResponsesUtils: encrypted_content=None, status=None, ) - formatted_item = construct_chat_message_with_tool_call(item) + formatted_item = _construct_single_message_from_response_item(item) assert formatted_item["role"] == "assistant" assert formatted_item["reasoning"] == "Leroy Jenkins" @@ -72,7 +112,7 @@ class TestResponsesUtils: status=None, ) - formatted_item = construct_chat_message_with_tool_call(item) + formatted_item = _construct_single_message_from_response_item(item) assert formatted_item["role"] == "assistant" assert ( formatted_item["reasoning"] @@ -86,7 +126,7 @@ class TestResponsesUtils: output="1234", status="completed", ) - formatted_item = construct_chat_message_with_tool_call(tool_call_output) + formatted_item = _construct_single_message_from_response_item(tool_call_output) assert formatted_item["role"] == "tool" assert formatted_item["content"] == "1234" assert formatted_item["tool_call_id"] == "temp" @@ -100,4 +140,23 @@ class TestResponsesUtils: status=None, ) with pytest.raises(ValueError): - construct_chat_message_with_tool_call(item) + _construct_single_message_from_response_item(item) + + output_item = ResponseOutputMessage( + id="msg_bf585bbbe3d500e0", + content=[ + ResponseOutputText( + annotations=[], + text="dongyi", + type="output_text", + logprobs=None, + ) + ], + role="assistant", + status="completed", + type="message", + ) + + formatted_item = _construct_single_message_from_response_item(output_item) + assert formatted_item["role"] == "assistant" + assert formatted_item["content"] == "dongyi" diff --git a/tests/kernels/attention/test_cpu_attn.py b/tests/kernels/attention/test_cpu_attn.py index fb3b1799ba48e..be5d66197f6ef 100644 --- a/tests/kernels/attention/test_cpu_attn.py +++ b/tests/kernels/attention/test_cpu_attn.py @@ -7,7 +7,8 @@ import math import pytest import torch -from vllm.platforms import current_platform +from vllm.platforms import CpuArchEnum, current_platform +from vllm.v1.attention.backends.cpu_attn import _get_attn_isa if not current_platform.is_cpu(): pytest.skip("skipping CPU-only tests", allow_module_level=True) @@ -36,6 +37,21 @@ SEQ_LENS = [ # (q_len, kv_len) ] +def get_attn_isa( + block_size: int | None = None, + dtype: torch.dtype | None = None, +): + if block_size and dtype: + return _get_attn_isa(dtype, block_size) + else: + if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: + return "neon" + elif torch._C._cpu._is_amx_tile_supported(): + return "amx" + else: + return "vec" + + # rand number generation takes too much time, cache rand tensors @functools.lru_cache(maxsize=128, typed=False) def tensor_cache( @@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16( ) +@pytest.mark.parametrize("seq_lens", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", [96, 128]) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) +@pytest.mark.parametrize("dtype", QTYPES) +@pytest.mark.parametrize("soft_cap", [None]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("use_alibi", [False]) +@pytest.mark.parametrize("use_sink", [False]) +@pytest.mark.parametrize("isa", ["neon"]) +@pytest.mark.skipif( + current_platform.get_cpu_architecture() != CpuArchEnum.ARM, + reason="Not an Arm CPU.", +) +def test_varlen_with_paged_kv_normal_neon( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + sliding_window: int | None, + dtype: torch.dtype, + block_size: int, + soft_cap: float | None, + num_blocks: int, + use_alibi: bool, + use_sink: bool, + isa: str, +) -> None: + varlen_with_paged_kv( + seq_lens=seq_lens, + num_heads=num_heads, + head_size=head_size, + sliding_window=sliding_window, + dtype=dtype, + block_size=block_size, + soft_cap=soft_cap, + num_blocks=num_blocks, + use_alibi=use_alibi, + use_sink=use_sink, + isa=isa, + ) + + @pytest.mark.parametrize("seq_lens", SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", [96]) @@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("use_alibi", [False]) @pytest.mark.parametrize("use_sink", [False]) -@pytest.mark.parametrize( - "isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"] -) +@pytest.mark.parametrize("isa", [get_attn_isa()]) def test_varlen_with_paged_kv_softcap( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], @@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("use_alibi", [True]) @pytest.mark.parametrize("use_sink", [False]) -@pytest.mark.parametrize( - "isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"] -) +@pytest.mark.parametrize("isa", [get_attn_isa()]) def test_varlen_with_paged_kv_alibi( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], @@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("use_alibi", [False]) @pytest.mark.parametrize("use_sink", [True]) -@pytest.mark.parametrize( - "isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"] -) +@pytest.mark.parametrize("isa", [get_attn_isa()]) def test_varlen_with_paged_kv_sink( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], diff --git a/tests/kernels/attention/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py index a60f4e385a893..784c16304a286 100644 --- a/tests/kernels/attention/test_cutlass_mla_decode.py +++ b/tests/kernels/attention/test_cutlass_mla_decode.py @@ -32,8 +32,8 @@ def cal_diff( CUTLASS_MLA_UNSUPPORTED_REASON = ( - "Cutlass MLA Requires compute capability of 10 or above." - if not current_platform.is_device_capability(100) + "Cutlass MLA Requires compute capability of 100 or above." + if not current_platform.is_device_capability_family(100) else "Cutlass MLA is supported" ) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 98ea40608b468..06a7085a82ba0 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -11,7 +11,7 @@ from tests.kernels.quantization.nvfp4_utils import ( from vllm.platforms import current_platform from vllm.utils.math_utils import round_up -if not current_platform.is_device_capability(100): +if not current_platform.is_device_capability_family(100): pytest.skip( "This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True ) @@ -443,7 +443,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: - rtol, atol = 1e-1, 2e-1 + rtol, atol = 3e-1, 4e-1 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 4e-2, 6e-2 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index ae3c63cc62d6b..639abdf6f0487 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -26,7 +26,14 @@ def clear_cache(): _cached_get_attn_backend.cache_clear() -@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) +devices = ["cpu"] +if current_platform.is_cuda(): + devices.append("cuda") +if current_platform.is_rocm(): + devices.append("hip") + + +@pytest.mark.parametrize("device", devices) def test_mha_attn_platform(device: str): """ Test the attention selector between different platform and device. @@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN else: # Test CUDA with head_size=64 (divisible by 32) # - should use vLLM's FlashAttention diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index bf4d2179af5f9..7fb08e5780f51 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -7,6 +7,7 @@ import torch from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.platforms import current_platform +from vllm.utils.math_utils import next_power_of_2 NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] @@ -22,6 +23,10 @@ QDTYPES = ( # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] +# 0: use 2D kernel for decode +# 8: use 3D kernel for decode +SEQ_THRESHOLD_3D_VALUES = [0, 8] + def ref_paged_attn( query: torch.Tensor, @@ -92,6 +97,7 @@ def ref_paged_attn( @pytest.mark.parametrize("soft_cap", [None, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) +@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES) @torch.inference_mode() def test_triton_unified_attn( seq_lens: list[tuple[int, int]], @@ -103,6 +109,7 @@ def test_triton_unified_attn( soft_cap: float | None, num_blocks: int, q_dtype: torch.dtype | None, + seq_threshold_3D: int, ) -> None: torch.set_default_device("cuda") @@ -152,6 +159,21 @@ def test_triton_unified_attn( k_descale = torch.rand(scale_shape, dtype=torch.float32) v_descale = torch.rand(scale_shape, dtype=torch.float32) + num_par_softmax_segments = 16 + head_size_padded = next_power_of_2(head_size) + softmax_segm_output = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded), + dtype=torch.float32, + ) + softmax_segm_max = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + softmax_segm_expsum = torch.empty( + (seq_threshold_3D, num_query_heads, num_par_softmax_segments), + dtype=torch.float32, + ) + unified_attention( q=maybe_quantized_query, k=maybe_quantized_key_cache, @@ -169,6 +191,11 @@ def test_triton_unified_attn( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, ) ref_output = ref_paged_attn( diff --git a/tests/kernels/core/test_apply_rotary_emb.py b/tests/kernels/core/test_apply_rotary_emb.py new file mode 100644 index 0000000000000..23c722fa5e638 --- /dev/null +++ b/tests/kernels/core/test_apply_rotary_emb.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for ApplyRotaryEmb CustomOp dispatch behavior. + +This test ensures that RotaryEmbedding classes correctly call the appropriate +ApplyRotaryEmb methods based on the calling context: + +1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native() +2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch) +3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch) +""" + +from dataclasses import dataclass + +import pytest +import torch + +from vllm.config import ( + CompilationConfig, + VllmConfig, + get_cached_compilation_config, + set_current_vllm_config, +) +from vllm.platforms import current_platform + +CUDA_DEVICES = ["cuda:0"] + + +@dataclass +class RotaryEmbeddingTestCase: + """Test case configuration for RotaryEmbedding dispatch tests.""" + + name: str + rope_class: type + rope_kwargs: dict + method_name: str # forward_native, forward_cuda, forward + positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens) + expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native() + expect_forward: bool # Should call ApplyRotaryEmb.forward() + + +def get_test_cases() -> list[RotaryEmbeddingTestCase]: + """Generate test cases for all RotaryEmbedding classes.""" + from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding, + ) + from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding + from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding + + common_kwargs = { + "head_size": 128, + "rotary_dim": 128, + "max_position_embeddings": 4096, + "base": 10000, + "is_neox_style": True, + "dtype": torch.bfloat16, + } + + return [ + # MRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="MRotaryEmbedding.forward_native", + rope_class=MRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]}, + method_name="forward_native", + positions_shape=(3, 32), # 2D for multimodal + expect_forward_native=True, + expect_forward=False, + ), + RotaryEmbeddingTestCase( + name="MRotaryEmbedding.forward_cuda_1d", + rope_class=MRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]}, + method_name="forward_cuda", + positions_shape=(32,), # 1D triggers apply_rotary_emb path + expect_forward_native=False, + expect_forward=True, + ), + # XDRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="XDRotaryEmbedding.forward", + rope_class=XDRotaryEmbedding, + rope_kwargs={ + **common_kwargs, + "scaling_alpha": 1.0, + "xdrope_section": [16, 16, 16, 16], + }, + method_name="forward", + positions_shape=(4, 32), # 4D for P/W/H/T + expect_forward_native=False, + expect_forward=True, + ), + # Ernie4_5_VLRotaryEmbedding tests + RotaryEmbeddingTestCase( + name="Ernie4_5_VLRotaryEmbedding.forward_native", + rope_class=Ernie4_5_VLRotaryEmbedding, + rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]}, + method_name="forward_native", + positions_shape=(3, 32), # 2D for multimodal + expect_forward_native=True, + expect_forward=False, + ), + ] + + +def run_dispatch_test( + test_case: RotaryEmbeddingTestCase, + device: str, +): + """Run a dispatch test for a RotaryEmbedding class.""" + vllm_config = VllmConfig( + compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"]) + ) + get_cached_compilation_config.cache_clear() + + with set_current_vllm_config(vllm_config): + rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device) + + apply_rotary_emb = rope.apply_rotary_emb + + # Verify custom op is enabled + if test_case.expect_forward_native: + assert ( + apply_rotary_emb._forward_method != apply_rotary_emb.forward_native + ), "Test setup error: ApplyRotaryEmb custom op should be enabled" + + # Setup call tracking + call_tracker = {"forward_native_called": False, "forward_called": False} + original_forward_native = apply_rotary_emb.forward_native + original_forward = apply_rotary_emb.forward + + def tracked_forward_native(*args, **kwargs): + call_tracker["forward_native_called"] = True + return original_forward_native(*args, **kwargs) + + def tracked_forward(*args, **kwargs): + call_tracker["forward_called"] = True + return original_forward(*args, **kwargs) + + apply_rotary_emb.forward_native = tracked_forward_native + apply_rotary_emb.forward = tracked_forward + + try: + num_tokens = test_case.positions_shape[-1] + num_q_heads = 8 + num_kv_heads = 2 + head_size = test_case.rope_kwargs["head_size"] + max_position = test_case.rope_kwargs["max_position_embeddings"] + + positions = torch.randint( + 0, max_position // 4, test_case.positions_shape, device=device + ) + query = torch.randn( + num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device + ) + key = torch.randn( + num_tokens, + num_kv_heads * head_size, + dtype=torch.bfloat16, + device=device, + ) + + # Call the method under test + method = getattr(rope, test_case.method_name) + method(positions, query.clone(), key.clone()) + + # Verify expectations + if test_case.expect_forward_native: + assert call_tracker["forward_native_called"], ( + f"{test_case.name} should call ApplyRotaryEmb.forward_native()" + ) + if not test_case.expect_forward: + assert not call_tracker["forward_called"], ( + f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). " + "Bug: when +apply_rotary_emb is enabled, forward_native() " + "incorrectly dispatches to CUDA/HIP kernels." + ) + if test_case.expect_forward: + assert call_tracker["forward_called"], ( + f"{test_case.name} should call ApplyRotaryEmb.forward()" + ) + finally: + apply_rotary_emb.forward_native = original_forward_native + apply_rotary_emb.forward = original_forward + + +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests." +) +@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_rotary_embedding_dispatch( + test_case: RotaryEmbeddingTestCase, + device: str, +): + """ + Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method. + + - forward_native methods should call ApplyRotaryEmb.forward_native() + - forward_cuda/forward methods should call ApplyRotaryEmb.forward() + """ + run_dispatch_test(test_case, device) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index b5fc653ca7353..094073f5d3f92 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -8,6 +8,12 @@ import torch import vllm._custom_ops as ops from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, +) DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] @@ -21,6 +27,7 @@ NUM_TOKENS_HIDDEN_SIZES = [ ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] +GROUP_SIZES = [None, [1, 64], [1, 128]] SEEDS = [0] CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @@ -45,12 +52,13 @@ def ref_rms_norm( return out, residual -def ref_dynamic_per_token_quant( +def ref_dynamic_per_token_or_block_quant( rms_norm_layer: RMSNorm, x: torch.Tensor, quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -59,13 +67,24 @@ def ref_dynamic_per_token_quant( torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual) # Quant - if quant_dtype == torch.float8_e4m3fn: - torch_out, scales = ops.scaled_fp8_quant( - torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True - ) + if group_size is not None: + if quant_dtype == torch.float8_e4m3fn: + torch_out, scales = per_token_group_quant_fp8( + torch_out, group_size=group_size[1], use_ue8m0=False + ) + else: + assert quant_dtype == torch.int8 + torch_out, scales = per_token_group_quant_int8( + torch_out, group_size=group_size[1] + ) else: - assert quant_dtype == torch.int8 - torch_out, scales, _ = ops.scaled_int8_quant(torch_out) + if quant_dtype == torch.float8_e4m3fn: + torch_out, scales = ops.scaled_fp8_quant( + torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True + ) + else: + assert quant_dtype == torch.int8 + torch_out, scales, _ = ops.scaled_int8_quant(torch_out) return torch_out, scales, residual @@ -76,24 +95,32 @@ def ref_impl( quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - return ref_dynamic_per_token_quant( - rms_norm_layer, x, quant_dtype, residual, scale_ub + return ref_dynamic_per_token_or_block_quant( + rms_norm_layer, x, quant_dtype, residual, scale_ub, group_size ) -def ops_dynamic_per_token_quant( +def ops_dynamic_per_token_or_block_quant( weight: torch.Tensor, x: torch.Tensor, quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: if residual is not None: residual = residual.clone() - out, scales = ops.rms_norm_dynamic_per_token_quant( - x, weight, EPS, quant_dtype, scale_ub, residual - ) + if group_size is not None: + out, scales = ops.rms_norm_per_block_quant( + x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True + ) + scales = scales.contiguous() + else: + out, scales = ops.rms_norm_dynamic_per_token_quant( + x, weight, EPS, quant_dtype, scale_ub, residual + ) return out, scales, residual @@ -103,8 +130,11 @@ def ops_impl( quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) + return ops_dynamic_per_token_or_block_quant( + weight, x, quant_dtype, residual, scale_ub, group_size + ) @pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) @@ -112,6 +142,7 @@ def ops_impl( @pytest.mark.parametrize("has_scale_ub", SCALE_UBS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("group_size", GROUP_SIZES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() @@ -122,6 +153,7 @@ def test_rms_norm( has_scale_ub: bool, dtype: torch.dtype, quant_dtype: torch.dtype, + group_size: list[int] | None, seed: int, device: str, ) -> None: @@ -130,6 +162,14 @@ def test_rms_norm( torch.cuda.manual_seed(seed) torch.set_default_device(device) + if group_size is not None and hidden_size % group_size[1] != 0: + # skip + return + + if group_size is not None and has_scale_ub: + # blockwise baseline doesn't support scale_ub + return + if has_scale_ub and quant_dtype != torch.float8_e4m3fn: # skip return @@ -150,10 +190,10 @@ def test_rms_norm( scale_ub = None ref_out, ref_scales, ref_residual = ref_impl( - layer, x, quant_dtype, residual, scale_ub + layer, x, quant_dtype, residual, scale_ub, group_size ) ops_out, ops_scales, ops_residual = ops_impl( - layer.weight, x, quant_dtype, residual, scale_ub + layer.weight, x, quant_dtype, residual, scale_ub, group_size ) assert ref_out.dtype == quant_dtype @@ -166,11 +206,15 @@ def test_rms_norm( assert torch.allclose(ref_scales, ops_scales) a = ref_out.to(dtype=torch.float32) b = ops_out.to(dtype=torch.float32) - ok = torch.allclose(a, b) + ok = torch.allclose(a, b, atol=1e-6) if not ok: # fallback: compare dequantized values with relaxed tolerance - a_deq = a * ref_scales.view(-1, 1) - b_deq = b * ops_scales.view(-1, 1) + if group_size is None: + a_deq = a * ref_scales.view(-1, 1) + b_deq = b * ops_scales.view(-1, 1) + else: + a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1) + b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1) # NOTE: It is possible that some future test cases trigger this # max diff due to precision issues. If such an error is # encountered, it's recommended to inspect the differences between diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index 43b242ab2d586..ba5d593b2d355 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -113,12 +113,9 @@ def test_mrope( is_neox_style = True max_position = config.max_position_embeddings - partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) - rotary_dim = int(head_dim * partial_rotary_factor) mrope_helper_class = get_rope( head_size=head_dim, - rotary_dim=rotary_dim, max_position=max_position, is_neox_style=is_neox_style, rope_parameters=config.rope_parameters, @@ -184,12 +181,9 @@ def test_mrope_torch_compile_tracing( ) is_neox_style = True max_position = config.max_position_embeddings - partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) - rotary_dim = int(head_dim * partial_rotary_factor) mrope_helper_class = get_rope( head_size=head_dim, - rotary_dim=rotary_dim, max_position=max_position, is_neox_style=is_neox_style, rope_parameters=config.rope_parameters, diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index a8ed3825689d3..d18f01314c8f5 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -83,8 +83,12 @@ def test_rotary_embedding( torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size - rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} - rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters) + rope_parameters = { + "rope_type": "default", + "rope_theta": rope_theta, + "partial_rotary_factor": rotary_dim / head_size, + } + rope = get_rope(head_size, max_position, is_neox_style, rope_parameters) rope = rope.to(dtype=dtype, device=torch.get_default_device()) positions = torch.randint(0, max_position, (batch_size, seq_len)) @@ -150,9 +154,9 @@ def test_rope_module_cache(): if rotary_dim is None: rotary_dim = head_size rope_parameters["rope_theta"] = rope_theta + rope_parameters["partial_rotary_factor"] = rotary_dim / head_size rope = get_rope( head_size, - rotary_dim, max_position, is_neox_style, rope_parameters, @@ -177,9 +181,9 @@ def test_rope_module_cache(): if rotary_dim is None: rotary_dim = head_size rope_parameters["rope_theta"] = rope_theta + rope_parameters["partial_rotary_factor"] = rotary_dim / head_size rope = get_rope( head_size, - rotary_dim, max_position, is_neox_style, rope_parameters, diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 98edc959957d0..50e48aad6ebaa 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype): assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +@pytest.mark.parametrize("max_seq_len", [1, 2, 4]) +def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 5e-2, 1.5e-1 + if torch.version.hip: + atol *= 2 + # set seed + current_platform.seed_everything(0) + batch_size = 4 + token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device) + total_tokens = int(token_counts.sum().item()) + cu_seqlens = torch.tensor( + [0] + torch.cumsum(token_counts, dim=0).tolist(), + dtype=torch.int32, + device=device, + ) + state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) + x = torch.randn(total_tokens, dim, device=device, dtype=itype) + out = torch.empty_like(x) + dt = torch.randn(total_tokens, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(total_tokens, dstate, device=device) + C = torch.randn(total_tokens, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state.detach().clone() + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + cu_seqlens=cu_seqlens, + ) + + out_ref_list = [] + for seq_idx in range(batch_size): + start_idx = cu_seqlens[seq_idx].item() + end_idx = cu_seqlens[seq_idx + 1].item() + num_tokens = end_idx - start_idx + for token_idx in range(num_tokens): + idx = start_idx + token_idx + out_ref_list.append( + selective_state_update_ref( + state_ref[seq_idx : seq_idx + 1], + x[idx : idx + 1], + dt[idx : idx + 1], + A, + B[idx : idx + 1], + C[idx : idx + 1], + D=D, + z=z[idx : idx + 1] if has_z else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + ) + out_ref = torch.cat(out_ref_list, dim=0) + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("wtype", [torch.float32]) @pytest.mark.parametrize("itype", [torch.float32]) @pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096]) @@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices( print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 64]) +@pytest.mark.parametrize("dim", [2048, 4096]) +@pytest.mark.parametrize("max_seq_len", [2, 4]) +def test_selective_state_update_with_num_accepted_tokens( + dim, dstate, has_z, itype, max_seq_len +): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 5e-2, 1.5e-1 + if torch.version.hip: + atol *= 2 + + current_platform.seed_everything(0) + batch_size = 4 + + tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device) + total_tokens = int(tokens_per_seq.sum().item()) + + num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device) + num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens + num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted + + cu_seqlens = torch.tensor( + [0] + torch.cumsum(tokens_per_seq, dim=0).tolist(), + dtype=torch.int32, + device=device, + ) + + total_state_slots = 50 + state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device) + + state_batch_indices = torch.full( + (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + initial_state_slots = torch.randint( + 0, 15, (batch_size,), device=device, dtype=torch.int32 + ) + for seq_idx in range(batch_size): + token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) + state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx] + + dst_state_batch_indices = torch.full( + (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + slot_offset = 15 + dst_slots_map = {} + for seq_idx in range(batch_size): + for token_idx in range(tokens_per_seq[seq_idx].item()): + dst_state_batch_indices[seq_idx, token_idx] = slot_offset + dst_slots_map[(seq_idx, token_idx)] = slot_offset + slot_offset += 1 + + x = torch.randn(total_tokens, dim, device=device, dtype=itype) + out = torch.empty_like(x) + dt = torch.randn(total_tokens, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(total_tokens, dstate, device=device) + C = torch.randn(total_tokens, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + + state_ref_intermediate = {} + out_ref_list = [] + + for seq_idx in range(batch_size): + seq_start = cu_seqlens[seq_idx].item() + seq_end = cu_seqlens[seq_idx + 1].item() + num_tokens = seq_end - seq_start + + token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) + initial_slot = state_batch_indices[seq_idx, token_pos].item() + state_seq = state[initial_slot : initial_slot + 1].clone() + + for token_idx in range(num_tokens): + global_idx = seq_start + token_idx + + out_token = selective_state_update_ref( + state_seq, + x[global_idx : global_idx + 1], + dt[global_idx : global_idx + 1], + A, + B[global_idx : global_idx + 1], + C[global_idx : global_idx + 1], + D=D, + z=z[global_idx : global_idx + 1] if has_z else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + out_ref_list.append(out_token) + state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone() + + out_ref = torch.cat(out_ref_list, dim=0) + + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + cu_seqlens=cu_seqlens, + state_batch_indices=state_batch_indices, + dst_state_batch_indices=dst_state_batch_indices, + num_accepted_tokens=num_accepted_tokens, + pad_slot_id=PAD_SLOT_ID, + ) + + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + for seq_idx in range(batch_size): + num_tokens = tokens_per_seq[seq_idx].item() + for token_idx in range(num_tokens): + dst_slot = dst_slots_map[(seq_idx, token_idx)] + state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0) + assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 64]) +@pytest.mark.parametrize("dim", [2048, 4096]) +@pytest.mark.parametrize("max_seq_len", [2, 4]) +def test_selective_state_update_varlen_with_num_accepted( + dim, dstate, has_z, itype, max_seq_len +): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 5e-2, 1.5e-1 + if torch.version.hip: + atol *= 2 + + current_platform.seed_everything(0) + batch_size = 4 + + tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device) + total_tokens = int(tokens_per_seq.sum().item()) + + num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device) + num_accepted_tokens[0] = 0 # Add edge-case of no accepted tokens + num_accepted_tokens[1] = max_seq_len # Add edge-case of all tokens accepted + + cu_seqlens = torch.tensor( + [0] + torch.cumsum(tokens_per_seq, dim=0).tolist(), + dtype=torch.int32, + device=device, + ) + + total_state_slots = 50 + state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device) + + state_batch_indices = torch.full( + (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + + initial_state_slots = torch.randint( + 0, 15, (batch_size,), device=device, dtype=torch.int32 + ) + for seq_idx in range(batch_size): + token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) + state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx] + + dst_state_batch_indices = torch.full( + (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + + slot_offset = 15 + dst_slots_map = {} + for seq_idx in range(batch_size): + for token_idx in range(tokens_per_seq[seq_idx].item()): + dst_state_batch_indices[seq_idx, token_idx] = slot_offset + dst_slots_map[(seq_idx, token_idx)] = slot_offset + slot_offset += 1 + + x = torch.randn(total_tokens, dim, device=device, dtype=itype) + out = torch.empty_like(x) + dt = torch.randn(total_tokens, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(total_tokens, dstate, device=device) + C = torch.randn(total_tokens, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + + state_ref_intermediate = {} + + for seq_idx in range(batch_size): + seq_start = cu_seqlens[seq_idx].item() + seq_end = cu_seqlens[seq_idx + 1].item() + num_tokens = seq_end - seq_start + + token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) + initial_slot = state_batch_indices[seq_idx, token_pos].item() + state_seq = state[initial_slot : initial_slot + 1].clone() + + for token_idx in range(num_tokens): + global_idx = seq_start + token_idx + + selective_state_update_ref( + state_seq, + x[global_idx : global_idx + 1], + dt[global_idx : global_idx + 1], + A, + B[global_idx : global_idx + 1], + C[global_idx : global_idx + 1], + D=D, + z=z[global_idx : global_idx + 1] if has_z else None, + dt_bias=dt_bias, + dt_softplus=True, + ) + + state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone() + + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + cu_seqlens=cu_seqlens, + state_batch_indices=state_batch_indices, + dst_state_batch_indices=dst_state_batch_indices, + num_accepted_tokens=num_accepted_tokens, + pad_slot_id=PAD_SLOT_ID, + ) + + for seq_idx in range(batch_size): + num_tokens = tokens_per_seq[seq_idx].item() + + for token_idx in range(num_tokens): + dst_slot = dst_slots_map[(seq_idx, token_idx)] + state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0) + + assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index d95c22fdf0a5b..6078ce44cee9f 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -594,7 +594,8 @@ def make_modular_kernel( ) modular_kernel = mk.FusedMoEModularKernel( - prepare_finalize=prepare_finalize, fused_experts=fused_experts + prepare_finalize=prepare_finalize, + fused_experts=fused_experts, ) return modular_kernel diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index d79fdfbe07af3..99b168dc75548 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( - BatchedTritonOrDeepGemmExperts, -) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported(): needs_matching_quant=False, needs_deep_gemm=True, ) - register_experts( - BatchedTritonOrDeepGemmExperts, - batched_format, - common_float_and_int_types, - blocked_quantization_support=True, - supports_chunking=False, - supports_expert_map=False, - needs_matching_quant=True, - needs_deep_gemm=True, - ) register_experts( TritonOrDeepGemmExperts, standard_format, @@ -457,10 +444,6 @@ def make_fused_experts( kwargs = batch_kwargs | quant_kwargs print(f"Making BatchedTritonExperts {kwargs} ...") experts = BatchedTritonExperts(**kwargs) - elif fused_experts_type == BatchedTritonOrDeepGemmExperts: - kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs - print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") - experts = BatchedTritonOrDeepGemmExperts(**kwargs) elif fused_experts_type == DeepGemmExperts: print(f"Making DeepGemmExperts {quant_config} ...") experts = DeepGemmExperts(quant_config) diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index 59cecd60d3d61..0ba3d8d4c958e 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -27,7 +27,7 @@ BLOCK_SIZE = [128, 128] @pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert @pytest.mark.parametrize("topk", [2, 4]) def test_batched_deepgemm_vs_triton( - E: int, T: int, K: int, N: int, topk: int, monkeypatch + E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init ): """Compare BatchedDeepGemmExperts to BatchedTritonExperts.""" diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index dab1207d78031..2ef170f1ab308 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -248,6 +248,7 @@ def test_fused_moe_batched_experts( per_act_token_quant: bool, block_shape: list[int] | None, input_scales: bool, + workspace_init, ): """Note: float8_e4m3fn is not supported on CUDA architecture < 89, and those tests will be skipped on unsupported hardware.""" diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index b0ff1e64e3219..53a03f48e24ee 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -137,7 +137,7 @@ def setup_cuda(): @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe( - M, N, K, E, topk, block_size, dtype, seed, monkeypatch + M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init ): if topk > E: pytest.skip(f"Skipping test; topk={topk} > E={E}") diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index c15837f145705..0160694d7bb54 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -274,6 +274,7 @@ def test_cutlass_moe_8_bit_no_graph( per_act_token: bool, per_out_ch: bool, monkeypatch, + workspace_init, ep_size: int | None = None, ): current_platform.seed_everything(7) @@ -329,6 +330,7 @@ def test_cutlass_moe_8_bit_cuda_graph( per_act_token: bool, per_out_ch: bool, monkeypatch, + workspace_init, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") @@ -385,9 +387,19 @@ def test_cutlass_moe_8_bit_EP( per_out_channel: bool, ep_size: int, monkeypatch, + workspace_init, ): test_cutlass_moe_8_bit_no_graph( - m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + m, + n, + k, + e, + topk, + per_act_token, + per_out_channel, + monkeypatch, + workspace_init, + ep_size, ) @@ -419,9 +431,19 @@ def test_cutlass_moe_8_bit_EP_large( per_out_channel: bool, ep_size: int, monkeypatch, + workspace_init, ): test_cutlass_moe_8_bit_no_graph( - m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + m, + n, + k, + e, + topk, + per_act_token, + per_out_channel, + monkeypatch, + workspace_init, + ep_size, ) @@ -445,6 +467,7 @@ def test_run_cutlass_moe_fp8( per_act_token: bool, per_out_channel: bool, ep_size: int, + workspace_init, ): current_platform.seed_everything(7) with set_current_vllm_config(vllm_config): diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 455ecacef5ec3..f427734ef09e2 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import ( is_deep_gemm_supported, ) from vllm.utils.import_utils import has_deep_ep, has_deep_gemm +from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -363,6 +364,9 @@ def _test_deepep_deepgemm_moe( w1_scale: torch.Tensor, w2_scale: torch.Tensor, ): + device = torch.device(f"cuda:{pgi.local_rank}") + init_workspace_manager(device) + current_platform.seed_everything(pgi.rank) w1 = w1.to(device=torch.cuda.current_device()) @@ -445,6 +449,7 @@ def test_ht_deepep_deepgemm_moe( topk: int, world_dp_size: tuple[int, int], disable_deepgemm_ue8m0, + workspace_init, ): """ Tests for High-Throughput DeepEP + DeepGemm integration. @@ -518,6 +523,7 @@ def test_ll_deepep_deepgemm_moe( block_size: list[int], world_dp_size: tuple[int, int], disable_deepgemm_ue8m0, + workspace_init, ): """ Tests for Low-Latency DeepEP + DeepGemm integration. diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index d78b8250463a9..e698ca92a1515 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep +from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -342,6 +343,9 @@ def _deep_ep_moe( use_fp8_dispatch: bool, per_act_token_quant: bool, ): + device = torch.device(f"cuda:{pgi.local_rank}") + init_workspace_manager(device) + if not low_latency_mode: assert not use_fp8_dispatch, ( "FP8 dispatch interface is available only in low-latency mode" @@ -437,6 +441,7 @@ def test_deep_ep_moe( topk: int, world_dp_size: tuple[int, int], per_act_token_quant: bool, + workspace_init, ): low_latency_mode = False use_fp8_dispatch = False @@ -492,6 +497,7 @@ def test_low_latency_deep_ep_moe( topk: int, world_dp_size: tuple[int, int], use_fp8_dispatch: bool, + workspace_init, ): low_latency_mode = True diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 9b1054f7d0ab8..442b561f8f315 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -143,7 +143,7 @@ NUM_EXPERTS = [32] @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") -def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch): +def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init): with monkeypatch.context() as mp: mp.setenv("VLLM_USE_DEEP_GEMM", "1") diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index a6977f222408d..bf4ef2d30466b 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import pytest import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -107,6 +108,19 @@ class TestData: layer.w2_input_scale = a2_scale layer.w13_weight_scale = w13_weight_scale layer.w2_weight_scale = w2_weight_scale + # Setup dummy config. + layer.moe_parallel_config = mk.FusedMoEParallelConfig( + tp_size=1, + pcp_size=1, + dp_size=1, + ep_size=1, + tp_rank=1, + pcp_rank=1, + dp_rank=1, + ep_rank=1, + use_ep=False, + all2all_backend="naive", + ) register_moe_scaling_factors(layer) @@ -206,6 +220,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( topk: int, activation: str, monkeypatch, + workspace_init, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index b2be03ecee2f1..133a8a4a30a60 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -51,7 +51,14 @@ MNK_FACTORS = [ @pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"]) @torch.inference_mode() def test_flashinfer_fp4_moe_no_graph( - m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + activation: str, + workspace_init, ): current_platform.seed_everything(7) with set_current_vllm_config( diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 98e80ec029777..384f43db479b5 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -269,7 +269,7 @@ class Case: ) @pytest.mark.parametrize("num_token", [2]) @pytest.mark.parametrize("tp", [1, 2, 4, 8]) -def test_equiv(num_token, a_dtype, w_dtype, tp): +def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init): from triton_kernels.tensor_details import layout if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"): diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 2a30ef2355529..6ebf1016c166c 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -16,6 +16,7 @@ from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.torch_utils import cuda_device_count_stateless +from vllm.v1.worker.workspace import init_workspace_manager from .modular_kernel_tools.common import ( Config, @@ -77,6 +78,10 @@ def rank_worker( weights: WeightTensors, verbose: bool, ): + # Initialize workspace manager in child process + device = torch.device(f"cuda:{pgi.local_rank}") + init_workspace_manager(device) + current_platform.seed_everything(pgi.rank) # sanity check @@ -300,6 +305,7 @@ def test_modular_kernel_combinations_singlegpu( chunk_size: int | None, world_size: int, pytestconfig, + workspace_init, ): """Note: float8_e4m3fn is not supported on CUDA architecture < 89, and those tests will be skipped on unsupported hardware.""" diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py index c8616f13bbf85..1abb08f878b2b 100644 --- a/tests/kernels/moe/test_modular_oai_triton_moe.py +++ b/tests/kernels/moe/test_modular_oai_triton_moe.py @@ -209,6 +209,7 @@ def test_oai_triton_moe( num_experts: int, topk: int, unfused: bool, + workspace_init, ): current_platform.seed_everything(0) ( diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index bacf6f37f2b08..ce99d9691fdc8 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -231,6 +231,7 @@ def test_fused_moe( padding: bool, chunk_size: int, monkeypatch, + workspace_init, ): current_platform.seed_everything(7) @@ -955,9 +956,22 @@ def test_fused_marlin_moe_with_bias(m): torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) -def test_moe_align_block_size_opcheck(): +@pytest.mark.parametrize("ep_size", [1, 2]) +def test_moe_align_block_size_opcheck(ep_size): num_experts = 4 block_size = 4 + + expert_map = None + if ep_size != 1: + local_num_experts = num_experts // ep_size + expert_ids = torch.randint( + 0, num_experts, (local_num_experts,), device="cuda", dtype=torch.int32 + ) + expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32) + expert_map[expert_ids] = torch.arange( + local_num_experts, device="cuda", dtype=torch.int32 + ) + topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda") max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) @@ -980,6 +994,7 @@ def test_moe_align_block_size_opcheck(): sorted_ids, expert_ids, num_tokens_post_pad, + expert_map, ), ) diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index 8975f00bd4c6e..1abfc11fb460e 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -106,6 +106,8 @@ def torch_moe_align_block_size( max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + if topk_ids.numel() < num_experts: + max_num_tokens_padded = topk_ids.numel() * block_size flattened_token_indices = torch.arange( topk_ids.numel(), device=topk_ids.device, dtype=torch.int32 @@ -126,6 +128,8 @@ def torch_moe_align_block_size( ) for expert_id in range(num_experts): original_count = expert_token_counts[expert_id] + if expert_map is not None and expert_map[expert_id] == -1: + continue if original_count > 0: expert_padded_counts[expert_id] = ( (original_count + block_size - 1) // block_size @@ -143,6 +147,9 @@ def torch_moe_align_block_size( current_pos = 0 current_block = 0 for expert_id in range(num_experts): + if expert_map is not None and expert_map[expert_id] == -1: + continue + expert_mask = sorted_expert_ids == expert_id expert_tokens = sorted_token_indices[expert_mask] num_expert_tokens = expert_tokens.shape[0] @@ -153,7 +160,13 @@ def torch_moe_align_block_size( ) expert_blocks_needed = expert_padded_counts[expert_id] // block_size - expert_ids[current_block : current_block + expert_blocks_needed] = expert_id + + expert_id_new = expert_id + if expert_map is not None: + expert_id_new = expert_map[expert_id] + expert_ids[current_block : current_block + expert_blocks_needed] = ( + expert_id_new + ) current_pos += expert_padded_counts[expert_id] current_block += expert_blocks_needed @@ -163,8 +176,6 @@ def torch_moe_align_block_size( [total_padded_tokens], dtype=torch.int32, device=topk_ids.device ) - if expert_map is not None: - expert_ids = expert_map[expert_ids] return sorted_token_ids, expert_ids, num_tokens_post_pad @@ -229,9 +240,9 @@ def test_moe_align_block_size( ) -@pytest.mark.parametrize("m", [16, 32]) +@pytest.mark.parametrize("m", [16, 32, 2048]) @pytest.mark.parametrize("topk", [2, 4]) -@pytest.mark.parametrize("num_experts", [8]) +@pytest.mark.parametrize("num_experts", [8, 64]) @pytest.mark.parametrize("block_size", [64]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_moe_align_block_size_with_expert_map( @@ -253,6 +264,7 @@ def test_moe_align_block_size_with_expert_map( block_size=block_size, num_experts=num_experts, expert_map=expert_map, + ignore_invalid_experts=True, ) golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( torch_moe_align_block_size( diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index aa544fe0e0f63..e67bd76a16181 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -40,7 +40,7 @@ MNK_FACTORS = [ @pytest.mark.parametrize("dtype", [torch.bfloat16]) @torch.inference_mode() def test_cutlass_fp4_moe_no_graph( - m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init ): current_platform.seed_everything(7) with set_current_vllm_config( diff --git a/tests/kernels/moe/test_ocp_mx_moe.py b/tests/kernels/moe/test_ocp_mx_moe.py index 91b508d4163cc..8fe471d124f43 100644 --- a/tests/kernels/moe/test_ocp_mx_moe.py +++ b/tests/kernels/moe/test_ocp_mx_moe.py @@ -17,7 +17,7 @@ QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( ) >= version.parse("0.8.99") TRTLLM_GEN_MXFP4_AVAILABLE = ( - current_platform.is_cuda() and current_platform.is_device_capability(100) + current_platform.is_cuda() and current_platform.is_device_capability_family(100) ) HOPPER_MXFP4_BF16_AVAILABLE = ( @@ -70,12 +70,12 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): f"{torch.cuda.device_count()}" ) - # `cuda_graph_sizes=[16]` to reduce load time. + # `cudagraph_capture_sizes=[16]` to reduce load time. with vllm_runner( model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy", - cuda_graph_sizes=[16], + cudagraph_capture_sizes=[16], ) as llm: # Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562 # def check_model(model): @@ -799,7 +799,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( @pytest.mark.skipif( not ( current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) and has_flashinfer() ), reason="NVIDIA GPU sm100 and flashinfer are required for this test", diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index f671b23d300ce..35e554e16cb38 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -46,6 +46,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ) from vllm.platforms import current_platform from vllm.utils.math_utils import round_up +from vllm.v1.worker.workspace import init_workspace_manager from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -181,6 +182,7 @@ def test_fused_moe_batched_experts( e: int, topk: int, dtype: torch.dtype, + workspace_init, ): current_platform.seed_everything(7) @@ -863,6 +865,9 @@ def _pplx_test_loop( make_weights: bool, test_fn: Callable, ): + device = torch.device(f"cuda:{pgi.local_rank}") + init_workspace_manager(device) + def format_result(msg, ex=None): if ex is not None: x = str(ex) diff --git a/tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py b/tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py new file mode 100644 index 0000000000000..e4617072cd52c --- /dev/null +++ b/tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _per_token_group_quant_fp8_colmajor, + silu_mul_per_token_group_quant_fp8_colmajor, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + +FLOAT8_DTYPE = torch.float8_e4m3fn +GROUP_SIZE = 128 + + +def reference_quant(x: torch.Tensor, use_ue8m0: bool): + """ + Reference triton quant kernel from, + vllm.model_executor.layers.quantization.utils.fp8_utils + """ + + x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE) + + # Allocate the scale tensor in column-major format. + shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + + M = x.numel() // GROUP_SIZE + N = GROUP_SIZE + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + finfo = torch.finfo(FLOAT8_DTYPE) + fp8_min = finfo.min + fp8_max = finfo.max + + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + GROUP_SIZE, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + return x_q, x_s + + +def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]: + T, N = x.size() + ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda") + torch.ops._C.silu_and_mul(ref_act_out, x) + return reference_quant(ref_act_out, use_ue8m0) + + +@pytest.mark.parametrize("T", [128, 256, 512]) +@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2]) +def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int): + current_platform.seed_everything(42) + + input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda") + + use_ue8m0 = is_deep_gemm_e8m0_used() + + # Test + output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor( + input, use_ue8m0=use_ue8m0 + ) + + # Reference + ref_output, ref_output_scales = reference(input, use_ue8m0) + + torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32)) + torch.testing.assert_close(output_scales, ref_output_scales) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 830d43569e98b..7927bd0d200d8 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -30,16 +30,11 @@ def ref_dynamic_per_token_quant( if quant_dtype == torch.int8 else torch.finfo(quant_dtype) ) - qtype_traits_max = ( - ROCM_FP8FNUZ_MAX - if current_platform.is_rocm() and current_platform.is_fp8_fnuz() - else qtype_traits.max - ) - qtype_traits_min = ( - -ROCM_FP8FNUZ_MAX - if current_platform.is_rocm() and current_platform.is_fp8_fnuz() - else qtype_traits.min + use_fp8fnuz = ( + current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype() ) + qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max + qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0) @@ -103,7 +98,7 @@ def ref_dynamic_per_tensor_fp8_quant( .clamp(fp8_traits_min, fp8_traits_max) .to(FP8_DTYPE) ) - return ref_out, ref_scale.view((1, 1)) + return ref_out, ref_scale.view(1) def native_w8a8_block_matmul( diff --git a/tests/kernels/quantization/test_awq.py b/tests/kernels/quantization/test_awq.py index efb62ca3799a9..3bf59dea30972 100644 --- a/tests/kernels/quantization/test_awq.py +++ b/tests/kernels/quantization/test_awq.py @@ -41,9 +41,9 @@ def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch): qweight = torch.randint( -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32 ) - scales = torch.randint( + scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16) + qzeros = torch.randint( -2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32 ) - qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16) split_k_iters = 8 - opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters)) + opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters)) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index d0e4f6554a91f..32c77b9a01ece 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -54,6 +54,10 @@ def setup_cuda(): torch.set_default_device("cuda") +@pytest.mark.skipif( + current_platform.is_fp8_fnuz(), + reason="This platform supports e4m3fnuz, not e4m3fn.", +) @pytest.mark.parametrize( "num_tokens,d,dtype,group_size,seed", itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS), @@ -78,14 +82,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.manual_seed(seed) factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_info = torch.finfo(current_platform.fp8_dtype()) fp8_max, fp8_min = fp8_info.max, fp8_info.min A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype()) B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype()) block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n @@ -103,6 +107,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform." +) @torch.inference_mode() def test_w8a8_block_fp8_cutlass_matmul(): # Test simple case where weight.shape % 128 != 0, @@ -151,6 +158,10 @@ def test_w8a8_block_fp8_cutlass_matmul(): assert rel_diff < 0.001 +@pytest.mark.skipif( + current_platform.is_fp8_fnuz(), + reason="This platform supports e4m3fnuz, not e4m3fn.", +) @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index de595b0a34e46..bc4744df7e69e 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -15,6 +15,9 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv +if not current_platform.is_cuda(): + pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True) + MNK_FACTORS = [ (1, 256, 128), (1, 16384, 1024), diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py index 465e24fd7eb97..8cfc993fe8e82 100644 --- a/tests/kernels/quantization/test_cutlass_w4a8.py +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -12,12 +12,18 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.quant_utils import ( + convert_packed_uint4b8_to_signed_int4_inplace, + pack_cols, pack_rows, quantize_weights, + unpack_quantized_values_into_int32, ) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +if not current_platform.is_cuda(): + pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True) + # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel # unit tests to a common utility function. Currently the use of # `is_quant_method_supported` conflates kernels with quantization methods @@ -167,8 +173,7 @@ def create_test_tensors( # for the practical use case we need per-tok scales for fp8 activations w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type) - # weights are already per-group quantized, use placeholder here - w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type) + w_ch_s = torch.randn((n,), device="cuda", dtype=types.channel_scale_type) return Tensors( w_ref=w_ref, @@ -211,7 +216,7 @@ def mm_test_helper( print(output_ref) torch.testing.assert_close( - output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3 + output, output_ref.to(output.dtype), rtol=1e-2, atol=1e-2 ) @@ -257,7 +262,7 @@ def test_w4a8_cuda_graph(): ) w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32) - w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32) + w_ch_s = torch.randn((n,), device="cuda", dtype=torch.float32) # Construct a trivial model with a single layer that calls the kernel model = W4A8Layer( @@ -287,4 +292,38 @@ def test_w4a8_cuda_graph(): output.zero_() g.replay() - torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES) +def test_convert_packed_uint4b8_to_signed_int4_inplace(shape): + """ + The W4A16 checkpoints encode the weights as int4b8 packed to int32. + The CUTLASS kernels expect signed int4 packed to int32. + This tests checks that the runtime int4b8 -> signed int4 conversion + matches the offline conversion step exactly. + """ + _, N, K = shape + # random weights packed to int32 + t = torch.randint( + low=torch.iinfo(torch.int32).min, + high=torch.iinfo(torch.int32).max + 1, + size=(N, K // 8), + dtype=torch.int32, + device="cuda", + ) + + # compute reference + unpacked = unpack_quantized_values_into_int32( + t.clone(), scalar_types.uint4b8, packed_dim=1 + ) + unpacked = unpacked - 8 # int4b8 -> signed int4 + ref = pack_cols(unpacked & 0x0F, 4, *unpacked.shape) + + out = convert_packed_uint4b8_to_signed_int4_inplace(t.clone()) + + assert torch.equal(ref, out) + assert not torch.equal(ref, t) diff --git a/tests/kernels/quantization/test_cutlass_w4a8_moe.py b/tests/kernels/quantization/test_cutlass_w4a8_moe.py new file mode 100644 index 0000000000000..a855f7333b617 --- /dev/null +++ b/tests/kernels/quantization/test_cutlass_w4a8_moe.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer. +""" + +import random +from dataclasses import dataclass + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, + quantize_weights, +) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +IS_SUPPORTED_BY_GPU = ( + current_platform.is_cuda() and current_platform.get_device_capability()[0] >= 9 +) + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn) + + +def cutlass_quantize( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: torch.dtype | None, + group_size: int | None, + zero_points: bool = False, +): + """ + Quantize weights into W4 and compute reference dequantized weights. + + Encoding/reordering of weights and packing of scales is deferred + until after all experts are combined. + """ + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, wtype, group_size=group_size, zero_points=zero_points + ) + + # Since scales are later cast to fp8, recompute w_ref in atype here. + w_ref = ( + w_q.to(torch.float32) + * w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0) + ).to(atype) + + # Bit mask prevents sign extension of int4 when packing. + w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape) + # Make weights row-major (N, K). + w_q = w_q.t().contiguous() + + return w_ref, w_q, w_s.to(atype), w_zp + + +def cutlass_preprocess( + w_q_experts: list[torch.Tensor], w_s_experts: list[torch.Tensor] +): + """ + Reorder/encode expert weights and pack scales. + + Returns: + w_q_packed: Packed/encoded int4 weights for all experts. + w_s_packed: Packed fp8 scales for all experts. + packed_layout: Layout/stride metadata for grouped GEMM. + """ + w_s_packed = ops.cutlass_pack_scale_fp8(torch.stack(w_s_experts)) + w_q_packed, packed_layout = ops.cutlass_encode_and_reorder_int4b_grouped( + torch.stack(w_q_experts) + ) # expects dim 3 + return w_q_packed, w_s_packed, packed_layout + + +GROUP_SIZE = 128 +# (num_experts, N, K) +TEST_SHAPES = [ + (8, 512, 2048), + (8, 2048, 2048), + (64, 512, 1024), + (64, 2048, 2048), + (4, 2048, 768), + (8, 768, 2048), + (64, 1536, 2048), + (128, 8192, 4096), # test overflow int32 +] +ALIGNMENT = 16 # torch._scaled_mm alignment for M, needed for reference check + + +@dataclass +class MoETestSetup: + num_experts: int + K: int + N: int + Ms: list[int] + M_full: int + a: torch.Tensor + a_ref: torch.Tensor + a_strides: torch.Tensor + out: torch.Tensor + c_strides: torch.Tensor + per_tok_scales: torch.Tensor + per_chan_scales: torch.Tensor + w_refs: list[torch.Tensor] + w_q_packed: torch.Tensor + w_s_packed: torch.Tensor + problem_sizes: torch.Tensor + expert_offsets: torch.Tensor + b_strides: torch.Tensor + group_scale_strides: torch.Tensor + + +def make_moe_test_setup( + num_experts: int, + K: int, + N: int, + *, + alignment: int = ALIGNMENT, + max_blocks: int = 64, + device: str = "cuda", + random_zero: bool = False, +) -> MoETestSetup: + """Create a full set of tensors for testing cutlass_w4a8_moe_mm.""" + + assert K % GROUP_SIZE == 0 + # Token counts per expert (multiples of `alignment`). + Ms = [alignment * random.randint(1, max_blocks) for _ in range(num_experts)] + + # set random experts to 0 tokens + if random_zero and num_experts > 1: + num_zero = max(1, num_experts // 8) + zero_indices = random.sample(range(num_experts), k=num_zero) + for idx in zero_indices: + Ms[idx] = 0 + + M_full = sum(Ms) + assert M_full > 0 + + # Activations. + a = to_fp8(torch.randn((M_full, K), device=device)) + a_ref = a.to(torch.float32) + a_strides = torch.full((num_experts,), K, dtype=torch.int64, device=device) + + # Output buffer. + out = torch.empty((M_full, N), dtype=torch.bfloat16, device=device) + c_strides = torch.full((num_experts,), N, dtype=torch.int64, device=device) + + # Channel/token scales. + per_tok_scales = torch.randn((M_full, 1), dtype=torch.float32, device=device) + per_chan_scales = torch.randn( + (num_experts, N, 1), dtype=torch.float32, device=device + ) + + # Expert weights and scales. + wtype = scalar_types.int4 + atype = stype = torch.float8_e4m3fn + w_refs, w_qs, w_ss = [], [], [] + for _ in range(num_experts): + b = to_fp8(torch.randn((K, N), device=device)) + w_ref, w_q, w_s, _ = cutlass_quantize( + atype, b.to(torch.float16), wtype, stype, GROUP_SIZE, zero_points=False + ) + w_refs.append(w_ref) + w_qs.append(w_q) + w_ss.append(w_s) + + w_q_packed, w_s_packed, packed_layout = cutlass_preprocess(w_qs, w_ss) + + problem_sizes = torch.tensor( + [[N, M, K] for M in Ms], dtype=torch.int32, device=device + ) + + expert_offsets = torch.cat( + [ + torch.tensor([0], dtype=torch.int64), + torch.cumsum(torch.tensor(Ms, dtype=torch.int64), dim=0)[:-1], + ] + ).to(device=device) + + # B strides and group scale strides. + b_strides = packed_layout + group_scale_strides = torch.zeros( + (num_experts, 2), dtype=torch.int64, device=device + ) + group_scale_strides[:, 0] = N + + return MoETestSetup( + num_experts=num_experts, + K=K, + N=N, + Ms=Ms, + M_full=M_full, + a=a, + a_ref=a_ref, + a_strides=a_strides, + out=out, + c_strides=c_strides, + per_tok_scales=per_tok_scales, + per_chan_scales=per_chan_scales, + w_refs=w_refs, + w_q_packed=w_q_packed, + w_s_packed=w_s_packed, + problem_sizes=problem_sizes, + expert_offsets=expert_offsets, + b_strides=b_strides, + group_scale_strides=group_scale_strides, + ) + + +def compute_moe_reference_output(setup: MoETestSetup) -> torch.Tensor: + """Compute reference output using torch._scaled_mm per expert.""" + out_ref = torch.empty_like(setup.out) + + ends = torch.cumsum(torch.tensor(setup.Ms), 0).tolist() + starts = setup.expert_offsets.cpu().tolist() + + for i in range(setup.num_experts): + start, end = starts[i], ends[i] + if start == end: + continue + + out_ref_i = torch._scaled_mm( + setup.a_ref[start:end].to(torch.float8_e4m3fn), + setup.w_refs[i].to(torch.float8_e4m3fn).t().contiguous().t(), + setup.per_tok_scales[start:end], # (M, 1) + setup.per_chan_scales[i].reshape(1, -1), # (1, N) + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + out_ref[start:end] = out_ref_i + + return out_ref + + +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, + reason="W4A8 Grouped GEMM is not supported on this GPU type.", +) +@pytest.mark.parametrize("shape", TEST_SHAPES) +@pytest.mark.parametrize("random_zero", [True, False]) +def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero): + num_experts, N, K = shape + current_platform.seed_everything(42) + setup = make_moe_test_setup( + num_experts=num_experts, K=K, N=N, max_blocks=64, random_zero=random_zero + ) + + ops.cutlass_w4a8_moe_mm( + setup.out, + setup.a, + setup.w_q_packed, + setup.per_tok_scales, + setup.per_chan_scales, + setup.w_s_packed, + GROUP_SIZE, + setup.expert_offsets, + setup.problem_sizes, + setup.a_strides, + setup.b_strides, + setup.c_strides, + setup.group_scale_strides, + ) + torch.cuda.synchronize() + + out_ref = compute_moe_reference_output(setup) + torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2) + + +class W4A8MoELayer(torch.nn.Module): + """ + Minimal wrapper module to test cuda graphs + """ + + def __init__(self, setup: MoETestSetup): + super().__init__() + self.setup = setup + + def forward(self, a: torch.Tensor) -> torch.Tensor: + s = self.setup + ops.cutlass_w4a8_moe_mm( + s.out, + a, + s.w_q_packed, + s.per_tok_scales, + s.per_chan_scales, + s.w_s_packed, + GROUP_SIZE, + s.expert_offsets, + s.problem_sizes, + s.a_strides, + s.b_strides, + s.c_strides, + s.group_scale_strides, + ) + return s.out + + +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, + reason="W4A8 Grouped GEMM is not supported on this GPU type.", +) +def test_cutlass_w4a8_moe_mm_cuda_graph(): + current_platform.seed_everything(42) + # Fixed config for CUDA graph test (single parameter point). + num_experts = 8 + K = 512 + N = 2048 + + setup = make_moe_test_setup( + num_experts=num_experts, + K=K, + N=N, + max_blocks=32, + ) + + # Construct model that calls the grouped GEMM kernel. + model = W4A8MoELayer(setup) + + # Build reference output once. + out_ref = compute_moe_reference_output(setup) + + # Capture and run the model in a CUDA graph. + a_static = setup.a.clone() # static input tensor for graph replay + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out_static = model(a_static) + + out_static.zero_() + g.replay() + + torch.testing.assert_close(out_static, out_ref, rtol=1e-2, atol=1e-2) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 6628ac650fd5f..f5e1cde94b6e9 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -62,7 +62,7 @@ def test_quantfp8_group_functionality( assert scales_col.stride(1) == batch_size # Test column-major scales consistency - assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8) + torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8) # 3. Test CUDA implementation (only for divisible dimensions) if is_divisible: @@ -71,7 +71,7 @@ def test_quantfp8_group_functionality( assert scales_cuda.shape == (batch_size, expected_num_groups) # Verify CUDA/native consistency - assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) + torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8) # Quantized values should mostly match diff_count = (x_quant_cuda != x_quant_native).sum().item() diff --git a/tests/kernels/quantization/test_hadacore.py b/tests/kernels/quantization/test_hadacore.py index 3ccee9db048cf..7a5c7fbd55f72 100644 --- a/tests/kernels/quantization/test_hadacore.py +++ b/tests/kernels/quantization/test_hadacore.py @@ -8,6 +8,13 @@ import torch from compressed_tensors.transform import deterministic_hadamard_matrix from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +if current_platform.is_rocm(): + pytest.skip( + "These tests require hadacore_transform, not supported on ROCm.", + allow_module_level=True, + ) @pytest.mark.parametrize("batch_size", [1, 32]) diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index efa81de158d38..7f4ce2a085807 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -23,6 +23,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +if current_platform.is_rocm(): + pytest.skip( + "These tests require machete_prepack_B, not supported on ROCm.", + allow_module_level=True, + ) + CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 59516db1b115d..995e777bb5e8b 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -56,6 +56,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +if current_platform.is_rocm(): + pytest.skip( + "These tests require gptq_marlin_repack," + "marlin_int4_fp8_preprocess, gptq_marlin_24_gemm," + "or gptq_marlin_gemm which are not supported on ROCm.", + allow_module_level=True, + ) + ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] USE_ATOMIC_ADD_OPTS = [False, True] diff --git a/tests/kernels/quantization/test_scaled_mm_kernel_selection.py b/tests/kernels/quantization/test_scaled_mm_kernel_selection.py new file mode 100644 index 0000000000000..2ed55931c8164 --- /dev/null +++ b/tests/kernels/quantization/test_scaled_mm_kernel_selection.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for ScaledMM kernel selection logic (CPU-only) + +Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`. +""" + +import inspect +from abc import ABC + +import pytest + +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + ScaledMMLinearLayerConfig, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( + AiterScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( + CPUScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearKernel, +) + +pytestmark = pytest.mark.cpu_test + + +def test_is_supported_is_abstract(): + """Test that is_supported() is properly defined as abstract.""" + assert issubclass(ScaledMMLinearKernel, ABC) + assert hasattr(ScaledMMLinearKernel, "is_supported") + + +def test_cpu_kernel_implements_is_supported(): + """Test that CPUScaledMMLinearKernel implements is_supported() method.""" + assert hasattr(CPUScaledMMLinearKernel, "is_supported"), ( + "CPUScaledMMLinearKernel missing is_supported() method" + ) + # Verify it's a classmethod by checking if it can be called with the class + # and by checking the method type + assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction( + CPUScaledMMLinearKernel.is_supported + ), "CPUScaledMMLinearKernel.is_supported() should be a classmethod" + # Verify it can be called as a classmethod + result, reason = CPUScaledMMLinearKernel.is_supported() + assert isinstance(result, bool), "is_supported() should return a bool" + assert reason is None or isinstance(reason, str), "reason should be str or None" + + +def test_aiter_kernel_implements_is_supported(): + """Test that AiterScaledMMLinearKernel implements is_supported() method.""" + assert hasattr(AiterScaledMMLinearKernel, "is_supported"), ( + "AiterScaledMMLinearKernel missing is_supported() method" + ) + # Verify it's a classmethod by checking if it can be called with the class + # and by checking the method type + assert inspect.ismethod( + AiterScaledMMLinearKernel.is_supported + ) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), ( + "AiterScaledMMLinearKernel.is_supported() should be a classmethod" + ) + # Verify it can be called as a classmethod + # (will return False on CPU, which is expected) + result, reason = AiterScaledMMLinearKernel.is_supported() + assert isinstance(result, bool), "is_supported() should return a bool" + assert reason is None or isinstance(reason, str), "reason should be str or None" + # On CPU, it should return False with a reason about requiring ROCm + # This validates the method works correctly even on non-ROCm platforms + + +def test_cpu_kernel_accepts_all_configs(): + """Test that CPUScaledMMLinearKernel accepts all config combinations.""" + configs = [ + ScaledMMLinearLayerConfig( + is_channelwise=False, + is_static_input_scheme=True, + input_symmetric=True, + ), + ScaledMMLinearLayerConfig( + is_channelwise=True, + is_static_input_scheme=False, + input_symmetric=False, + ), + ] + + for config in configs: + can_impl, reason = CPUScaledMMLinearKernel.can_implement(config) + assert can_impl, ( + f"CPUScaledMMLinearKernel should accept config {config}: {reason}" + ) diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index cadda27b49e9c..3bf69389753e3 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -9,23 +9,45 @@ from vllm.platforms import current_platform # Test parameters NUM_ROWS = [1, 32, 2050] -TOP_K_VALUES = [2048] -BATCH_SIZE = [1, 2, 4, 2048, 4096] -NEXT_N = [1, 2, 4, 8] +TOP_K_VALUES = [2048, 3000] +BATCH_SIZE = [1, 2, 2048] +NEXT_N = [1, 8] +DATA_GENERATION = ["random", "10LSBits"] def create_random_logits( row_starts: torch.Tensor, row_ends: torch.Tensor, - vocab_size: int, dtype: torch.dtype, seed: int, + data_generation: str, ) -> torch.Tensor: """Create random logits tensor for testing.""" torch.manual_seed(seed) np.random.seed(seed) # Generate logits with some structure to make testing more meaningful - logits = torch.randn(row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda") + if data_generation == "random": + logits = torch.randn( + row_starts.shape[0], max(row_ends), dtype=dtype, device="cuda" + ) + elif data_generation == "10LSBits": + top_22_bits_mask = 0xFFFFFC00 + last_10_bits_mask = 0x000003FF + fixed_top_22_bits = 0x3F900000 + # Generate random bits for the last 10 bits + random_bottom_bits = torch.randint( + 0, + 2**10, + (row_starts.shape[0], max(row_ends)), + dtype=torch.int32, + device="cuda", + ) + # Combine: fixed top 22 bits with random last 10 bits + logits_bits = (fixed_top_22_bits & top_22_bits_mask) | ( + random_bottom_bits & last_10_bits_mask + ) + logits = logits_bits.view(dtype) + for i, end in enumerate(row_ends): logits[i, end:] = float("-inf") return logits @@ -113,13 +135,13 @@ def test_top_k_per_row( # Create test data vocab_size = 20000 row_starts, row_ends = create_row_boundaries(num_rows, vocab_size) - logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) + logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random") # Create output tensors indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") # Run CUDA implementation - torch.ops._C.top_k_per_row( + torch.ops._C.top_k_per_row_prefill( logits, row_starts, row_ends, @@ -127,6 +149,7 @@ def test_top_k_per_row( num_rows, logits.stride(0), logits.stride(1), + top_k, ) # Run reference implementation @@ -139,27 +162,23 @@ def test_top_k_per_row( # Compare results assert compare_top_k_results( logits, indices, torch_indices, row_starts, row_ends, top_k - ), "CUDA top_k_per_row results don't match torch.topk" + ), "CUDA top_k_per_row_prefill results don't match torch.topk" -@pytest.mark.parametrize("top_k", TOP_K_VALUES) -@pytest.mark.parametrize("batch_size", BATCH_SIZE) -@pytest.mark.parametrize("next_n", NEXT_N) -@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") -@torch.inference_mode() -def test_top_k_per_row_decode( +def _run_top_k_per_row_decode_test( top_k: int, batch_size: int, next_n: int, + vocab_size: int, + data_generation: str, ) -> None: """ - Test top_k_per_row with seq_lens tensor. + Helper function to run top_k_per_row_decode test with given parameters. """ torch.set_default_device("cuda:0") # Create test data num_rows = batch_size * next_n - vocab_size = 20000 seq_lens = torch.randint( vocab_size, (batch_size,), dtype=torch.int32, device="cuda" ) @@ -167,7 +186,9 @@ def test_top_k_per_row_decode( row_indices = torch.arange(num_rows, device="cuda") // next_n next_n_offset = torch.arange(num_rows, device="cuda") % next_n row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1 - logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) + logits = create_random_logits( + row_starts, row_ends, torch.float32, 42, data_generation + ) # Create output tensors indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") @@ -181,6 +202,7 @@ def test_top_k_per_row_decode( num_rows, logits.stride(0), logits.stride(1), + top_k, ) torch.cuda.synchronize() @@ -195,4 +217,41 @@ def test_top_k_per_row_decode( # Compare results assert compare_top_k_results( logits, indices, torch_indices, row_starts, row_ends, top_k - ), "CUDA top_k_per_row results don't match torch.topk" + ), "CUDA top_k_per_row_decode results don't match torch.topk" + + +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("next_n", NEXT_N) +@pytest.mark.parametrize("data_generation", DATA_GENERATION) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_top_k_per_row_decode( + top_k: int, + batch_size: int, + next_n: int, + data_generation: str, +) -> None: + """ + Test top_k_per_row with seq_lens tensor. + """ + vocab_size = 20000 + _run_top_k_per_row_decode_test( + top_k, batch_size, next_n, vocab_size, data_generation + ) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_top_k_per_row_decode_large_vocab_size() -> None: + """ + Test top_k_per_row_decode with large vocabulary size. + """ + top_k = 2048 + batch_size = 2 + next_n = 2 + vocab_size = 300000 + data_generation = "random" + _run_top_k_per_row_decode_test( + top_k, batch_size, next_n, vocab_size, data_generation + ) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py deleted file mode 100644 index a61ccef700624..0000000000000 --- a/tests/kv_transfer/test_lookup_buffer.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import random - -import torch -from tqdm import tqdm - -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer -from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe - -# TODO: the test depends on a lot of fields in the current implementation. -# We should have standard interface instead direct field access - - -def test_run(my_rank, buffer, device): - # buffer should be empty in the beginning - if my_rank == 0: - assert buffer.buffer_size == 0 - assert len(buffer.buffer) == 0 - - print(f"My rank: {my_rank}, device: {device}") - - # insert - tokens = torch.tensor([1, 2, 3]).to(device) - roi = tokens > 0 - if my_rank == 0: - key = 2.0 * torch.ones([5, 6]).to(device) - value = 3.0 * torch.ones([5, 6]).to(device) - - placeholder = torch.tensor([1]).to(device) - - buffer.insert(tokens, roi, key, value, placeholder) - - torch.distributed.barrier() - - # drop_select - if my_rank == 1: - tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) - assert torch.allclose(tokens, tok) - assert torch.allclose(roi, roi_) - assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device)) - assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device)) - torch.distributed.barrier() - - if my_rank == 0: - assert buffer.buffer_size == 0 - assert len(buffer.buffer) == 0 - - print(f"My rank: {my_rank}, Test run passed!") - - -def stress_test(my_rank, buf, device): - torch.distributed.barrier() - torch.manual_seed(100) - - reqs = [ - ( - torch.rand(100).to(device), # tokens - torch.ones(100).bool().to(device), # roi - torch.rand(100).to(device), # key - torch.rand(100).to(device), # value - torch.rand(100).to(device), # hidden - ) - for i in tqdm(range(200)) - ] - - random.seed(my_rank) - random.shuffle(reqs) - - torch.distributed.barrier() - - n = 0 - - # the buffer size can only store 100 reqs - # so the sender will occasionally block to wait for the receiver. - for req in tqdm(reqs): - if my_rank == 0: - buf.insert(*req) - else: - tok, roi, k, v, h = req - tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi) - - if tok_ is None: - assert roi_ is None - assert k_ is None - assert v_ is None - assert h_ is None - n += 1 - else: - assert torch.allclose(tok, tok_) - assert torch.allclose(roi, roi_) - assert torch.allclose(k, k_) - assert torch.allclose(v, v_) - assert torch.allclose(h, h_) - print(f"Rank {my_rank} done") - torch.distributed.barrier() - - if my_rank == 0: - x = torch.tensor([0]) - torch.distributed.recv(x, 1) - # the # of None received is the kv that are not selected - assert x.item() == len(buf.buffer) - # and the size of the buffer should be 2000 * buffer len - print(buf.buffer_size) - assert buf.buffer_size == 1700 * len(buf.buffer) - else: - torch.distributed.send(torch.tensor([n]), 0) - - print(f"My rank: {my_rank}, Passed stress test!") - - -if __name__ == "__main__": - my_rank = int(os.environ["RANK"]) - - torch.distributed.init_process_group( - backend="gloo", - init_method="tcp://localhost:12398", - world_size=2, - rank=my_rank, - ) - - print(f"initialized! My rank is {my_rank}") - - config = KVTransferConfig( - kv_connector="P2pNcclConnector", - kv_buffer_device="cuda", - kv_buffer_size=1e9, - kv_rank=my_rank, - kv_role="kv_both", # this arg doesn't matter in this test - kv_parallel_size=2, - kv_ip="127.0.0.1", - kv_port=12345, - ) - - data_pipe = PyNcclPipe( - local_rank=my_rank, - config=config, - device="cuda", - port_offset=0, - ) - cpu_pipe = PyNcclPipe( - local_rank=my_rank, - config=config, - device="cpu", - port_offset=1, - ) - - buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000) - - test_run(my_rank, buffer, data_pipe.device) - - stress_test(my_rank, buffer, data_pipe.device) - - buffer.close() - data_pipe.close() - cpu_pipe.close() - print("Done") diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh deleted file mode 100644 index f2aeaee9ca6d5..0000000000000 --- a/tests/kv_transfer/test_lookup_buffer.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -RANK=0 python3 test_lookup_buffer.py & -PID0=$! -RANK=1 python3 test_lookup_buffer.py & -PID1=$! - -wait $PID0 -wait $PID1 diff --git a/tests/kv_transfer/test_module.py b/tests/kv_transfer/test_module.py deleted file mode 100644 index b9a28e4bceb7c..0000000000000 --- a/tests/kv_transfer/test_module.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import subprocess -import sys - -import pytest -import torch - - -def run_python_script(script_name, timeout): - script_name = f"kv_transfer/{script_name}" - try: - # Start both processes asynchronously using Popen - process0 = subprocess.Popen( - [sys.executable, script_name], - env={"RANK": "0"}, # Set the RANK environment variable for process 0 - stdout=sys.stdout, # Pipe stdout to current stdout - stderr=sys.stderr, # Pipe stderr to current stderr - ) - - process1 = subprocess.Popen( - [sys.executable, script_name], - env={"RANK": "1"}, # Set the RANK environment variable for process 1 - stdout=sys.stdout, # Pipe stdout to current stdout - stderr=sys.stderr, # Pipe stderr to current stderr - ) - - # Wait for both processes to complete, with a timeout - process0.wait(timeout=timeout) - process1.wait(timeout=timeout) - - # Check the return status of both processes - if process0.returncode != 0: - pytest.fail(f"Test {script_name} failed for RANK=0, {process0.returncode}") - if process1.returncode != 0: - pytest.fail(f"Test {script_name} failed for RANK=1, {process1.returncode}") - - except subprocess.TimeoutExpired: - # If either process times out, terminate both and fail the test - process0.terminate() - process1.terminate() - pytest.fail(f"Test {script_name} timed out") - except Exception as e: - pytest.fail(f"Test {script_name} failed with error: {str(e)}") - - -# Define the test cases using pytest's parametrize -@pytest.mark.parametrize( - "script_name,timeout", - [ - ("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout - ("test_send_recv.py", 120), # First test case with a 120-second timeout - ], -) -def test_run_python_script(script_name, timeout): - # Check the number of GPUs - if torch.cuda.device_count() < 2: - pytest.skip(f"Skipping test {script_name} because <2 GPUs are available") - - # Run the test if there are at least 2 GPUs - run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py deleted file mode 100644 index 5762224eff76d..0000000000000 --- a/tests/kv_transfer/test_send_recv.py +++ /dev/null @@ -1,154 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import time - -import torch -from tqdm import tqdm - -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe - - -def test_run(my_rank, pipe): - print(f"rank {my_rank} test_run starts....") - # test run - x = torch.tensor([1]).to(pipe.device) - y = torch.tensor([[2.0, 3.0, 4.0, 8.0]]).to(pipe.device) - if my_rank == 0: - pipe.send_tensor(x) - print(f"rank {my_rank} sent tensor x") - pipe.send_tensor(y) - print(f"rank {my_rank} sent tensor y") - x2 = pipe.recv_tensor() - print(f"rank {my_rank} received x2 = ", x2) - y2 = pipe.recv_tensor() - print(f"rank {my_rank} received y2 = ", y2) - - else: - x2 = pipe.recv_tensor() - print(f"rank {my_rank} received x2 = ", x2) - y2 = pipe.recv_tensor() - print(f"rank {my_rank} received y2 = ", y2) - pipe.send_tensor(x) - print(f"rank {my_rank} sent tensor x") - pipe.send_tensor(y) - print(f"rank {my_rank} sent tensor y") - - assert torch.allclose(x, x2) - assert torch.allclose(y, y2) - - print(f"rank {my_rank} test_run passed!") - - -def stress_test(my_rank, pipe): - print(f"rank {my_rank} stress_test starts....") - - tensors: list[torch.Tensor] = [] - - torch.distributed.barrier() - torch.manual_seed(0) - - for i in tqdm(range(500)): - mean = torch.rand(1).item() * 100 - std = torch.rand(1).item() * 100 - size = torch.randint(900, 1000, (2,)) - x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device) - - # 5% probability of sending a None - if torch.rand(1).item() < 0.05: - tensors.append(None) - tensors.append(None) - tensors.append(None) - else: - tensors.append(x) - tensors.append(x.mean().unsqueeze(0)) - tensors.append(x.std().unsqueeze(0)) - - torch.distributed.barrier() - - for i in tqdm(range(500)): - if my_rank == int((i % 10) > 3): - pipe.send_tensor(tensors[3 * i]) - pipe.send_tensor(tensors[3 * i + 1]) - pipe.send_tensor(tensors[3 * i + 2]) - else: - x = pipe.recv_tensor() - mean = pipe.recv_tensor() - std = pipe.recv_tensor() - - if x is None: - assert mean is None - assert std is None - else: - assert torch.allclose(x, tensors[3 * i]) - assert x.mean() == mean[0] - assert x.std() == std[0] - - torch.distributed.barrier() - - -def latency_test(my_rank, pipe, nelement, ntensor): - latencies = [] - - torch.distributed.barrier() - - for i in tqdm(range(500)): - tensors = [] - - if my_rank == 0: - # create tensor - tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] - - torch.distributed.barrier() - - if my_rank == 0: - t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device) - for tensor in tensors: - pipe.send_tensor(tensor) - pipe.send_tensor(t) - else: - for _ in range(ntensor): - pipe.recv_tensor() - t = pipe.recv_tensor() - latencies.append(time.time() - t.item()) - - torch.distributed.barrier() - - print("Latency test passed.") - print("Latency:", torch.tensor(latencies).mean().item() * 1000, "ms") - - -if __name__ == "__main__": - my_rank = int(os.environ["RANK"]) - - torch.distributed.init_process_group( - backend="gloo", - init_method="tcp://localhost:12398", - world_size=2, - rank=my_rank, - ) - - config = KVTransferConfig( - kv_connector="P2pNcclConnector", - kv_buffer_device="cuda", - kv_buffer_size=1e9, - kv_rank=my_rank, - kv_role="kv_both", # this arg doesn't matter in this test - kv_parallel_size=2, - kv_ip="127.0.0.1", - kv_port=12345, - ) - - pipe = PyNcclPipe( - local_rank=my_rank, - config=config, - ) - - test_run(my_rank, pipe) - - stress_test(my_rank, pipe) - - # Use this function if you want to test the latency of pipe impl. - # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh deleted file mode 100644 index 54e0604806841..0000000000000 --- a/tests/kv_transfer/test_send_recv.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -RANK=0 python3 test_send_recv.py & -PID0=$! -RANK=1 python3 test_send_recv.py & -PID1=$! - -wait $PID0 -wait $PID1 diff --git a/tests/lora/test_default_mm_loras.py b/tests/lora/test_default_mm_loras.py index 407b29fdd1d58..1d16862b30e52 100644 --- a/tests/lora/test_default_mm_loras.py +++ b/tests/lora/test_default_mm_loras.py @@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download from vllm.lora.request import LoRARequest from ..conftest import AudioTestAssets, VllmRunner +from ..utils import create_new_process_for_each_test MODEL_PATH = snapshot_download("microsoft/Phi-4-multimodal-instruct") AUDIO_LORA_PATH = os.path.join(MODEL_PATH, "speech-lora") @@ -60,6 +61,7 @@ def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, **kwargs) assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(expected_suffix) +@create_new_process_for_each_test() def test_active_default_mm_lora( vllm_runner: type[VllmRunner], audio_assets: AudioTestAssets, @@ -74,6 +76,7 @@ def test_active_default_mm_lora( ) +@create_new_process_for_each_test() def test_inactive_default_mm_lora( vllm_runner: type[VllmRunner], audio_assets: AudioTestAssets, @@ -89,6 +92,7 @@ def test_inactive_default_mm_lora( ) +@create_new_process_for_each_test() def test_default_mm_lora_succeeds_with_redundant_lora_request( vllm_runner: type[VllmRunner], audio_assets: AudioTestAssets, @@ -103,6 +107,7 @@ def test_default_mm_lora_succeeds_with_redundant_lora_request( ) +@create_new_process_for_each_test() def test_default_mm_lora_fails_with_overridden_lora_request( vllm_runner: type[VllmRunner], audio_assets: AudioTestAssets, @@ -118,6 +123,7 @@ def test_default_mm_lora_fails_with_overridden_lora_request( ) +@create_new_process_for_each_test() def test_default_mm_lora_does_not_expand_string_reqs(vllm_runner): class MockEngineException(Exception): pass diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py index f4269750feb6b..2fa61f280587f 100644 --- a/tests/lora/test_gptoss_tp.py +++ b/tests/lora/test_gptoss_tp.py @@ -76,6 +76,8 @@ def test_gpt_oss_lora(gptoss20b_lora_files): enable_lora=True, max_loras=4, max_lora_rank=8, + max_num_seqs=2, + max_num_batched_tokens=2048, compilation_config=vllm.config.CompilationConfig( # Avoid OOM cudagraph_specialize_lora=False, ), @@ -94,8 +96,10 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras): enable_lora=True, max_loras=2, max_lora_rank=8, - max_num_seqs=16, + max_num_seqs=2, + max_num_batched_tokens=2048, tensor_parallel_size=2, + gpu_memory_utilization=0.8, fully_sharded_loras=fully_sharded_loras, compilation_config=vllm.config.CompilationConfig( # Avoid OOM cudagraph_specialize_lora=False, diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 9df3a07a9e5e9..47d1fcfe9a0c7 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -28,7 +28,7 @@ from vllm.lora.layers import ( RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA, ) -from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.model_executor.layers.linear import ( ColumnParallelLinear, diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 18704fa6e45de..483235ff51291 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -76,11 +76,18 @@ def do_sample( if lora_id else None, ) - # Print the outputs. + lora_request = LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text + # The output should include correct lora_request info + if lora_request is not None: + assert output.lora_request.lora_name == lora_request.lora_name + assert output.lora_request.lora_int_id == lora_request.lora_int_id + assert output.lora_request.lora_path == lora_request.lora_path + else: + assert output.lora_request is None generated_texts.append(generated_text) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") return generated_texts diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index e9653a2fedfaf..e6816e83da001 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -3,7 +3,7 @@ import pytest -from vllm.lora.models import LoRAModel +from vllm.lora.lora_model import LoRAModel from vllm.lora.peft_helper import PEFTHelper from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.utils import WeightsMapper diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index 3348d2f8ce654..7c7f4eb4b626b 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -3,7 +3,7 @@ import pytest -from vllm.lora.models import LoRAModel +from vllm.lora.lora_model import LoRAModel from vllm.lora.peft_helper import PEFTHelper from vllm.lora.utils import get_adapter_absolute_path from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 081f14d6fabfb..50f17ced5dd74 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -15,10 +15,10 @@ from vllm.lora.layers import ( MergedColumnParallelLinearWithLoRA, RowParallelLinearWithLoRA, ) +from vllm.lora.lora_model import LoRAModel from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.models import ( +from vllm.lora.model_manager import ( LoRAMapping, - LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, ) diff --git a/tests/lora/test_moe_lora_align_sum.py b/tests/lora/test_moe_lora_align_sum.py index 72f1d759f1e7a..3a17f3eba6e8b 100644 --- a/tests/lora/test_moe_lora_align_sum.py +++ b/tests/lora/test_moe_lora_align_sum.py @@ -32,7 +32,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num): @pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920 @pytest.mark.parametrize("topk_num", [6]) -@pytest.mark.parametrize("num_experts", [64, 128]) +@pytest.mark.parametrize("num_experts", [64, 128, 256, 512]) @pytest.mark.parametrize("max_loras", [2, 32]) @pytest.mark.parametrize("block_size", [16]) def test_moe_lora_align_block_size( diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index eb026c2ec0209..bec12eeeb48d5 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -3,7 +3,7 @@ from collections import OrderedDict from typing import NamedTuple -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from huggingface_hub.utils import HfHubHTTPError @@ -194,5 +194,8 @@ def test_get_adapter_absolute_path_huggingface_error( # Hugging Face model identifier with download error path = "org/repo" mock_exist.return_value = False - mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info") + mock_snapshot_download.side_effect = HfHubHTTPError( + "failed to query model info", + response=MagicMock(), + ) assert get_adapter_absolute_path(path) == path diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index b163559a9414d..445aaf9cb7d1e 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -16,7 +16,7 @@ from vllm.config import ( ) from vllm.config.load import LoadConfig from vllm.config.lora import LoRAConfig -from vllm.lora.models import LoRAMapping +from vllm.lora.model_manager import LoRAMapping from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu_worker import Worker @@ -33,14 +33,16 @@ def test_worker_apply_lora(qwen3_lora_files): lora_requests, lora_mapping ) + model_config = ModelConfig( + MODEL_PATH, + seed=0, + dtype="float16", + max_model_len=127, + enforce_eager=True, + ) + vllm_config = VllmConfig( - model_config=ModelConfig( - MODEL_PATH, - seed=0, - dtype="float16", - max_model_len=127, - enforce_eager=True, - ), + model_config=model_config, load_config=LoadConfig( download_dir=None, load_format="dummy", @@ -50,7 +52,14 @@ def test_worker_apply_lora(qwen3_lora_files): tensor_parallel_size=1, data_parallel_size=1, ), - scheduler_config=SchedulerConfig("generate", 32, 32, 32), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + runner_type="generate", + max_num_batched_tokens=32, + max_num_seqs=32, + max_num_partial_prefills=32, + ), device_config=DeviceConfig("cuda"), cache_config=CacheConfig( block_size=16, diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py b/tests/model_executor/model_loader/runai_streamer_loader/__init__.py similarity index 100% rename from vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py rename to tests/model_executor/model_loader/runai_streamer_loader/__init__.py diff --git a/tests/model_executor/model_loader/runai_streamer_loader/conftest.py b/tests/model_executor/model_loader/runai_streamer_loader/conftest.py new file mode 100644 index 0000000000000..9a022f6bbd9d1 --- /dev/null +++ b/tests/model_executor/model_loader/runai_streamer_loader/conftest.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.executor import UniProcExecutor +from vllm.v1.worker.worker_base import WorkerWrapperBase + + +# This is a dummy executor for patching in test_runai_model_streamer_s3.py. +# We cannot use vllm_runner fixture here, because it spawns worker process. +# The worker process reimports the patched entities, and the patch is not applied. +class RunaiDummyExecutor(UniProcExecutor): + def _init_executor(self) -> None: + distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) + + local_rank = 0 + rank = 0 + is_driver_worker = True + + device_info = self.vllm_config.device_config.device.__str__().split(":") + if len(device_info) > 1: + local_rank = int(device_info[1]) + + worker_rpc_kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + + wrapper_kwargs = { + "vllm_config": self.vllm_config, + } + + self.driver_worker = WorkerWrapperBase(**wrapper_kwargs) + + self.collective_rpc("init_worker", args=([worker_rpc_kwargs],)) + self.collective_rpc("init_device") diff --git a/tests/model_executor/model_loader/runai_model_streamer/test_runai_model_streamer_loader.py b/tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_loader.py similarity index 100% rename from tests/model_executor/model_loader/runai_model_streamer/test_runai_model_streamer_loader.py rename to tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_loader.py diff --git a/tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_s3.py b/tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_s3.py new file mode 100644 index 0000000000000..d60c9ba64cbdb --- /dev/null +++ b/tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_s3.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path + +from huggingface_hub import snapshot_download +from runai_model_streamer.safetensors_streamer.streamer_mock import StreamerPatcher + +from vllm.engine.arg_utils import EngineArgs + +from .conftest import RunaiDummyExecutor + +load_format = "runai_streamer" +test_model = "openai-community/gpt2" + + +def test_runai_model_loader_download_files_s3_mocked_with_patch( + vllm_runner, + tmp_path: Path, + monkeypatch, +): + patcher = StreamerPatcher(str(tmp_path)) + + test_mock_s3_model = "s3://my-mock-bucket/gpt2/" + + # Download model from HF + mock_model_dir = f"{tmp_path}/gpt2" + snapshot_download(repo_id=test_model, local_dir=mock_model_dir) + + monkeypatch.setattr( + "vllm.transformers_utils.runai_utils.runai_list_safetensors", + patcher.shim_list_safetensors, + ) + monkeypatch.setattr( + "vllm.transformers_utils.runai_utils.runai_pull_files", + patcher.shim_pull_files, + ) + monkeypatch.setattr( + "vllm.model_executor.model_loader.weight_utils.SafetensorsStreamer", + patcher.create_mock_streamer, + ) + + engine_args = EngineArgs( + model=test_mock_s3_model, + load_format=load_format, + tensor_parallel_size=1, + ) + + vllm_config = engine_args.create_engine_config() + + executor = RunaiDummyExecutor(vllm_config) + executor.driver_worker.load_model() diff --git a/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py b/tests/model_executor/model_loader/runai_streamer_loader/test_runai_utils.py similarity index 100% rename from tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py rename to tests/model_executor/model_loader/runai_streamer_loader/test_runai_utils.py diff --git a/tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py b/tests/model_executor/model_loader/runai_streamer_loader/test_weight_utils.py similarity index 100% rename from tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py rename to tests/model_executor/model_loader/runai_streamer_loader/test_weight_utils.py diff --git a/tests/models/fixtures/audioflamingo3/expected_results_batched.json b/tests/models/fixtures/audioflamingo3/expected_results_batched.json new file mode 100644 index 0000000000000..4dbb107edccb7 --- /dev/null +++ b/tests/models/fixtures/audioflamingo3/expected_results_batched.json @@ -0,0 +1 @@ +{"transcriptions": ["There is no clear relationship between the barking and the music, as they seem to be independent of each other.", "(B) To indicate that language cannot express clearly, satirizing the inversion of black and white in the world"], "token_ids": [[3862, 374, 902, 2797, 5025, 1948, 279, 293, 33452, 323, 279, 4627, 11, 438, 807, 2803, 311, 387, 9489, 315, 1817, 1008, 13, 151645], [5349, 8, 2014, 13216, 429, 4128, 4157, 3158, 9355, 11, 7578, 404, 4849, 279, 46488, 315, 3691, 323, 4158, 304, 279, 1879, 151645, 151671]]} \ No newline at end of file diff --git a/tests/models/fixtures/audioflamingo3/expected_results_single.json b/tests/models/fixtures/audioflamingo3/expected_results_single.json new file mode 100644 index 0000000000000..be9233467a20e --- /dev/null +++ b/tests/models/fixtures/audioflamingo3/expected_results_single.json @@ -0,0 +1 @@ +{"transcriptions": ["The content of the input audio is 'you can ask why over and over and over again forever even if one day we explain every physical interaction and scientific law and hope and dream and regret with a single elegant equation'."], "token_ids": [[785, 2213, 315, 279, 1946, 7699, 374, 364, 9330, 646, 2548, 3170, 916, 323, 916, 323, 916, 1549, 15683, 1496, 421, 825, 1899, 582, 10339, 1449, 6961, 16230, 323, 12344, 2329, 323, 3900, 323, 7904, 323, 22231, 448, 264, 3175, 25777, 23606, 4427, 151645]]} \ No newline at end of file diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index 1377776a6d84b..0ef4ba2577724 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -5,12 +5,12 @@ import json import pytest -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( +from vllm.sampling_params import SamplingParams +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers.mistral_tool_parser import ( MistralToolCall, MistralToolParser, ) -from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer from ...utils import check_logprobs_close @@ -315,3 +315,38 @@ def test_mistral_function_call_nested_json(): assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict # No additional content outside the tool call should be returned. assert parsed.content is None + + # multiple calls + multiple_args_dict = [ + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + "sub_dict": {"foo": "bar", "inner": {"x": 1, "y": 2}}, + }, + {}, + {"a": 0}, + {"a": 1, "b": "c"}, + ] + names = ["get_current_weather", "get_current_weather_2", "random", "random_2"] + + model_output = "".join( + [ + f"{parser.bot_token}{name}{json.dumps(args)}" + for name, args in zip(names, multiple_args_dict) + ] + ) + + parsed = parser.extract_tool_calls(model_output, None) + + # Assertions: the tool call is detected and the full nested JSON is parsed + # without truncation. + assert parsed.tools_called + assert len(parsed.tool_calls) == len(multiple_args_dict) + + for i, tool_call in enumerate(parsed.tool_calls): + assert MistralToolCall.is_valid_id(tool_call.id) + assert tool_call.function.name == names[i] + assert json.loads(tool_call.function.arguments) == multiple_args_dict[i] + # No additional content outside the tool call should be returned. + assert parsed.content is None diff --git a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py new file mode 100644 index 0000000000000..c259c532220b2 --- /dev/null +++ b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModel + +from tests.models.utils import check_embeddings_close +from vllm import TokensPrompt + + +@pytest.mark.parametrize( + "model", + ["Qwen/Qwen3-Embedding-0.6B"], +) +@torch.inference_mode +def test_embed_models(hf_runner, vllm_runner, model: str): + chunk_size = 10 + n_prompt_tokens = [55, 56, 57] + token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] + + with vllm_runner( + model, + runner="pooling", + max_model_len=128, + max_num_batched_tokens=chunk_size, + enforce_eager=True, + # `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner + enable_chunked_prefill=True, + enable_prefix_caching=True, + ) as vllm_model: + vllm_outputs = vllm_model.token_embed( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + ) + + with hf_runner( + model, + auto_cls=AutoModel, + ) as hf_model: + hf_outputs = [] + for token_prompt in token_prompts: + inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])}) + input_ids = inputs["input_ids"] + output = hf_model.model(input_ids) + hf_outputs.append(output.last_hidden_state.cpu().float()[0]) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + check_embeddings_close( + embeddings_0_lst=hf_output, + embeddings_1_lst=vllm_output, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index 0d41b93233d5a..488b27e2da0f1 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): max_model_len=128, enforce_eager=True, runner="pooling", - enable_chunked_prefill=False, enable_prefix_caching=True, ) as vllm_model: pooling_outputs = vllm_model.llm.encode( diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py index a31a771238e26..d50ee85b9fd2b 100644 --- a/tests/models/language/pooling/test_mm_classifier_conversion.py +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -17,7 +17,6 @@ def test_idefics_multimodal( with vllm_runner( model_name="HuggingFaceM4/Idefics3-8B-Llama3", runner="pooling", - task="classify", convert="classify", load_format="dummy", max_model_len=512, @@ -86,7 +85,6 @@ def test_gemma_multimodal( with vllm_runner( model_name="google/gemma-3-4b-it", runner="pooling", - task="classify", convert="classify", load_format="auto", hf_overrides=update_config, diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py index 2dfc0072126bc..64d42432c74b9 100644 --- a/tests/models/language/pooling/test_token_classification.py +++ b/tests/models/language/pooling/test_token_classification.py @@ -68,3 +68,34 @@ def test_modernbert_models( hf_output = torch.tensor(hf_output).cpu().float() vllm_output = torch.tensor(vllm_output).cpu().float() assert torch.allclose(hf_output, vllm_output, atol=1e-2) + + +@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"]) +@pytest.mark.parametrize("dtype", ["float"]) +@torch.inference_mode +def test_auto_conversion( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.token_classify(example_prompts) + + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForTokenClassification + ) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + hf_outputs.append(softmax(output.logits[0])) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).cpu().float() + vllm_output = torch.tensor(vllm_output).cpu().float() + assert torch.allclose(hf_output, vllm_output, atol=1e-2) diff --git a/tests/models/multimodal/generation/conftest.py b/tests/models/multimodal/generation/conftest.py new file mode 100644 index 0000000000000..26f8586742cea --- /dev/null +++ b/tests/models/multimodal/generation/conftest.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pytest configuration for vLLM tests.""" + +import warnings + +import torch + +from vllm.platforms import current_platform + + +def pytest_configure(config): + """Disable Flash/MemEfficient SDP on ROCm to avoid HF + Transformers accuracy issues. + """ + if not current_platform.is_rocm(): + return + + skip_patterns = ["test_granite_speech.py"] + if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns): + # Skip disabling SDP for Granite Speech tests on ROCm + return + + # Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers + # accuracy issues + # TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_math_sdp(True) + warnings.warn( + "ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp " + "to avoid HuggingFace Transformers accuracy issues", + UserWarning, + stacklevel=1, + ) diff --git a/tests/models/multimodal/generation/test_audioflamingo3.py b/tests/models/multimodal/generation/test_audioflamingo3.py new file mode 100644 index 0000000000000..d14291a62c346 --- /dev/null +++ b/tests/models/multimodal/generation/test_audioflamingo3.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import pytest + +from tests.models.registry import HF_EXAMPLE_MODELS +from vllm import LLM, SamplingParams + +MODEL_NAME = "nvidia/audio-flamingo-3-hf" + + +def get_fixture_path(filename): + return os.path.join( + os.path.dirname(__file__), "../../fixtures/audioflamingo3", filename + ) + + +@pytest.fixture(scope="module") +def llm(): + # Check if the model is supported by the current transformers version + model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration") + model_info.check_transformers_version(on_fail="skip") + + try: + llm = LLM( + model=MODEL_NAME, + trust_remote_code=True, + dtype="bfloat16", + enforce_eager=True, + limit_mm_per_prompt={"audio": 1}, + ) + return llm + except Exception as e: + pytest.skip(f"Failed to load model {MODEL_NAME}: {e}") + + +def test_single_generation(llm): + fixture_path = get_fixture_path("expected_results_single.json") + if not os.path.exists(fixture_path): + pytest.skip(f"Fixture not found: {fixture_path}") + + with open(fixture_path) as f: + expected = json.load(f) + + audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Why_do_we_ask_questions_converted.wav" + + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": "Transcribe the input speech."}, + ], + } + ] + + sampling_params = SamplingParams(temperature=0.0, max_tokens=128) + + outputs = llm.chat( + messages=messages, + sampling_params=sampling_params, + ) + generated_text = outputs[0].outputs[0].text.strip() + + expected_text = expected["transcriptions"][0] + + assert expected_text in generated_text or generated_text in expected_text + + +def test_batched_generation(llm): + fixture_path = get_fixture_path("expected_results_batched.json") + if not os.path.exists(fixture_path): + pytest.skip(f"Fixture not found: {fixture_path}") + + with open(fixture_path) as f: + expected = json.load(f) + + items = [ + { + "audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/dogs_barking_in_sync_with_the_music.wav", + "question": "What is surprising about the relationship " + "between the barking and the music?", + "expected_idx": 0, + }, + { + "audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Ch6Ae9DT6Ko_00-04-03_00-04-31.wav", + "question": ( + "Why is the philosopher's name mentioned in the lyrics? " + "(A) To express a sense of nostalgia " + "(B) To indicate that language cannot express clearly, " + "satirizing the inversion of black and white in the world " + "(C) To add depth and complexity to the lyrics " + "(D) To showcase the wisdom and influence of the philosopher" + ), + "expected_idx": 1, + }, + ] + + conversations = [] + for item in items: + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": item["audio_url"]}}, + {"type": "text", "text": item["question"]}, + ], + } + ] + conversations.append(messages) + + sampling_params = SamplingParams(temperature=0.0, max_tokens=128) + + outputs = llm.chat( + messages=conversations, + sampling_params=sampling_params, + ) + + for i, output in enumerate(outputs): + generated_text = output.outputs[0].text.strip() + expected_text = expected["transcriptions"][i] + + assert expected_text in generated_text or generated_text in expected_text diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index deaeea059ccaf..c5a0b6748f797 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -137,7 +137,7 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "qwen2_5_omni": VLMTestInfo( @@ -152,7 +152,7 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForTextToWaveform, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "qwen3_vl": VLMTestInfo( @@ -173,7 +173,7 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, patch_hf_runner=model_utils.qwen3_vl_patch_hf_runner, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[ pytest.mark.core_model, ], @@ -278,7 +278,7 @@ VLM_TEST_SETTINGS = { marks=[large_gpu_mark(min_gb=64)], ), "aya_vision": VLMTestInfo( - models=["CohereForAI/aya-vision-8b"], + models=["CohereLabs/aya-vision-8b"], test_type=(VLMTestType.IMAGE), prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 single_image_prompts=IMAGE_ASSETS.prompts( @@ -294,7 +294,7 @@ VLM_TEST_SETTINGS = { vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}, ), "aya_vision-multi_image": VLMTestInfo( - models=["CohereForAI/aya-vision-8b"], + models=["CohereLabs/aya-vision-8b"], test_type=(VLMTestType.MULTI_IMAGE), prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 single_image_prompts=IMAGE_ASSETS.prompts( @@ -350,7 +350,7 @@ VLM_TEST_SETTINGS = { patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner, hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output, stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], - image_size_factors=[(), (1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)], + image_size_factors=[(1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)], ), "fuyu": VLMTestInfo( models=["adept/fuyu-8b"], @@ -382,7 +382,6 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForImageTextToText, vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}}, patch_hf_runner=model_utils.gemma3_patch_hf_runner, - num_logprobs=10, ), "glm4v": VLMTestInfo( models=["zai-org/glm-4v-9b"], @@ -403,12 +402,13 @@ VLM_TEST_SETTINGS = { # So, we need to reduce the number of tokens for the test to pass. max_tokens=8, num_logprobs=10, + auto_cls=AutoModelForCausalLM, marks=[large_gpu_mark(min_gb=32)], ), "glm4_1v": VLMTestInfo( models=["zai-org/GLM-4.1V-9B-Thinking"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", + prompt_formatter=lambda img_prompt: f"[gMASK]<|user|>\n{img_prompt}<|assistant|>\n", # noqa: E501 img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", max_model_len=2048, @@ -423,6 +423,7 @@ VLM_TEST_SETTINGS = { models=["zai-org/GLM-4.1V-9B-Thinking"], # GLM4.1V require include video metadata for input test_type=VLMTestType.CUSTOM_INPUTS, + prompt_formatter=lambda vid_prompt: f"[gMASK]<|user|>\n{vid_prompt}<|assistant|>\n", # noqa: E501 max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -707,7 +708,7 @@ VLM_TEST_SETTINGS = { max_model_len=8192, max_num_seqs=2, auto_cls=AutoModelForCausalLM, - image_size_factors=[(), (0.25,)], + image_size_factors=[(0.25,)], marks=[ pytest.mark.skipif( Version(TRANSFORMERS_VERSION) == Version("4.57.3"), @@ -737,7 +738,13 @@ VLM_TEST_SETTINGS = { max_model_len=8192, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - marks=[large_gpu_mark(min_gb=48)], + marks=[ + large_gpu_mark(min_gb=48), + pytest.mark.skipif( + current_platform.is_rocm(), + reason="Model produces a vector of <UNK> output in HF on ROCm", + ), + ], ), "qwen_vl": VLMTestInfo( models=["Qwen/Qwen-VL"], @@ -760,7 +767,7 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.cpu_model], ), "skywork_r1v": VLMTestInfo( @@ -812,7 +819,7 @@ VLM_TEST_SETTINGS = { max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + image_size_factors=[(0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.skip("Model initialization hangs")], ), ### Tensor parallel / multi-gpu broadcast tests diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index e39dfc888779e..f528a993f8551 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -8,6 +8,7 @@ from transformers import AutoModelForSpeechSeq2Seq from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform from ....conftest import AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner from ...registry import HF_EXAMPLE_MODELS @@ -34,6 +35,12 @@ audio_lora_path = MODEL_NAME models = [MODEL_NAME] +@pytest.fixture(autouse=True) +def set_attention_backend_for_rocm(monkeypatch): + if current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") + + def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], @@ -111,8 +118,12 @@ def run_test( @pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_model_len", [2048]) +@pytest.mark.parametrize( + "dtype", ["float16"] if current_platform.is_rocm() else ["bfloat16"] +) +@pytest.mark.parametrize( + "max_model_len", [512] if current_platform.is_rocm() else [2048] +) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_models( diff --git a/tests/models/multimodal/generation/test_phi4_multimodal.py b/tests/models/multimodal/generation/test_phi4_multimodal.py deleted file mode 100644 index 62456221711ed..0000000000000 --- a/tests/models/multimodal/generation/test_phi4_multimodal.py +++ /dev/null @@ -1,281 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from collections.abc import Sequence - -import librosa -import pytest -from huggingface_hub import snapshot_download - -from vllm.assets.image import ImageAsset -from vllm.lora.request import LoRARequest -from vllm.multimodal.image import rescale_image_size - -from ....conftest import ( - IMAGE_ASSETS, - HfRunner, - PromptAudioInput, - PromptImageInput, - VllmRunner, -) -from ....utils import large_gpu_test -from ...utils import check_logprobs_close - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( - { - "stop_sign": "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 - "cherry_blossom": "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 - } -) -HF_MULTIIMAGE_IMAGE_PROMPT = ( - "<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 -) - -model_path = snapshot_download( - "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70" -) -# Since the vision-lora and speech-lora co-exist with the base model, -# we have to manually specify the path of the lora weights. -vision_lora_path = os.path.join(model_path, "vision-lora") -speech_question = os.path.join( - model_path, "examples", "what_is_shown_in_this_image.wav" -) -models = [model_path] - -target_dtype = "half" - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - inputs: Sequence[tuple[list[str], PromptImageInput, PromptAudioInput | None]], - model: str, - *, - max_model_len: int, - dtype: str, - max_tokens: int, - num_logprobs: int, - mm_limit: int, - tensor_parallel_size: int, - distributed_executor_backend: str | None = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - # max_model_len should be greater than image_feature_size - with vllm_runner( - model, - task="generate", - max_model_len=max_model_len, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=320, - gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI - enforce_eager=True, - trust_remote_code=False, - ) as vllm_model: - lora_request = LoRARequest("vision", 1, vision_lora_path) - vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs( - prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - lora_request=lora_request, - ) - for prompts, images, audios in inputs - ] - - with hf_runner(model, dtype=dtype) as hf_model: - hf_model.model.load_adapter( - vision_lora_path, - adapter_name="vision", - ) - hf_processor = hf_model.processor - eos_token_id = hf_processor.tokenizer.eos_token_id - hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - eos_token_id=eos_token_id, - ) - for prompts, images, audios in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [12800]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_models( - hf_runner, - vllm_runner, - image_assets, - model, - size_factors, - dtype: str, - max_model_len: int, - max_tokens: int, - num_logprobs: int, -) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [ - ( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - None, - ) - for image, prompt in zip(images, HF_IMAGE_PROMPTS) - ] - - run_test( - hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_model_len=max_model_len, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=1, - tensor_parallel_size=1, - ) - - -@large_gpu_test(min_gb=48) -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - # [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [25600]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_multi_images_models( - hf_runner, - vllm_runner, - image_assets, - model, - size_factors, - dtype: str, - max_model_len: int, - max_tokens: int, - num_logprobs: int, -) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_case = [ - ( - [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [ - [rescale_image_size(image, factor) for image in images] - for factor in size_factors - ], - None, - ), - ] - - run_test( - hf_runner, - vllm_runner, - inputs_per_case, - model, - dtype=dtype, - max_model_len=max_model_len, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=2, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [12800]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_vision_speech_models( - hf_runner, - vllm_runner, - model, - dtype: str, - max_model_len: int, - max_tokens: int, - num_logprobs: int, -) -> None: - # use the example speech question so that the model outputs are reasonable - audio = librosa.load(speech_question, sr=16000) - image = ImageAsset("cherry_blossom").pil_image.convert("RGB") - - inputs_vision_speech = [ - ( - ["<|user|><|image|><|audio|><|end|><|assistant|>"], - [image], - [audio], - ), - ] - - run_test( - hf_runner, - vllm_runner, - inputs_vision_speech, - model, - dtype=dtype, - max_model_len=max_model_len, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=1, - tensor_parallel_size=1, - ) diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index 3cad2c43d5623..375099f4365ac 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -15,6 +15,7 @@ from transformers import AutoProcessor from vllm import SamplingParams, TextPrompt, TokensPrompt from vllm.logprobs import Logprob, SampleLogprobs from vllm.multimodal import MultiModalDataBuiltins +from vllm.platforms import current_platform from ....utils import VLLM_PATH, large_gpu_test from ...utils import check_logprobs_close @@ -165,6 +166,15 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: def test_chat( vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server ) -> None: + if ( + model == MISTRAL_SMALL_3_1_ID + and max_model_len == 65536 + and current_platform.is_rocm() + ): + pytest.skip( + "OOM on ROCm: 24B model with 65536 context length exceeds GPU memory" + ) + EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model]) with vllm_runner( model, diff --git a/tests/models/multimodal/generation/test_vit_backend_functionality.py b/tests/models/multimodal/generation/test_vit_backend_functionality.py new file mode 100644 index 0000000000000..a4e4ce312ddd4 --- /dev/null +++ b/tests/models/multimodal/generation/test_vit_backend_functionality.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Consolidated test for ViT attention backend functionality across multiple models. + +This test validates that each multimodal model can successfully generate outputs +using different ViT attention backends. Tests are parametrized by model and backend. +""" + +from dataclasses import asdict +from typing import Any + +import pytest +from transformers import AutoProcessor + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.multimodal.utils import encode_image_base64 +from vllm.multimodal.video import sample_frames_from_video +from vllm.platforms import current_platform + +from ....utils import create_new_process_for_each_test +from ...utils import dummy_hf_overrides + +# Dots.OCR prompt from official repository +# https://github.com/rednote-hilab/dots.ocr/blob/d72d1d8c5bdd0362eb264f714cdbd1e5daa7cdff/dots_ocr/utils/prompts.py#L3 +# ruff: noqa: E501 +DOTS_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. + +1. Bbox format: [x1, y1, x2, y2] + +2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. + +3. Text Extraction & Formatting Rules: + - Picture: For the 'Picture' category, the text field should be omitted. + - Formula: Format its text as LaTeX. + - Table: Format its text as HTML. + - All Others (Text, Title, etc.): Format their text as Markdown. + +4. Constraints: + - The output text must be the original text from the image, with no translation. + - All layout elements must be sorted according to human reading order. + +5. Final Output: The entire output must be a single JSON object. +""" + +VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" + + +# Model configurations +MODEL_CONFIGS: dict[str, dict[str, Any]] = { + "dots_ocr": { + "model_name": "rednote-hilab/dots.ocr", + "interface": "llm_chat", + "max_model_len": 32768, + "max_num_seqs": 1, + "limit_mm_per_prompt": {"image": 1}, + "sampling_params": { + "temperature": 0.1, + "max_tokens": 16384, + "top_p": 0.9, + "stop_token_ids": None, + }, + "use_specific_image": "stop_sign", + "prompt_builder": "build_dots_ocr_prompt", + "output_validator": lambda x: len(x) > 10 and "stop" in x.lower(), + }, + "ernie45_vl": { + "model_name": "baidu/ERNIE-4.5-VL-28B-A3B-PT", + "interface": "llm_generate", + "max_model_len": 16384, + "max_num_seqs": 2, + "sampling_params": { + "temperature": 0.0, + "max_tokens": 256, + "stop_token_ids": None, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, + "glm4_1v": { + "model_name": "zai-org/GLM-4.1V-9B-Thinking", + "interface": "llm_generate", + "max_model_len": 32768, + "max_num_seqs": 2, + "sampling_params": { + "temperature": 0.0, + "max_tokens": 256, + "stop_token_ids": None, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, + "keye_vl": { + "model_name": "Kwai-Keye/Keye-VL-8B-Preview", + "interface": "llm_generate", + "max_model_len": 8192, + "max_num_seqs": 5, + "sampling_params": { + "temperature": 0.0, + "max_tokens": 256, + "stop_token_ids": None, + }, + "supported_backends": { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, + "ovis2_5": { + "model_name": "AIDC-AI/Ovis2.5-2B", + "interface": "llm_generate", + "max_model_len": 8192, + "max_num_seqs": 2, + "sampling_params": { + "temperature": 0.0, + "max_tokens": 256, + "stop_token_ids": None, + }, + "prompt_builder": "build_ovis_prompt", + "question": "What is the content of each image?", + }, + "qwen2_5_vl": { + "model_name": "Qwen/Qwen2.5-VL-3B-Instruct", + "interface": "vllm_runner", + "media_type": "video", + "max_model_len": 4000, + "max_num_seqs": 1, + "limit_mm_per_prompt": {"video": 1}, + "sampling_params": { + "max_tokens": 128, + }, + "runner_kwargs": { + "runner": "generate", + "dtype": "bfloat16", + }, + "video_params": { + "num_frames": 16, + "pruning_rates": [0.0, 0.75], + }, + }, + "qwen2_5_omni": { + "model_name": "Qwen/Qwen2.5-Omni-3B", + "interface": "llm_generate", + "max_model_len": 32768, + "max_num_seqs": 2, + "limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3}, + "sampling_params": { + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_tokens": 16384, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, + "qwen3_omni": { + "model_name": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "interface": "llm_generate", + "max_model_len": 32768, + "max_num_seqs": 2, + "limit_mm_per_prompt": {"image": 3, "video": 3, "audio": 3}, + "sampling_params": { + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_tokens": 16384, + }, + "use_processor": True, + "question": "What is the content of each image?", + }, +} + + +# Prompt builder functions +def build_dots_ocr_prompt(images, config): + """Build Dots.OCR specific prompt with OCR instructions.""" + # Use only stop_sign image for Dots.OCR + image = images[0] # Already filtered to stop_sign + + image_url = f"data:image/jpeg;base64,{encode_image_base64(image)}" + + placeholders = [{"type": "image_url", "image_url": {"url": image_url}}] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + { + "type": "text", + "text": f"<|img|><|imgpad|><|endofimg|>{DOTS_OCR_PROMPT}", + }, + ], + }, + ] + + return messages + + +def build_processor_prompt(images, config): + """Build prompt using AutoProcessor.apply_chat_template().""" + processor = AutoProcessor.from_pretrained( + config["model_name"], trust_remote_code=True + ) + + image_urls = [ + f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images + ] + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": config["question"]}, + ], + }, + ] + + return processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + +def build_ovis_prompt(images, config): + """Build Ovis2.5 specific prompt with custom format.""" + image_urls = [ + f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images + ] + + placeholders = "\n".join( + f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) + ) + + return ( + f"<|im_start|>user\n\n{placeholders}\n{config['question']}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + +def build_qwen2_5_video_prompt(): + """Build Qwen2.5-VL video prompt with EVS placeholder.""" + return ( + f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n{VIDEO_PLACEHOLDER}" + "Describe this video with a short sentence (no more than 20 words)" + "<|im_end|><|im_start|>assistant\n" + ) + + +# Handler functions +def run_llm_generate_test(config, mm_encoder_attn_backend, image_assets): + """Standard LLM.generate() interface handler.""" + images = [asset.pil_image for asset in image_assets] + + # Build prompt + if config.get("use_processor"): + prompt = build_processor_prompt(images, config) + else: + prompt_builder_name = config.get("prompt_builder", "build_ovis_prompt") + prompt_builder = globals()[prompt_builder_name] + prompt = prompt_builder(images, config) + + # Determine limit_mm_per_prompt + limit_mm_per_prompt = config.get("limit_mm_per_prompt", {"image": len(images)}) + + # Create engine + engine_args = EngineArgs( + model=config["model_name"], + trust_remote_code=True, + max_model_len=config["max_model_len"], + max_num_seqs=config["max_num_seqs"], + limit_mm_per_prompt=limit_mm_per_prompt, + mm_encoder_attn_backend=mm_encoder_attn_backend, + hf_overrides=dummy_hf_overrides, + load_format="dummy", + ) + + engine_dict = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_dict) + + # Generate + sampling_params = SamplingParams(**config["sampling_params"]) + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": {"image": images}, + }, + sampling_params=sampling_params, + ) + + # Validate + for o in outputs: + generated_text = o.outputs[0].text + validator = config.get("output_validator", lambda x: len(x) > 10) + assert validator(generated_text), ( + f"Validation failed for {config['model_name']}: {generated_text}" + ) + + +def run_llm_chat_test(config, mm_encoder_attn_backend, image_assets): + """LLM.chat() interface handler for Dots.OCR.""" + # Filter to stop_sign image only + stop_sign_image = [ + asset.pil_image for asset in image_assets if asset.name == "stop_sign" + ][0] + + # Build messages + messages = build_dots_ocr_prompt([stop_sign_image], config) + + # Create engine + engine_args = EngineArgs( + model=config["model_name"], + trust_remote_code=True, + max_model_len=config["max_model_len"], + max_num_seqs=config["max_num_seqs"], + limit_mm_per_prompt=config["limit_mm_per_prompt"], + mm_encoder_attn_backend=mm_encoder_attn_backend, + hf_overrides=dummy_hf_overrides, + load_format="dummy", + ) + + engine_dict = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_dict) + + # Generate using chat + sampling_params = SamplingParams(**config["sampling_params"]) + outputs = llm.chat(messages=messages, sampling_params=sampling_params) + + # Validate + for o in outputs: + generated_text = o.outputs[0].text + validator = config.get("output_validator", lambda x: len(x) > 10) + assert validator(generated_text), ( + f"Validation failed for {config['model_name']}: {generated_text}" + ) + + +def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner): + """Video test with EVS (Efficient Video Sampling) handler.""" + for pruning_rate in config["video_params"]["pruning_rates"]: + num_frames = config["video_params"]["num_frames"] + + # Sample frames from video + sampled_vids = [ + sample_frames_from_video(asset.np_ndarrays, num_frames) + for asset in video_assets + ] + + # Build prompt and prepare video + prompt = build_qwen2_5_video_prompt() + prompts = [prompt] + videos = [sampled_vids[0]] + + # Run with vllm_runner context manager + with vllm_runner( + config["model_name"], + max_model_len=config["max_model_len"], + max_num_seqs=config["max_num_seqs"], + limit_mm_per_prompt=config["limit_mm_per_prompt"], + tensor_parallel_size=1, + video_pruning_rate=pruning_rate, + mm_encoder_attn_backend=mm_encoder_attn_backend, + hf_overrides=dummy_hf_overrides, + load_format="dummy", + **config["runner_kwargs"], + ) as vllm_model: + outputs = vllm_model.generate_greedy( + prompts, + config["sampling_params"]["max_tokens"], + videos=videos, + ) + + # Validate output + assert len(outputs) == 1, f"Expected 1 output, got {len(outputs)}" + output_ids, output_text = outputs[0] + assert len(output_ids) > 0, "Generated no output IDs" + assert len(output_text) > 0, "Generated empty text" + assert isinstance(output_text, str), ( + f"Output is not string: {type(output_text)}" + ) + + +# Main test function +@pytest.mark.parametrize("model_key", list(MODEL_CONFIGS.keys())) +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) +@pytest.mark.skip(reason="Broken test due to memory segmentation fault") +@create_new_process_for_each_test() +def test_vit_backend_functionality( + model_key: str, + mm_encoder_attn_backend: AttentionBackendEnum | None, + image_assets, + video_assets, + vllm_runner, + request, +): + """Test ViT attention backend functionality for multimodal models. + + This test validates that each model can successfully generate outputs + using different ViT attention backends. The test: + 1. Filters unsupported backends per model + 2. Applies appropriate GPU marks + 3. Routes to the correct test handler based on interface + 4. Validates output meets minimum requirements + """ + config = MODEL_CONFIGS[model_key] + + # Step 1: Backend filtering + if ( + "supported_backends" in config + and mm_encoder_attn_backend is not None + and mm_encoder_attn_backend not in config["supported_backends"] + ): + pytest.skip( + f"{model_key} does not support {mm_encoder_attn_backend} backend now." + ) + + # Step 2: Apply GPU marks dynamically + if "gpu_marks" in config: + for mark in config["gpu_marks"]: + request.applymarker(mark) + + # Step 3: Route to appropriate handler + if config.get("media_type") == "video": + run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner) + elif config["interface"] == "llm_chat": + run_llm_chat_test(config, mm_encoder_attn_backend, image_assets) + elif config["interface"] == "llm_generate": + run_llm_generate_test(config, mm_encoder_attn_backend, image_assets) + else: + raise ValueError(f"Unknown interface: {config['interface']}") diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py index 9e9087cb0fc4d..0eaef49e2395c 100644 --- a/tests/models/multimodal/generation/test_voxtral.py +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -9,7 +9,7 @@ from mistral_common.audio import Audio from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk from mistral_common.protocol.instruct.messages import UserMessage -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from ....conftest import AudioTestAssets from ....utils import RemoteOpenAIServer diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index eca2b61e37d53..b206995a9cecc 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -1,131 +1,146 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import Any + +import librosa import pytest +from transformers import AutoModelForSpeechSeq2Seq -from vllm import SamplingParams from vllm.assets.audio import AudioAsset +from vllm.platforms import current_platform -from ....conftest import VllmRunner +from ....conftest import HfRunner, PromptAudioInput, VllmRunner from ....utils import create_new_process_for_each_test, multi_gpu_test +from ...registry import HF_EXAMPLE_MODELS +from ...utils import check_logprobs_close -PROMPTS = [ - { - "prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", - "multi_modal_data": { - "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, - }, - }, - { # Test explicit encoder/decoder prompt - "encoder_prompt": { - "prompt": "", - "multi_modal_data": { - "audio": AudioAsset("winning_call").audio_and_sample_rate, - }, - }, - "decoder_prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", - }, -] +VLLM_PROMPT = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" +HF_PROMPT = "" +# Whisper expects 16kHz audio +WHISPER_SAMPLE_RATE = 16000 -EXPECTED = { - "openai/whisper-tiny": [ - " He has birth words I spoke in the original corner of that. And a" - " little piece of black coat poetry. Mary had a little sandwich," - " sweet, with white and snow. And everyone had it very went the last" - " would sure to go.", - " >> And the old one, fit John the way to Edgar Martinez. >> One more" - " to line down the field line for our base camp. Here comes joy. Here" - " is June and the third base. They're going to wave him in. The throw" - " to the plate will be late. The Mariners are going to play for the" - " American League Championship. I don't believe it. It just continues" - " by all five.", - ], - "openai/whisper-small": [ - " The first words I spoke in the original pornograph. A little piece" - " of practical poetry. Mary had a little lamb, its fleece was quite a" - " slow, and everywhere that Mary went the lamb was sure to go.", - " And the old one pitch on the way to Edgar Martinez one month. Here" - " comes joy. Here is Junior to third base. They're gonna wave him" - " in. The throw to the plate will be late. The Mariners are going to" - " play for the American League Championship. I don't believe it. It" - " just continues. My, oh my.", - ], - "openai/whisper-medium": [ - " The first words I spoke in the original phonograph, a little piece" - " of practical poetry. Mary had a little lamb, its fleece was quite as" - " slow, and everywhere that Mary went the lamb was sure to go.", - " And the 0-1 pitch on the way to Edgar Martinez swung on the line" - " down the left field line for Obeyshev. Here comes Joy. Here is" - " Jorgen at third base. They're going to wave him in. The throw to the" - " plate will be late. The Mariners are going to play for the American" - " League Championship. I don't believe it. It just continues. My, oh" - " my.", - ], - "openai/whisper-large-v3": [ - " The first words I spoke in the original phonograph, a little piece" - " of practical poetry. Mary had a little lamb, its feet were quite as" - " slow, and everywhere that Mary went, the lamb was sure to go.", - " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line." - " Now the left field line for a base hit. Here comes Joy. Here is" - " Junior to third base. They're going to wave him in. The throw to the" - " plate will be late. The Mariners are going to play for the American" - " League Championship. I don't believe it. It just continues. My, oh," - " my.", - ], - "openai/whisper-large-v3-turbo": [ - " The first words I spoke in the original phonograph, a little piece" - " of practical poetry. Mary had a little lamb, its streets were quite" - " as slow, and everywhere that Mary went the lamb was sure to go.", - " And the 0-1 pitch on the way to Edgar Martinez. Swung on the line" - " down the left field line for a base hit. Here comes Joy. Here is" - " Junior to third base. They're going to wave him in. The throw to the" - " plate will be late. The Mariners are going to play for the American" - " League Championship. I don't believe it. It just continues. My, oh," - " my.", - ], -} + +@pytest.fixture(autouse=True) +def use_spawn_for_whisper(monkeypatch): + """Whisper has issues with forked workers, use spawn instead.""" + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") def run_test( + hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], + inputs: Sequence[tuple[list[str], list[str], PromptAudioInput]], model: str, *, + max_model_len: int, + dtype: str, + max_tokens: int, + num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: str | None = None, + enforce_eager: bool = True, ) -> None: - prompt_list = PROMPTS * 10 - expected_list = EXPECTED[model] * 10 + """Inference result should be the same between hf and vllm. + All the audio fixtures for the test are from AudioAsset. + For huggingface runner, we provide the audio as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + """ with vllm_runner( model, - dtype="half", - max_model_len=448, + dtype=dtype, + max_model_len=max_model_len, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, + limit_mm_per_prompt={"audio": 2}, + enforce_eager=enforce_eager, + disable_custom_all_reduce=True, ) as vllm_model: - llm = vllm_model.llm + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs( + vllm_prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=audios, + ) + for vllm_prompts, _, audios in inputs + ] - sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - max_tokens=200, + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit( + hf_prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=audios, + ) + for _, hf_prompts, audios in inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", ) - outputs = llm.generate(prompt_list, sampling_params) - for output, expected in zip(outputs, expected_list): - print(output.outputs[0].text) - assert output.outputs[0].text == expected +@pytest.fixture +def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]: + audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] + inputs = [] + for asset in audio_assets: + audio, orig_sr = asset.audio_and_sample_rate + # Resample to Whisper's expected sample rate (16kHz) + if orig_sr != WHISPER_SAMPLE_RATE: + audio = librosa.resample( + audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE + ) + # vLLM prompts, HF prompts, audio inputs + inputs.append(([VLLM_PROMPT], [HF_PROMPT], [(audio, WHISPER_SAMPLE_RATE)])) + return inputs + + +def check_model_available(model: str) -> None: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") @pytest.mark.core_model +@pytest.mark.cpu_model @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) -@create_new_process_for_each_test() -def test_models(vllm_runner, model) -> None: +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@create_new_process_for_each_test("spawn") +def test_models( + hf_runner, + vllm_runner, + model: str, + dtype: str, + num_logprobs: int, + input_audios, + enforce_eager: bool, +) -> None: + check_model_available(model) + if current_platform.is_cpu() and not enforce_eager: + pytest.skip("Skipping test for CPU with non-eager mode") run_test( + hf_runner, vllm_runner, + input_audios, model, + dtype=dtype, + max_model_len=448, + max_tokens=200, + num_logprobs=num_logprobs, tensor_parallel_size=1, + enforce_eager=enforce_eager, ) @@ -133,15 +148,31 @@ def test_models(vllm_runner, model) -> None: @pytest.mark.core_model @pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@create_new_process_for_each_test() +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [200]) +@pytest.mark.parametrize("num_logprobs", [5]) +@create_new_process_for_each_test("spawn") def test_models_distributed( + hf_runner, vllm_runner, - model, - distributed_executor_backend, + model: str, + distributed_executor_backend: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + input_audios, ) -> None: + check_model_available(model) run_test( + hf_runner, vllm_runner, + input_audios, model, + dtype=dtype, + max_model_len=448, + max_tokens=max_tokens, + num_logprobs=num_logprobs, tensor_parallel_size=2, distributed_executor_backend=distributed_executor_backend, + enforce_eager=False, ) diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index d42150bcbf672..116eead7a70ad 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -62,6 +62,65 @@ def get_filtered_test_settings( return matching_tests +def get_model_type_cases( + model_type: str, + test_info: VLMTestInfo, + test_type: VLMTestType, +): + # Ensure that something is wrapped as an iterable it's not already + ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,) + + # This is essentially the same as nesting a bunch of mark.parametrize + # decorators, but we do it programmatically to allow overrides for on + # a per-model basis, while still being able to execute each of these + # as individual test cases in pytest. + iter_kwargs = OrderedDict( + [ + ("model", ensure_wrapped(test_info.models)), + ("max_tokens", ensure_wrapped(test_info.max_tokens)), + ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), + ("dtype", ensure_wrapped(test_info.dtype)), + ( + "distributed_executor_backend", + ensure_wrapped(test_info.distributed_executor_backend), + ), + ] + ) + + # num_frames is video only + if test_type == VLMTestType.VIDEO: + iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames) + iter_kwargs["needs_video_metadata"] = ensure_wrapped( + test_info.needs_video_metadata + ) + + # No sizes passed for custom inputs, since inputs are directly provided + if test_type not in ( + VLMTestType.CUSTOM_INPUTS, + VLMTestType.AUDIO, + ): + wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) + if wrapped_sizes is None: + raise ValueError(f"Sizes must be set for test type {test_type}") + iter_kwargs["size_wrapper"] = wrapped_sizes + + # Otherwise expand the custom test options instead + elif test_type == VLMTestType.CUSTOM_INPUTS: + if test_info.custom_test_opts is None: + raise ValueError("Test has type CUSTOM_INPUTS, but none given") + iter_kwargs["custom_test_opts"] = test_info.custom_test_opts + + # Wrap all model cases in a pytest parameter & pass marks through + return [ + pytest.param( + model_type, + ExpandableVLMTestArgs(**{k: v for k, v in zip(iter_kwargs.keys(), case)}), + marks=test_info.marks if test_info.marks is not None else [], + ) + for case in list(itertools.product(*iter_kwargs.values())) + ] + + def get_parametrized_options( test_settings: dict[str, VLMTestInfo], test_type: VLMTestType, @@ -76,64 +135,11 @@ def get_parametrized_options( test_settings, test_type, create_new_process_for_each_test ) - # Ensure that something is wrapped as an iterable it's not already - ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,) - - def get_model_type_cases(model_type: str, test_info: VLMTestInfo): - # This is essentially the same as nesting a bunch of mark.parametrize - # decorators, but we do it programmatically to allow overrides for on - # a per-model basis, while still being able to execute each of these - # as individual test cases in pytest. - iter_kwargs = OrderedDict( - [ - ("model", ensure_wrapped(test_info.models)), - ("max_tokens", ensure_wrapped(test_info.max_tokens)), - ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), - ("dtype", ensure_wrapped(test_info.dtype)), - ( - "distributed_executor_backend", - ensure_wrapped(test_info.distributed_executor_backend), - ), - ] - ) - - # num_frames is video only - if test_type == VLMTestType.VIDEO: - iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames) - iter_kwargs["needs_video_metadata"] = ensure_wrapped( - test_info.needs_video_metadata - ) - - # No sizes passed for custom inputs, since inputs are directly provided - if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO): - wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) - if wrapped_sizes is None: - raise ValueError(f"Sizes must be set for test type {test_type}") - iter_kwargs["size_wrapper"] = wrapped_sizes - - # Otherwise expand the custom test options instead - elif test_type == VLMTestType.CUSTOM_INPUTS: - if test_info.custom_test_opts is None: - raise ValueError("Test has type CUSTOM_INPUTS, but none given") - iter_kwargs["custom_test_opts"] = test_info.custom_test_opts - - # Wrap all model cases in a pytest parameter & pass marks through - return [ - pytest.param( - model_type, - ExpandableVLMTestArgs( - **{k: v for k, v in zip(iter_kwargs.keys(), case)} - ), - marks=test_info.marks if test_info.marks is not None else [], - ) - for case in list(itertools.product(*iter_kwargs.values())) - ] - # Get a list per model type, where each entry contains a tuple of all of # that model type's cases, then flatten them into the top level so that # we can consume them in one mark.parametrize call. cases_by_model_type = [ - get_model_type_cases(model_type, test_info) + get_model_type_cases(model_type, test_info, test_type) for model_type, test_info in matching_tests.items() ] return list(itertools.chain(*cases_by_model_type)) diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index 8c9c390911bdc..84109233685bb 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -140,7 +140,7 @@ def video_with_metadata_glm4_1v(): metadata = VIDEO_ASSETS[0].metadata question = "Describe the video." video_prompt = "<|begin_of_video|><|video|><|end_of_video|>" - formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n" + formatted_prompt = f"[gMASK]<|user|>\n{video_prompt}{question}<|assistant|>\n" scales = [0.1, 0.2, 0.25] video_input = [ diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 87cd5c3cd3554..b2c62fbd119cc 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -25,6 +25,7 @@ from transformers import ( from transformers.video_utils import VideoMetadata from vllm.logprobs import SampleLogprobs +from vllm.platforms import current_platform from vllm.utils.collection_utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets @@ -366,6 +367,40 @@ def gemma3_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOut def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for GLM4V.""" + if current_platform.is_rocm(): + import types + + config = hf_model.model.config + if hasattr(config, "num_layers") and not hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = config.num_layers + config.output_hidden_states = True + + def patched_prepare_cache( + self, generation_config, model_kwargs, *args, **kwargs + ): + model_kwargs["past_key_values"] = None + model_kwargs["use_cache"] = False + return model_kwargs + + hf_model.model._prepare_cache_for_generation = types.MethodType( + patched_prepare_cache, hf_model.model + ) + original_generate = hf_model.model.generate + + def patched_generate(*args, **kwargs): + kwargs["output_hidden_states"] = True + kwargs["return_dict_in_generate"] = True + return original_generate(*args, **kwargs) + + hf_model.model.generate = patched_generate + original_forward = hf_model.model.forward + + def patched_forward(*args, **kwargs): + kwargs["output_hidden_states"] = True + return original_forward(*args, **kwargs) + + hf_model.model.forward = patched_forward + hf_processor = hf_model.processor def processor(*args, text="", images=None, **kwargs): @@ -406,7 +441,15 @@ def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: if videos is not None and is_list_of(videos, tuple): # If videos is a list of tuples, we assume each tuple contains # (video_array, metadata) as in the case of GLM4.1V. - video_metadata = [[VideoMetadata(**video[1])] for video in videos] + # Filter out 'do_sample_frames' as it's not a valid VideoMetadata arg + video_metadata = [ + [ + VideoMetadata( + **{k: v for k, v in video[1].items() if k != "do_sample_frames"} + ) + ] + for video in videos + ] videos = [[video[0]] for video in videos] else: video_metadata = None diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 0c03c84497125..ae2f754813590 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -50,8 +50,8 @@ MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PL VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?" -IMAGE_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] -EMBEDDING_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0)] +IMAGE_SIZE_FACTORS = [(1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] +EMBEDDING_SIZE_FACTORS = [(1.0,), (1.0, 1.0, 1.0)] RunnerOutput = tuple[list[int], str, SampleLogprobs | None] diff --git a/tests/models/multimodal/pooling/conftest.py b/tests/models/multimodal/pooling/conftest.py new file mode 100644 index 0000000000000..c5f40cb42ca2a --- /dev/null +++ b/tests/models/multimodal/pooling/conftest.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pytest configuration for vLLM pooling tests.""" + +import os +import warnings + +from vllm.platforms import current_platform + + +def pytest_collection_modifyitems(config, items): + """Set FLEX_ATTENTION backend for SigLIP tests on ROCm.""" + if not current_platform.is_rocm(): + return + + siglip_tests = [item for item in items if "test_siglip" in item.nodeid] + + if siglip_tests: + os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" + warnings.warn( + "ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests", + UserWarning, + stacklevel=1, + ) diff --git a/tests/models/multimodal/pooling/test_siglip.py b/tests/models/multimodal/pooling/test_siglip.py index 92ae115a19831..72886cbf7f323 100644 --- a/tests/models/multimodal/pooling/test_siglip.py +++ b/tests/models/multimodal/pooling/test_siglip.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import pytest from transformers import SiglipModel @@ -35,7 +37,11 @@ def _run_test( model: str, *, dtype: str, + tokenization_kwargs: dict[str, Any] | None = None, ) -> None: + if tokenization_kwargs is None: + tokenization_kwargs = {} + with vllm_runner( model, runner="pooling", @@ -44,10 +50,14 @@ def _run_test( max_model_len=64, gpu_memory_utilization=0.7, ) as vllm_model: - vllm_outputs = vllm_model.embed(input_texts, images=input_images) + vllm_outputs = vllm_model.embed( + input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs + ) with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model: - all_inputs = hf_model.get_inputs(input_texts, images=input_images) + all_inputs = hf_model.get_inputs( + input_texts, images=input_images, tokenization_kwargs=tokenization_kwargs + ) all_outputs = [] for inputs in all_inputs: @@ -94,6 +104,10 @@ def test_models_text( input_images, # type: ignore model, dtype=dtype, + tokenization_kwargs={ + "padding": "max_length", + "max_length": 64, + }, # siglip2 was trained with this padding setting. ) diff --git a/tests/models/multimodal/processing/test_audioflamingo3.py b/tests/models/multimodal/processing/test_audioflamingo3.py new file mode 100644 index 0000000000000..d7c00516ffead --- /dev/null +++ b/tests/models/multimodal/processing/test_audioflamingo3.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +from transformers import PretrainedConfig + +from tests.models.registry import HF_EXAMPLE_MODELS + + +class MockAudioFlamingo3Config(PretrainedConfig): + model_type = "audioflamingo3" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.audio_config = PretrainedConfig() + self.text_config = PretrainedConfig() + + +class MockAudioFlamingo3Processor: + def __init__(self): + self.audio_token = "<sound>" + self.audio_token_id = 12345 + self.feature_extractor = MockFeatureExtractor() + + def __call__(self, text=None, audios=None, **kwargs): + return {"input_ids": [1, 2, 3], "input_features": [np.zeros((3000, 80))]} + + +class MockFeatureExtractor: + def __init__(self): + self.sampling_rate = 16000 + self.chunk_length = 30 + + +@pytest.fixture +def mock_ctx(): + config = MockAudioFlamingo3Config() + + ctx = MagicMock() + ctx.get_hf_config.return_value = config + ctx.get_hf_processor.return_value = MockAudioFlamingo3Processor() + ctx.model_config.hf_config = config + return ctx + + +@pytest.fixture(autouse=True) +def check_transformers_version(): + # Check if the model is supported by the current transformers version + model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration") + model_info.check_transformers_version(on_fail="skip") + + +def test_audio_chunk_counting(mock_ctx): + from vllm.model_executor.models.audioflamingo3 import ( + AudioFlamingo3DummyInputsBuilder, + AudioFlamingo3MultiModalProcessor, + AudioFlamingo3ProcessingInfo, + ) + + info = AudioFlamingo3ProcessingInfo(mock_ctx) + processor = AudioFlamingo3MultiModalProcessor( + info, AudioFlamingo3DummyInputsBuilder(info) + ) + + sr = 16000 + audio_1 = np.zeros(30 * sr) + audio_2 = np.zeros(45 * sr) + + mm_data = {"audio": [audio_1, audio_2]} + prompt = "<|user|>Listen.<|end|>" + + from vllm.multimodal.processing import BaseMultiModalProcessor + + def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs): + return {"input_ids": [1, 2, 3], "input_features": torch.randn(1, 80, 3000)} + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call) + + processed = processor._call_hf_processor(prompt, mm_data, {}, {}) + + chunk_counts = processed["chunk_counts"] + + assert chunk_counts[0].item() == 1 + assert chunk_counts[1].item() == 2 + assert len(chunk_counts) == 2 + + +def test_dummy_data_generation(mock_ctx): + from vllm.model_executor.models.audioflamingo3 import ( + AudioFlamingo3DummyInputsBuilder, + AudioFlamingo3ProcessingInfo, + ) + + info = AudioFlamingo3ProcessingInfo(mock_ctx) + builder = AudioFlamingo3DummyInputsBuilder(info) + + mm_counts = {"audio": 2} + dummy_data = builder.get_dummy_mm_data(100, mm_counts, None) + + assert "audio" in dummy_data + assert len(dummy_data["audio"]) == 2 + + expected_len = 600 * 16000 + assert len(dummy_data["audio"][0]) == expected_len diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index c39e522100901..67861ebfc44e4 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -20,13 +20,10 @@ from vllm.config.multimodal import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.cache import MultiModalProcessorOnlyCache -from vllm.multimodal.inputs import MultiModalInputs +from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext -from vllm.tokenizers import MistralTokenizer -from vllm.transformers_utils.tokenizer import ( - cached_tokenizer_from_config, - encode_tokens, -) +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config +from vllm.tokenizers.mistral import MistralTokenizer from ....multimodal.utils import random_audio, random_image, random_video from ...registry import ( @@ -154,7 +151,7 @@ def get_text_token_prompts( mm_data: MultiModalDataDict, ): dummy_inputs = processor.dummy_inputs - tokenizer = processor.info.get_tokenizer() + tokenizer: TokenizerLike = processor.info.get_tokenizer() model_config = processor.info.ctx.model_config model_type = model_config.hf_config.model_type @@ -191,10 +188,9 @@ def get_text_token_prompts( assert isinstance(inputs.prompt, str) text_prompt = inputs.prompt - token_prompt = encode_tokens( - tokenizer, + token_prompt = tokenizer.encode( text_prompt, - add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type), + add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True), ) return text_prompt, token_prompt @@ -397,28 +393,6 @@ def test_processing_correctness( ) -# Phi4MultimodalForCausalLM share same model repo with original format -# Phi4MMForCausalLM, so we add it as a separate test case -# Remove this test after conversion PR merged: -# https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/70 -@pytest.mark.parametrize("model_arch", ["Phi4MultimodalForCausalLM"]) -@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) -@pytest.mark.parametrize("num_batches", [32]) -@pytest.mark.parametrize("simplify_rate", [1.0]) -def test_processing_correctness_phi4_multimodal( - model_arch: str, - hit_rate: float, - num_batches: int, - simplify_rate: float, -): - _test_processing_correctness( - model_arch, - hit_rate=hit_rate, - num_batches=num_batches, - simplify_rate=simplify_rate, - ) - - def _assert_inputs_equal( a: MultiModalInputs, b: MultiModalInputs, @@ -441,4 +415,4 @@ def _assert_inputs_equal( a_data.pop(key, None) b_data.pop(key, None) - assert a_data == b_data, msg + assert batched_tensors_equal(a_data, b_data), msg diff --git a/tests/models/multimodal/processing/test_gemma3.py b/tests/models/multimodal/processing/test_gemma3.py new file mode 100644 index 0000000000000..32a459ee8cdfb --- /dev/null +++ b/tests/models/multimodal/processing/test_gemma3.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.multimodal import MULTIMODAL_REGISTRY + +from ....conftest import ImageTestAssets +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"]) +def test_get_image_size_with_most_features( + image_assets: ImageTestAssets, model_id: str +): + ctx = build_model_context( + model_id, + mm_processor_kwargs={"do_pan_and_scan": True}, + limit_mm_per_prompt={"image": 1}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + + hf_processor_mm_kwargs: dict[str, object] = {} + hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) + + max_image_size = processor.info.get_image_size_with_most_features() + max_tokens = processor.info.get_num_image_tokens( + image_width=max_image_size.width, + image_height=max_image_size.height, + processor=hf_processor, + ) + + prompt = "<start_of_image>" + image_seq_length = hf_processor.image_seq_length + + for asset in image_assets: + mm_data = {"image": [asset.pil_image]} + processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + mm_kwargs_data = processed_inputs["mm_kwargs"].get_data() + num_patches_tensor = mm_kwargs_data["num_patches"] + tokens = int(num_patches_tensor.item()) * image_seq_length + assert tokens <= max_tokens diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index 553a5f719bd35..51071c93531de 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -5,6 +5,7 @@ import pytest from vllm.assets.video import VideoAsset from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import batched_tensors_equal from vllm.multimodal.video import OpenCVDynamicVideoBackend, OpenCVVideoBackend from ...utils import build_model_context @@ -103,7 +104,7 @@ def test_video_loader_consistency( dynamic_outputs = processor.apply(prompt, dynamic_mm_data, hf_processor_mm_kwargs) assert static_outputs["prompt_token_ids"] == dynamic_outputs["prompt_token_ids"] - assert ( - static_outputs["mm_kwargs"].get_data() - == dynamic_outputs["mm_kwargs"].get_data() + assert batched_tensors_equal( + static_outputs["mm_kwargs"].get_data(), + dynamic_outputs["mm_kwargs"].get_data(), ) diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index 4c0791ea3cece..b73246b68b36a 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -5,7 +5,6 @@ import pytest from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.transformers_utils.tokenizer import encode_tokens from ....conftest import ImageTestAssets from ...utils import build_model_context @@ -48,7 +47,7 @@ def test_processor_override( ] } if tokenized_prompt: - prompt = encode_tokens(tokenizer, prompt) + prompt = tokenizer.encode(prompt) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) mm_data = processed_inputs["mm_kwargs"].get_data() diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index 9f4cdb6789b2c..20beaa6011b8f 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -53,3 +53,38 @@ def test_processor_override( assert img_tok_count == expected_toks_per_img * num_imgs assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs assert pixel_shape[1] == expected_pixels_shape[1] + + +@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) +@pytest.mark.parametrize("max_pixels", [1280 * 28 * 28, 1283 * 28 * 28]) +def test_get_image_size_with_most_features( + image_assets: ImageTestAssets, + model_id: str, + max_pixels: int, +): + ctx = build_model_context( + model_id, + mm_processor_kwargs={"max_pixels": max_pixels}, + limit_mm_per_prompt={"image": 1}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + + hf_processor_mm_kwargs: dict[str, object] = {} + hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) + merge_size = processor.info.get_hf_config().vision_config.spatial_merge_size + + max_image_size = processor.info.get_image_size_with_most_features() + max_tokens = processor.info.get_num_image_tokens( + image_width=max_image_size.width, + image_height=max_image_size.height, + image_processor=hf_processor.image_processor, + ) + + prompt = "<|vision_start|><|image_pad|><|vision_end|>" + for asset in image_assets: + mm_data = {"image": [asset.pil_image]} + processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + grid_thw = processed_inputs["mm_kwargs"].get_data()["image_grid_thw"].tolist() + t, h, w = grid_thw[0] + tokens = (t * h * w) // (merge_size**2) + assert tokens < max_tokens diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 66a3fbe11b6a5..cb875436857cf 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -8,6 +8,7 @@ from typing import Any, TypeAlias import numpy as np import pytest +import torch import torch.nn as nn from PIL import Image @@ -31,10 +32,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.platforms import current_platform -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config from vllm.utils.collection_utils import is_list_of from vllm.utils.torch_utils import set_default_torch_dtype +from ....utils import create_new_process_for_each_test from ...registry import HF_EXAMPLE_MODELS from ...utils import dummy_hf_overrides from .test_common import get_model_ids_to_test, get_text_token_prompts @@ -130,13 +132,13 @@ def create_batched_mm_kwargs( hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs, )["mm_kwargs"].require_data() - items = [item for modality in supported_mm_limits for item in mm_kwargs[modality]] + return group_mm_kwargs_by_modality( - items, - merge_by_field_config=model_cls.merge_by_field_config, + [item for modality in supported_mm_limits for item in mm_kwargs[modality]] ) +# TODO(Isotr0py): Don't initalize model during test @contextmanager def initialize_dummy_model( model_cls: type[nn.Module], @@ -151,16 +153,21 @@ def initialize_dummy_model( backend="nccl", ) initialize_model_parallel(tensor_model_parallel_size=1) + + current_device = torch.get_default_device() vllm_config = VllmConfig(model_config=model_config) with set_current_vllm_config(vllm_config=vllm_config): with set_default_torch_dtype(model_config.dtype): + torch.set_default_device(current_platform.device_type) model = model_cls(vllm_config=vllm_config) + torch.set_default_device(current_device) yield model del model cleanup_dist_env_and_memory() +@create_new_process_for_each_test() @pytest.mark.parametrize("model_id", get_model_ids_to_test()) def test_model_tensor_schema(model_id: str): model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) diff --git a/tests/models/quantization/test_gguf.py b/tests/models/quantization/test_gguf.py index 3b9597507ac1b..064ca94f3cbac 100644 --- a/tests/models/quantization/test_gguf.py +++ b/tests/models/quantization/test_gguf.py @@ -47,6 +47,12 @@ QWEN2_CONFIG = GGUFTestConfig( gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf", ) +QWEN3_CONFIG = GGUFTestConfig( + original_model="Qwen/Qwen3-0.6B", + gguf_repo="unsloth/Qwen3-0.6B-GGUF", + gguf_filename="Qwen3-0.6B-BF16.gguf", +) + PHI3_CONFIG = GGUFTestConfig( original_model="microsoft/Phi-3.5-mini-instruct", gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF", @@ -87,6 +93,7 @@ GEMMA3_CONFIG = GGUFTestConfig( MODELS = [ # LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458 QWEN2_CONFIG, + QWEN3_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG, diff --git a/tests/models/registry.py b/tests/models/registry.py index d90f3a4d4f781..c5d72b5d581b9 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -173,10 +173,7 @@ class _HfExamplesInfo: _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] - "AfmoeForCausalLM": _HfExamplesInfo( - "arcee-ai/Trinity-Nano", - is_available_online=False, - ), + "AfmoeForCausalLM": _HfExamplesInfo("arcee-ai/Trinity-Nano-Preview"), "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"), "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True), @@ -211,10 +208,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True, ), "CohereForCausalLM": _HfExamplesInfo( - "CohereForAI/c4ai-command-r-v01", trust_remote_code=True + "CohereLabs/c4ai-command-r-v01", trust_remote_code=True ), "Cohere2ForCausalLM": _HfExamplesInfo( - "CohereForAI/c4ai-command-r7b-12-2024", + "CohereLabs/c4ai-command-r7b-12-2024", trust_remote_code=True, ), "CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"), @@ -358,6 +355,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True, ), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), + "MistralLarge3ForCausalLM": _HfExamplesInfo( + "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4" + ), "MixtralForCausalLM": _HfExamplesInfo( "mistralai/Mixtral-8x7B-Instruct-v0.1", {"tiny": "TitanML/tiny-mixtral"}, @@ -413,7 +413,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { trust_remote_code=True, ), "Qwen2ForCausalLM": _HfExamplesInfo( - "Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"} + "Qwen/Qwen2-0.5B-Instruct", + extras={ + "2.5": "Qwen/Qwen2.5-0.5B-Instruct", + "2.5-1.5B": "Qwen/Qwen2.5-1.5B-Instruct", + }, ), "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), @@ -569,12 +573,17 @@ _AUTOMATIC_CONVERTED_MODELS = { "Qwen3ForSequenceClassification": _HfExamplesInfo( "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" ), + "Qwen3ForTokenClassification": _HfExamplesInfo("bd2lcco/Qwen3-0.6B-finetuned"), } _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), - "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), + "AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo( + "nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev" + ), + "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"), + "BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"), "BeeForConditionalGeneration": _HfExamplesInfo( "Open-Bee/Bee-8B-RL", trust_remote_code=True, @@ -631,7 +640,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ), "HunYuanVLForConditionalGeneration": _HfExamplesInfo( "tencent/HunyuanOCR", - is_available_online=False, + hf_overrides={"num_experts": 0}, ), "Idefics3ForConditionalGeneration": _HfExamplesInfo( "HuggingFaceM4/Idefics3-8B-Llama3", @@ -664,10 +673,13 @@ _MULTIMODAL_EXAMPLE_MODELS = { "moonshotai/Kimi-VL-A3B-Instruct", extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, trust_remote_code=True, + max_transformers_version="4.53.3", + transformers_version_reason="HF model uses deprecated transformers API " + "(PytorchGELUTanh, DynamicCache.seen_tokens, and more). See: " + "https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31", ), "LightOnOCRForConditionalGeneration": _HfExamplesInfo( - "lightonai/LightOnOCR-1B", - is_available_online=False, + "lightonai/LightOnOCR-1B-1025" ), "Llama4ForConditionalGeneration": _HfExamplesInfo( "meta-llama/Llama-4-Scout-17B-16E-Instruct", @@ -764,12 +776,12 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Phi4MMForCausalLM": _HfExamplesInfo( "microsoft/Phi-4-multimodal-instruct", trust_remote_code=True ), - "Phi4MultimodalForCausalLM": _HfExamplesInfo( - "microsoft/Phi-4-multimodal-instruct", - revision="refs/pr/70", - ), "PixtralForConditionalGeneration": _HfExamplesInfo( "mistralai/Pixtral-12B-2409", + extras={ + "mistral-large-3": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4", + "ministral-3": "mistralai/Ministral-3-3B-Instruct-2512", + }, tokenizer_mode="mistral", ), "QwenVLForConditionalGeneration": _HfExamplesInfo( @@ -822,7 +834,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), "Tarsier2ForConditionalGeneration": _HfExamplesInfo( "omni-research/Tarsier2-Recap-7b", - hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}, + hf_overrides={ + "architectures": ["Tarsier2ForConditionalGeneration"], + "model_type": "tarsier2", + }, ), "VoxtralForConditionalGeneration": _HfExamplesInfo( "mistralai/Voxtral-Mini-3B-2507", @@ -830,7 +845,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { is_available_online=False, ), # [Encoder-decoder] - "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), + "WhisperForConditionalGeneration": _HfExamplesInfo( + "openai/whisper-large-v3-turbo", + extras={"v3": "openai/whisper-large-v3"}, + ), # [Cross-encoder] "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), } @@ -870,6 +888,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { use_original_num_layers=True, max_model_len=10240, ), + "EagleMistralLarge3ForCausalLM": _HfExamplesInfo( + "mistralai/Mistral-Large-3-675B-Instruct-2512", + speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle", + # TODO: revert once figuring out OOM in CI + is_available_online=False, + ), "LlamaForCausalLMEagle3": _HfExamplesInfo( "Qwen/Qwen3-8B", trust_remote_code=True, diff --git a/tests/models/test_gguf_download.py b/tests/models/test_gguf_download.py index 155768ac9bff7..b1674cdf77178 100644 --- a/tests/models/test_gguf_download.py +++ b/tests/models/test_gguf_download.py @@ -203,7 +203,7 @@ class TestGGUFModelLoader: @patch("vllm.config.model.get_hf_image_processor_config", return_value=None) @patch("vllm.config.model.get_config") @patch("vllm.config.model.is_gguf", return_value=False) - @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False) + @patch("vllm.transformers_utils.gguf_utils.check_gguf_file", return_value=False) @patch("os.path.isfile", return_value=False) def test_prepare_weights_invalid_format( self, diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 9017a0fd91407..a089696e10ffc 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -13,7 +13,6 @@ from vllm.model_executor.models import ( ) from vllm.model_executor.models.adapters import ( as_embedding_model, - as_reward_model, as_seq_cls_model, ) from vllm.model_executor.models.registry import ( @@ -46,7 +45,6 @@ def test_registry_imports(model_arch): # All vLLM models should be convertible to a pooling model assert is_pooling_model(as_seq_cls_model(model_cls)) assert is_pooling_model(as_embedding_model(model_cls)) - assert is_pooling_model(as_reward_model(model_cls)) if model_arch in _MULTIMODAL_MODELS: assert supports_multimodal(model_cls) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index ae5befd2c00b7..c642ff1ee4384 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -59,10 +59,6 @@ def check_implementation( ) -@pytest.mark.skipif( - current_platform.is_rocm(), - reason="Llama-3.2-1B-Instruct, Ilama-3.2-1B produce memory access fault.", -) @pytest.mark.parametrize( "model,model_impl", [ diff --git a/tests/models/utils.py b/tests/models/utils.py index 9843887a13204..d84b4b820533e 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -13,7 +13,7 @@ from transformers import PretrainedConfig from vllm.config.model import ModelConfig, ModelDType, RunnerOption from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.multimodal.processing import InputProcessingContext -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config from .. import ci_envs from .registry import HF_EXAMPLE_MODELS diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 531674c30f55f..e641b1111abaf 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing as mp import numpy as np import pytest @@ -8,9 +9,16 @@ import torch from vllm.config import ModelConfig, ParallelConfig, VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import ( + BaseMultiModalProcessorCache, + BaseMultiModalReceiverCache, MultiModalCache, + MultiModalProcessorCacheInItem, MultiModalProcessorCacheItem, MultiModalProcessorCacheItemMetadata, + MultiModalProcessorSenderCache, + MultiModalReceiverCache, + ShmObjectStoreReceiverCache, + ShmObjectStoreSenderCache, engine_receiver_cache_from_config, processor_cache_from_config, ) @@ -22,6 +30,7 @@ from vllm.multimodal.inputs import ( MultiModalSharedField, ) from vllm.multimodal.processing import PromptInsertion +from vllm.utils.mem_constants import GiB_bytes, MiB_bytes pytestmark = pytest.mark.cpu_test @@ -42,7 +51,7 @@ def _dummy_elem( modality=modality, key=key, data=data, - field=MultiModalSharedField(1), + field=MultiModalSharedField(batch_size=1), ) @@ -76,12 +85,6 @@ def _dummy_items( (_dummy_item("a", {"a1": 100}), 100), (_dummy_item("a", {"a1": 100, "a2": 110}), 210), (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 - ( - _dummy_items( - {"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}} - ).get_data(), - 460, - ), # noqa: E501 ], ) def test_cache_item_size(item, expected_size): @@ -98,6 +101,9 @@ def test_cache_item_size(item, expected_size): cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update]) assert cache.currsize == expected_size + cache[""] = item.get_data() + assert cache.currsize == expected_size + def _create_vllm_config( *, @@ -144,8 +150,7 @@ def _compare_caches( MultiModalHasher.hash_kwargs(item=item.get_data()) for item in all_items ] - # Should not be used since there is nothing to convert to text - prompt_update = PromptInsertion("dummy", "target", "insertion") + prompt_update = PromptInsertion("dummy", "target", "insertion").resolve(0) for it in range(n_iter): num_items_to_select = rng.randint(0, max_items_per_iter) @@ -159,10 +164,11 @@ def _compare_caches( else: for _ in range(is_cached_calls_per_iter): cache_0_p0.is_cached(selected_hashes) + cache_0_p0_out = [ item for item, _ in cache_0_p0.get_and_update( - [(item, prompt_update.content) for item in selected_items], + [(item, [prompt_update]) for item in selected_items], selected_hashes, ) ] @@ -172,10 +178,11 @@ def _compare_caches( else: for _ in range(is_cached_calls_per_iter): cache_1_p0.is_cached(selected_hashes) + cache_1_p0_out = [ item for item, _ in cache_1_p0.get_and_update( - [(item, prompt_update.content) for item in selected_items], + [(item, [prompt_update]) for item in selected_items], selected_hashes, ) ] @@ -225,3 +232,289 @@ def test_ipc_enable_disable_consistency(is_cached_calls_per_iter): vllm_config_ipc_enabled, is_cached_calls_per_iter=is_cached_calls_per_iter, ) + + +def _run_test_cache_eviction_lru( + p0_cache: BaseMultiModalProcessorCache, + p1_cache: BaseMultiModalReceiverCache, + base_item_size: int, +): + request1_hashes = [ + "image_A", + "image_B", + "image_C", + ] + request1_items = { + h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size) + for h in request1_hashes + } + + request2_hashes = ["image_D", "image_E", "image_A", "image_C"] + request2_items = { + h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size) + for h in request2_hashes + } + + ########################## + # STEP 1: Request 1 send + ########################## + sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes) + # Cache is empty + assert sender_is_cached_item_req1 == [False, False, False] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request1_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 1 for P0 Cache + ########################### + item_tuple: MultiModalProcessorCacheInItem + for i, h in enumerate(request1_hashes): + # Use precomputed cache state + is_cached = sender_is_cached_item_req1[i] + item_tuple = (request1_items[h], []) if not is_cached else None + print(f"Request 1: key={h} | cached={is_cached}") + + p0_cache.get_and_update_item(item_tuple, h) + + ########################### + # Process request 1 for P1 Cache + ########################### + # Touch all mm hash for P1 Cache before process + for mm_hash in request1_hashes: + p1_cache.touch_receiver_cache_item(mm_hash) + + for h in request1_hashes: + p1_cache.get_and_update_item(request1_items[h], h) + + expected_hashes = ["image_A", "image_B", "image_C"] + assert list(p0_cache._cache.order) == expected_hashes + + ########################## + # STEP 2: Request 2 send + ########################## + sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes) + assert sender_is_cached_item_req2 == [False, False, True, True] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request2_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 2 for P0 Cache + ########################### + for i, h in enumerate(request2_hashes): + # Use precomputed cache state again + is_cached = sender_is_cached_item_req2[i] + item_tuple = (request2_items[h], []) if not is_cached else None + print(f"Request 2: key={h} | cached={is_cached}") + + p0_cache.get_and_update_item(item_tuple, h) + + ########################### + # Process request 2 for P1 Cache + ########################### + + # Touch all mm hash for P1 Cache before process + for mm_hash in request2_hashes: + p1_cache.touch_receiver_cache_item(mm_hash) + + for h in request2_hashes: + p1_cache.get_and_update_item(request2_items[h], h) + + expected_hashes = ["image_D", "image_E", "image_A", "image_C"] + assert list(p0_cache._cache.order) == expected_hashes + + +def test_cache_eviction_lru_cache(): + model_config = ModelConfig( + model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + mm_processor_cache_gb=6 / GiB_bytes, + ) + sender_cache = MultiModalProcessorSenderCache(model_config) + receiver_cache = MultiModalReceiverCache(model_config) + + _run_test_cache_eviction_lru(sender_cache, receiver_cache, base_item_size=1) + + +# This test verifies shared-memory cache eviction behavior across processor (p0) +# and receiver (p1) caches. +# Flow summary: +# 1. Request 1 adds images A, B, C — completely filling the cache. +# 2. Request 2 tries to add image_G and image_A, but image_G cannot be added because +# cache is full and A is protected from eviction — cache remains unchanged. +# 3. Request 3 adds image_G, image_H, image_I and image_B +# this time, image_A is evicted, freeing 5MB space +# and image_G, image_H successfully fits, +# image_B is protected from eviction then image_i cannot be added. +# This proving normal eviction and reuse behavior. +def _run_test_cache_eviction_shm( + p0_cache: BaseMultiModalProcessorCache, + p1_cache: BaseMultiModalReceiverCache, + base_item_size: int, +): + request1_hashes = ["image_A", "image_B", "image_C"] + request1_items = { + h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size) + for h in request1_hashes + } + request1_items_p0_result = [] + + request2_hashes = ["image_G", "image_A"] + request2_items = { + h: MultiModalKwargsItem.dummy( + h, nbytes=(5 if h in request1_hashes else 2) * base_item_size + ) + for h in request2_hashes + } + request2_items_p0_result = [] + + request3_hashes = ["image_G", "image_H", "image_I", "image_B"] + request3_items = { + h: MultiModalKwargsItem.dummy( + h, nbytes=(5 if h in request1_hashes else 2) * base_item_size + ) + for h in request3_hashes + } + request3_items_p0_result = [] + + ########################## + # STEP 1: Request 1 send + # This will fill up the cache + ########################## + sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes) + # Cache is empty + assert sender_is_cached_item_req1 == [False, False, False] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request1_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 1 for P0 Cache + ########################### + item_tuple: MultiModalProcessorCacheInItem + for i, h in enumerate(request1_hashes): + # Use precomputed cache state + is_cached = sender_is_cached_item_req1[i] + item_tuple = (request1_items[h], []) if not is_cached else None + print(f"Request 1: key={h} | cached={is_cached}") + + p0_result = p0_cache.get_and_update_item(item_tuple, h) + # Only get mm item, ignore prompt update result + request1_items_p0_result.append(p0_result[0]) + + ########################### + # Process request 1 for P1 Cache + ########################### + # Touch all mm hash for P1 Cache before process + for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result): + p1_cache.touch_receiver_cache_item(mm_hash, mm_item) + + for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result): + p1_cache.get_and_update_item(mm_item, mm_hash) + + expected_hashes = ["image_A", "image_B", "image_C"] + assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes + + ########################## + # STEP 2: Request 2 send + # There is no eviction because image_A is protected + # No new item can add to cache + ########################## + sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes) + assert sender_is_cached_item_req2 == [False, True] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request2_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 2 for P0 Cache + ########################### + for i, h in enumerate(request2_hashes): + # Use precomputed cache state again + is_cached = sender_is_cached_item_req2[i] + item_tuple = (request2_items[h], []) if not is_cached else None + print(f"Request 2: key={h} | cached={is_cached}") + + p0_result = p0_cache.get_and_update_item(item_tuple, h) + # Only get mm item, ignore prompt update result + request2_items_p0_result.append(p0_result[0]) + + # image_A cannot be evict then + # image_G will fail to allocate anyway and image_A still in cache + assert p0_cache.is_cached(request2_hashes) == [False, True] + + ########################### + # Process request 2 for P1 Cache + ########################### + + # Touch all mm hash for P1 Cache before process + for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result): + p1_cache.touch_receiver_cache_item(mm_hash, mm_item) + + for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result): + p1_cache.get_and_update_item(mm_item, mm_hash) + + # Prove that cache state is unchanged + expected_hashes = ["image_A", "image_B", "image_C"] + assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes + + ########################## + # STEP 3: Request 3 send + ########################## + ##### Prove that cache eviction work normally + sender_is_cached_item_req3 = p0_cache.is_cached(request3_hashes) + assert sender_is_cached_item_req3 == [False, False, False, True] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request3_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 3 for P0 Cache + ########################### + for i, h in enumerate(request3_hashes): + # Use precomputed cache state again + is_cached = sender_is_cached_item_req3[i] + item_tuple = (request3_items[h], []) if not is_cached else None + print(f"Request 3: key={h} | cached={is_cached}") + p0_result = p0_cache.get_and_update_item(item_tuple, h) + # Only get mm item, ignore prompt update result + request3_items_p0_result.append(p0_result[0]) + + # image_A got evict and image_G add to cache + # image_B is still protected + # image_G, image_H fit but image_I cannot fit + assert p0_cache.is_cached(request3_hashes) == [True, True, False, True] + + ########################### + # Process request 3 for P1 Cache + ########################### + + # Touch all mm hash for P1 Cache before process + for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result): + p1_cache.touch_receiver_cache_item(mm_hash, mm_item) + + for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result): + p1_cache.get_and_update_item(mm_item, mm_hash) + + expected_hashes = ["image_B", "image_C", "image_G", "image_H"] + assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes + + +def test_cache_eviction_shm_cache(): + vllm_config = VllmConfig( + model_config=ModelConfig( + model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + mm_processor_cache_type="shm", + mm_shm_cache_max_object_size_mb=6, + mm_processor_cache_gb=15.2 * MiB_bytes / GiB_bytes, + ), + ) + sender_cache = ShmObjectStoreSenderCache(vllm_config) + receiver_cache = ShmObjectStoreReceiverCache(vllm_config, mp.Lock()) + + _run_test_cache_eviction_shm(sender_cache, receiver_cache, base_item_size=MiB_bytes) diff --git a/tests/multimodal/test_inputs.py b/tests/multimodal/test_inputs.py deleted file mode 100644 index 88e92bee3a292..0000000000000 --- a/tests/multimodal/test_inputs.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors - -pytestmark = pytest.mark.cpu_test - - -def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors): - assert type(expected) == type(actual) # noqa: E721 - if isinstance(expected, torch.Tensor): - assert torch.equal(expected, actual) - else: - for expected_item, actual_item in zip(expected, actual): - assert_nested_tensors_equal(expected_item, actual_item) - - -def assert_multimodal_inputs_equal( - expected: MultiModalKwargs, actual: MultiModalKwargs -): - assert set(expected.keys()) == set(actual.keys()) - for key in expected: - assert_nested_tensors_equal(expected[key], actual[key]) - - -def test_multimodal_input_batch_single_tensor(): - t = torch.rand([1, 2]) - result = MultiModalKwargs.batch([{"image": t}]) - assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)}) - - -def test_multimodal_input_batch_multiple_tensors(): - a = torch.rand([1, 1, 2]) - b = torch.rand([1, 1, 2]) - c = torch.rand([1, 1, 2]) - result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}]) - assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])}) - - -def test_multimodal_input_batch_multiple_heterogeneous_tensors(): - a = torch.rand([1, 2, 2]) - b = torch.rand([1, 3, 2]) - c = torch.rand([1, 4, 2]) - result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}]) - assert_multimodal_inputs_equal(result, {"image": [a, b, c]}) - - -def test_multimodal_input_batch_nested_tensors(): - a = torch.rand([2, 3]) - b = torch.rand([2, 3]) - c = torch.rand([2, 3]) - result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b]}, {"image": [c]}]) - assert_multimodal_inputs_equal( - result, {"image": torch.stack([a.unsqueeze(0), b.unsqueeze(0), c.unsqueeze(0)])} - ) - - -def test_multimodal_input_batch_heterogeneous_lists(): - a = torch.rand([1, 2, 3]) - b = torch.rand([1, 2, 3]) - c = torch.rand([1, 2, 3]) - result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) - assert_multimodal_inputs_equal( - result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]} - ) - - -def test_multimodal_input_batch_multiple_batchable_lists(): - a = torch.rand([1, 2, 3]) - b = torch.rand([1, 2, 3]) - c = torch.rand([1, 2, 3]) - d = torch.rand([1, 2, 3]) - result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}]) - assert_multimodal_inputs_equal( - result, {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])} - ) - - -def test_multimodal_input_batch_mixed_stacking_depths(): - a = torch.rand([1, 2, 3]) - b = torch.rand([1, 3, 3]) - c = torch.rand([1, 4, 3]) - - result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) - assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]}) - - result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b, c]}]) - assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]}) diff --git a/tests/multimodal/test_sparse_tensor_validation_unit.py b/tests/multimodal/test_sparse_tensor_validation_unit.py new file mode 100644 index 0000000000000..2eec8ea8283a2 --- /dev/null +++ b/tests/multimodal/test_sparse_tensor_validation_unit.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for sparse tensor validation. + +Simple, fast unit tests that can run without server fixtures. +Run with: pytest tests/multimodal/test_sparse_tensor_validation_unit.py -v +""" + +import io + +import pytest +import torch + + +class TestSparseTensorValidationContextManager: + """Test that torch.sparse.check_sparse_tensor_invariants() works as expected.""" + + def test_valid_sparse_tensor_passes(self): + """Valid sparse tensors should pass validation.""" + indices = torch.tensor([[0, 1], [0, 1]]) + values = torch.tensor([1.0, 2.0]) + shape = (2, 2) + + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.sparse_coo_tensor(indices, values, shape) + dense = tensor.to_dense() + + assert dense.shape == shape + + def test_out_of_bounds_indices_rejected(self): + """Sparse tensors with out-of-bounds indices should be rejected.""" + indices = torch.tensor([[5], [5]]) # Out of bounds for 2x2 + values = torch.tensor([1.0]) + shape = (2, 2) + + with pytest.raises(RuntimeError) as exc_info: # noqa: SIM117 + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.sparse_coo_tensor(indices, values, shape) + tensor.to_dense() + + assert ( + "index" in str(exc_info.value).lower() + or "bound" in str(exc_info.value).lower() + ) + + def test_negative_indices_rejected(self): + """Sparse tensors with negative indices should be rejected.""" + indices = torch.tensor([[-1], [0]]) + values = torch.tensor([1.0]) + shape = (2, 2) + + with pytest.raises(RuntimeError): # noqa: SIM117 + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.sparse_coo_tensor(indices, values, shape) + tensor.to_dense() + + def test_without_context_manager_allows_invalid(self): + """ + WITHOUT validation, invalid tensors may not immediately error. + + This demonstrates the vulnerability: PyTorch 2.8.0+ doesn't validate + by default, which can lead to memory corruption. + """ + indices = torch.tensor([[100], [100]]) # Way out of bounds + values = torch.tensor([1.0]) + shape = (2, 2) + + # Without validation context, this might create an invalid tensor + # (actual behavior depends on PyTorch version) + tensor = torch.sparse_coo_tensor(indices, values, shape) + + # The tensor object is created, but it's invalid + assert tensor.is_sparse + + +class TestTorchLoadWithValidation: + """Test torch.load() with sparse tensor validation.""" + + def test_load_valid_sparse_tensor_with_validation(self): + """Valid sparse tensors should load successfully with validation.""" + # Create and save a valid sparse tensor + indices = torch.tensor([[0, 1], [0, 1]]) + values = torch.tensor([1.0, 2.0]) + tensor = torch.sparse_coo_tensor(indices, values, (2, 2)) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + + # Load with validation + with torch.sparse.check_sparse_tensor_invariants(): + loaded = torch.load(buffer, weights_only=True) + dense = loaded.to_dense() + + assert dense.shape == (2, 2) + + def test_load_invalid_sparse_tensor_rejected(self): + """Invalid sparse tensors should be caught when loaded with validation.""" + # Create an invalid sparse tensor (out of bounds) + indices = torch.tensor([[10], [10]]) + values = torch.tensor([1.0]) + tensor = torch.sparse_coo_tensor(indices, values, (2, 2)) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + + # Load with validation - should fail on to_dense() + with pytest.raises(RuntimeError): # noqa: SIM117 + with torch.sparse.check_sparse_tensor_invariants(): + loaded = torch.load(buffer, weights_only=True) + loaded.to_dense() + + def test_load_dense_tensor_unaffected(self): + """Dense tensors should work normally with the validation context.""" + # Create and save a dense tensor + tensor = torch.randn(10, 20) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + + # Load with validation (should have no effect on dense tensors) + with torch.sparse.check_sparse_tensor_invariants(): + loaded = torch.load(buffer, weights_only=True) + + assert loaded.shape == (10, 20) + assert not loaded.is_sparse + + +if __name__ == "__main__": + # Allow running directly for quick testing + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 639e290406fe2..636cd0ffd445e 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import base64 import mimetypes import os @@ -186,6 +187,7 @@ async def test_fetch_image_error_conversion(): connector.fetch_image(broken_img) +@pytest.mark.flaky(reruns=3, reruns_delay=5) @pytest.mark.asyncio @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) @pytest.mark.parametrize("num_frames", [-1, 32, 1800]) @@ -198,8 +200,12 @@ async def test_fetch_video_http(video_url: str, num_frames: int): } ) - video_sync, metadata_sync = connector.fetch_video(video_url) - video_async, metadata_async = await connector.fetch_video_async(video_url) + try: + video_sync, metadata_sync = connector.fetch_video(video_url) + video_async, metadata_async = await connector.fetch_video_async(video_url) + except (TimeoutError, asyncio.TimeoutError) as e: + pytest.skip(f"Timeout fetching video (CI network flakiness): {e}") + assert np.array_equal(video_sync, video_async) assert metadata_sync == metadata_async diff --git a/tests/multimodal/test_video.py b/tests/multimodal/test_video.py index 6ed21de368ac3..eccaa53ea1004 100644 --- a/tests/multimodal/test_video.py +++ b/tests/multimodal/test_video.py @@ -147,7 +147,7 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch): """ Regression test for handling videos with broken frames. This test uses a pre-corrupted video file (assets/corrupted.mp4) that - contains broken/unreadable frames to verify the video loader handles + contains broken frames to verify the video loader handles them gracefully without crashing and returns accurate metadata. """ with monkeypatch.context() as m: @@ -177,3 +177,125 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch): f"Expected fewer than {metadata['total_num_frames']} frames, " f"but loaded {frames.shape[0]} frames" ) + + +@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_1") +class TestVideoBackendOverride1(VideoLoader): + """Test loader that returns FAKE_OUTPUT_1 to verify backend selection.""" + + @classmethod + def load_bytes( + cls, data: bytes, num_frames: int = -1, **kwargs + ) -> tuple[npt.NDArray, dict]: + return FAKE_OUTPUT_1, {"video_backend": "test_video_backend_override_1"} + + +@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_2") +class TestVideoBackendOverride2(VideoLoader): + """Test loader that returns FAKE_OUTPUT_2 to verify backend selection.""" + + @classmethod + def load_bytes( + cls, data: bytes, num_frames: int = -1, **kwargs + ) -> tuple[npt.NDArray, dict]: + return FAKE_OUTPUT_2, {"video_backend": "test_video_backend_override_2"} + + +def test_video_media_io_backend_kwarg_override(monkeypatch: pytest.MonkeyPatch): + """ + Test that video_backend kwarg can override the VLLM_VIDEO_LOADER_BACKEND + environment variable. + + This allows users to dynamically select a different video backend + via --media-io-kwargs without changing the global env var, which is + useful when plugins set a default backend but a specific request + needs a different one. + """ + with monkeypatch.context() as m: + # Set the env var to one backend + m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_1") + + imageio = ImageMediaIO() + + # Without video_backend kwarg, should use env var backend + videoio_default = VideoMediaIO(imageio, num_frames=10) + frames_default, metadata_default = videoio_default.load_bytes(b"test") + np.testing.assert_array_equal(frames_default, FAKE_OUTPUT_1) + assert metadata_default["video_backend"] == "test_video_backend_override_1" + + # With video_backend kwarg, should override env var + videoio_override = VideoMediaIO( + imageio, num_frames=10, video_backend="test_video_backend_override_2" + ) + frames_override, metadata_override = videoio_override.load_bytes(b"test") + np.testing.assert_array_equal(frames_override, FAKE_OUTPUT_2) + assert metadata_override["video_backend"] == "test_video_backend_override_2" + + +def test_video_media_io_backend_kwarg_not_passed_to_loader( + monkeypatch: pytest.MonkeyPatch, +): + """ + Test that video_backend kwarg is consumed by VideoMediaIO and NOT passed + through to the underlying video loader's load_bytes method. + + This ensures the kwarg is properly popped from kwargs before forwarding. + """ + + @VIDEO_LOADER_REGISTRY.register("test_reject_video_backend_kwarg") + class RejectVideoBackendKwargLoader(VideoLoader): + """Test loader that fails if video_backend is passed through.""" + + @classmethod + def load_bytes( + cls, data: bytes, num_frames: int = -1, **kwargs + ) -> tuple[npt.NDArray, dict]: + # This should never receive video_backend in kwargs + if "video_backend" in kwargs: + raise AssertionError( + "video_backend should be consumed by VideoMediaIO, " + "not passed to loader" + ) + return FAKE_OUTPUT_1, {"received_kwargs": list(kwargs.keys())} + + with monkeypatch.context() as m: + m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_reject_video_backend_kwarg") + + imageio = ImageMediaIO() + + # Even when video_backend is provided, it should NOT be passed to loader + videoio = VideoMediaIO( + imageio, + num_frames=10, + video_backend="test_reject_video_backend_kwarg", + other_kwarg="should_pass_through", + ) + + # This should NOT raise AssertionError + frames, metadata = videoio.load_bytes(b"test") + np.testing.assert_array_equal(frames, FAKE_OUTPUT_1) + # Verify other kwargs are still passed through + assert "other_kwarg" in metadata["received_kwargs"] + + +def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch): + """ + Test that when video_backend kwarg is None or not provided, + VideoMediaIO falls back to VLLM_VIDEO_LOADER_BACKEND env var. + """ + with monkeypatch.context() as m: + m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_2") + + imageio = ImageMediaIO() + + # Explicit None should fall back to env var + videoio_none = VideoMediaIO(imageio, num_frames=10, video_backend=None) + frames_none, metadata_none = videoio_none.load_bytes(b"test") + np.testing.assert_array_equal(frames_none, FAKE_OUTPUT_2) + assert metadata_none["video_backend"] == "test_video_backend_override_2" + + # Not providing video_backend should also fall back to env var + videoio_missing = VideoMediaIO(imageio, num_frames=10) + frames_missing, metadata_missing = videoio_missing.load_bytes(b"test") + np.testing.assert_array_equal(frames_missing, FAKE_OUTPUT_2) + assert metadata_missing["video_backend"] == "test_video_backend_override_2" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index a80617a366cab..8448003e70531 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -30,5 +30,6 @@ class DummyPlatform(Platform): use_mla, has_sink, use_sparse, + use_mm_prefix, ): return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 8dd4551ff4b96..a43d2abfdd8b8 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -10,9 +10,9 @@ import pytest from tests.utils import RemoteOpenAIServer from vllm.platforms import current_platform -if not current_platform.is_device_capability(100): +if not current_platform.is_device_capability_family(100): pytest.skip( - "This test only runs on Blackwell GPUs (SM100).", allow_module_level=True + "This test only runs on Blackwell GPUs (SM10x).", allow_module_level=True ) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 7bcac9ad768e7..62203186510ce 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -10,10 +10,14 @@ import torch from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.fp8 import ( + Fp8Config, Fp8KVCacheMethod, Fp8LinearMethod, + Fp8MoEMethod, ) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.platforms import current_platform MODELS = [ @@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None: torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype ), ) + + +@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod]) +# FP8 weight reloading does not support online quantization +@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False +@pytest.mark.parametrize("weight_block_size", [None, [1, 1]]) +# any postprocessing that is applied to the weights such as padding and repacking +# (excluding device sharding) must also be applied to the reloaded weights +# +# this is the case for marlin as well as per-tensor Fp8MoEMethod +@pytest.mark.parametrize("use_marlin", [False]) # skip True +def test_fp8_reloading( + method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init +): + if is_checkpoint_fp8_serialized is False: + pytest.skip("FP8 weight reloading does not support online quantization") + + if method_cls is Fp8MoEMethod and weight_block_size is None: + pytest.skip( + "FP8 Tensor weight reloading does not support fusing w13_weight_scale. " + "If this is your use case, consider using a restore function like #26327" + ) + + with torch.device("cuda:0"): + config = Fp8Config( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + weight_block_size=weight_block_size, + ) + + if method_cls is Fp8LinearMethod: + layer = torch.nn.Linear(1, 1) + method = method_cls(config) + method.create_weights( + layer=layer, + input_size_per_partition=1, + output_partition_sizes=[1], + input_size=1, + output_size=1, + params_dtype=torch.bfloat16, + weight_loader=default_weight_loader, + ) + + else: + layer = FusedMoE( + num_experts=1, + top_k=1, + hidden_size=1, + intermediate_size=1, + ) + method = method_cls(config, layer) + method.create_weights( + layer=layer, + num_experts=1, + hidden_size=1, + intermediate_size_per_partition=1, + params_dtype=torch.bfloat16, + weight_loader=default_weight_loader, + ) + + method.use_marlin = use_marlin + + # capture weights format during loading + original_metadata = [ + (name, param.shape, getattr(param, "weight_loader", default_weight_loader)) + for name, param in layer.named_parameters() + ] + + # test loading + for name, shape, _ in original_metadata: + param = getattr(layer, name) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, torch.zeros(shape)) # cannot use empty + + method.process_weights_after_loading(layer) + + # test reloading works after loading + # assuming that no reshaping occurred + for name, shape, original_weight_loader in original_metadata: + param = getattr(layer, name) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + assert weight_loader is original_weight_loader + weight_loader(param, torch.zeros(shape)) # cannot use empty + + method.process_weights_after_loading(layer) diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 334f9a65e4c03..0ff6e8407ce67 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -212,11 +212,11 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): task = "wikitext" rtol = 0.1 - # Smaller cuda_graph_sizes to speed up the test. + # Smaller cudagraph_capture_sizes to speed up the test. results = lm_eval.simple_evaluate( model="vllm", model_args=config.get_model_args( - tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]} + tp_size=tp_size, kwargs={"cudagraph_capture_sizes": [16]} ), tasks=task, batch_size=64, diff --git a/tests/reasoning/test_base_thinking_reasoning_parser.py b/tests/reasoning/test_base_thinking_reasoning_parser.py index d31b1c7d169b7..165e91a2c79f2 100644 --- a/tests/reasoning/test_base_thinking_reasoning_parser.py +++ b/tests/reasoning/test_base_thinking_reasoning_parser.py @@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods: """Test the is_reasoning_end method.""" parser = TestThinkingReasoningParser(test_tokenizer) end_token_id = parser.end_token_id - + start_token_id = parser.start_token_id # Test with end token present assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True @@ -122,6 +122,51 @@ class TestBaseThinkingReasoningParserMethods: # Test with empty list assert parser.is_reasoning_end([]) is False + # Test with interleaved thinking + assert parser.is_reasoning_end([1, start_token_id, 2, end_token_id]) is True + assert parser.is_reasoning_end([1, start_token_id, 2, 3]) is False + assert ( + parser.is_reasoning_end( + [1, start_token_id, 2, end_token_id, 2, 2, start_token_id] + ) + is False + ) + + def test_is_reasoning_end_streaming(self, test_tokenizer): + """Test the is_reasoning_end_streaming method.""" + parser = TestThinkingReasoningParser(test_tokenizer) + end_token_id = parser.end_token_id + start_token_id = parser.start_token_id + + assert ( + parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id]) + is True + ) + assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False + assert parser.is_reasoning_end_streaming([], []) is False + assert ( + parser.is_reasoning_end_streaming( + [1, start_token_id, 2, end_token_id], [end_token_id] + ) + is True + ) + assert ( + parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False + ) + assert ( + parser.is_reasoning_end_streaming( + [1, start_token_id, 2, end_token_id, 2, start_token_id, 2], + [2], + ) + is False + ) + assert ( + parser.is_reasoning_end_streaming( + [1, start_token_id, 2, end_token_id, 2, 2], [2] + ) + is False + ) + def test_extract_content_ids(self, test_tokenizer): """Test the extract_content_ids method.""" parser = TestThinkingReasoningParser(test_tokenizer) diff --git a/tests/reasoning/test_deepseekv3_reasoning_parser.py b/tests/reasoning/test_deepseekv3_reasoning_parser.py index 6e8f0e8dcc9b9..874fdef778110 100644 --- a/tests/reasoning/test_deepseekv3_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv3_reasoning_parser.py @@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer): input_tokens = tokenizer.tokenize(input_text) input_ids = tokenizer.convert_tokens_to_ids(input_tokens) assert parser.is_reasoning_end(input_ids) is True + assert parser.is_reasoning_end_streaming(input_ids, input_ids) is True # Test extract_content_ids returns all input_ids assert parser.extract_content_ids(input_ids) == input_ids diff --git a/tests/reasoning/test_holo2_reasoning_parser.py b/tests/reasoning/test_holo2_reasoning_parser.py new file mode 100644 index 0000000000000..438bb2e957b85 --- /dev/null +++ b/tests/reasoning/test_holo2_reasoning_parser.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from vllm.reasoning.holo2_reasoning_parser import Holo2ReasoningParser +from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser + +REASONING_MODEL_NAME = "HCompany/Holo2-4B" + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +@pytest.mark.parametrize( + "thinking,expected_parser_type", + [ + (True, DeepSeekR1ReasoningParser), + (False, IdentityReasoningParser), + ], +) +def test_parser_selection(tokenizer, thinking, expected_parser_type): + parser = Holo2ReasoningParser( + tokenizer, + chat_template_kwargs={ + "thinking": thinking, + }, + ) + + assert isinstance(parser._parser, expected_parser_type) + + +def test_holo2_default_parser_is_deepseekr1(tokenizer): + parser = Holo2ReasoningParser(tokenizer) + + assert isinstance(parser._parser, DeepSeekR1ReasoningParser) + + +def test_holo2_supports_structured_output(tokenizer): + # Structured output manager uses the reasoning parser to check if the + # reasoning content is ended before applying the grammar. The main function + # used is is_reasoning_end. This test checks if the parser is able to + # correctly identify the end of the reasoning content. + + # important to not pass chat_template_kwargs here as it is done in the + # StructuredOutputManager + parser = Holo2ReasoningParser(tokenizer) + + end_token_id = tokenizer.encode("</think>", add_special_tokens=False)[0] + + assert parser.is_reasoning_end([1, 2, 4, end_token_id]) + assert not parser.is_reasoning_end([1, 2, 4]) + assert parser.is_reasoning_end([1, 2, 4, end_token_id, 5]) + + +# thinking is True, non-streaming +WITH_THINK = { + "output": "This is a reasoning section</think>This is the rest", + "reasoning": "This is a reasoning section", + "content": "This is the rest", +} +# thinking is True, streaming +WITH_THINK_STREAM = { + "output": "This is a reasoning section</think>This is the rest", + "reasoning": "This is a reasoning section", + "content": "This is the rest", +} +# thinking is False, non-streaming +THINKING_DISABLED = { + "output": "This is the rest", + "reasoning": None, + "content": "This is the rest", +} +# thinking is False, streaming +THINKING_DISABLED_STREAM = { + "output": "This is the rest", + "reasoning": None, + "content": "This is the rest", +} +# thinking is False but the model output </think>, non-streaming +THINKING_DISABLED_WITH_CLOSE_TAG = { + "output": "</think>This is the rest", + "reasoning": None, + "content": "</think>This is the rest", +} +# thinking is False but the model output </think>, streaming +THINKING_DISABLED_WITH_CLOSE_TAG_STREAM = { + "output": "some text</think>This is the rest", + "reasoning": None, + "content": "some text</think>This is the rest", +} +COMPLETE_REASONING = { + "output": "This is a reasoning section</think>", + "reasoning": "This is a reasoning section", + "content": None, +} + +TEST_CASES = [ + pytest.param( + False, + WITH_THINK, + None, + id="with_think", + ), + pytest.param( + True, + WITH_THINK_STREAM, + None, + id="with_think_stream", + ), + pytest.param( + False, + WITH_THINK, + {"thinking": True}, + id="with_think_enabled", + ), + pytest.param( + True, + WITH_THINK_STREAM, + {"thinking": True}, + id="with_think_stream_enabled", + ), + pytest.param( + False, + THINKING_DISABLED, + {"thinking": False}, + id="thinking_disabled", + ), + pytest.param( + True, + THINKING_DISABLED_STREAM, + {"thinking": False}, + id="thinking_disabled_stream", + ), + pytest.param( + False, + THINKING_DISABLED_WITH_CLOSE_TAG, + {"thinking": False}, + id="thinking_disabled_with_close_tag", + ), + pytest.param( + True, + THINKING_DISABLED_WITH_CLOSE_TAG_STREAM, + {"thinking": False}, + id="thinking_disabled_with_close_tag_stream", + ), + pytest.param( + False, + COMPLETE_REASONING, + None, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + None, + id="complete_reasoning_stream", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict, chat_template_kwargs", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + chat_template_kwargs: dict | None, + tokenizer, +): + output = tokenizer.tokenize(param_dict["output"]) + output_tokens: list[str] = [ + tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser("holo2")( + tokenizer, + chat_template_kwargs=chat_template_kwargs, + ) + + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) + + assert reasoning == param_dict["reasoning"] + assert content == param_dict["content"] diff --git a/tests/reasoning/test_minimax_m2_append_reasoning_parser.py b/tests/reasoning/test_minimax_m2_append_reasoning_parser.py new file mode 100644 index 0000000000000..eefe5e3eff74c --- /dev/null +++ b/tests/reasoning/test_minimax_m2_append_reasoning_parser.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "minimax_m2_append_think" +end_token = "</think>" + +# MiniMax M2 model path +REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2" + + +@pytest.fixture(scope="module") +def minimax_m2_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +# ============================================================================= +# MiniMaxM2AppendThinkReasoningParser behavior: +# - Prepends <think> to the beginning of the output +# - Does NOT separate reasoning and content +# - Returns everything as content (with <think> prepended) +# - reasoning is always None +# +# This parser is used when you want to keep the raw output with <think> added +# ============================================================================= + +# Case: simple output with end token +SIMPLE_OUTPUT = { + "output": "This is reasoning</think>This is response", + "reasoning": None, + "content": "<think>This is reasoning</think>This is response", + "is_reasoning_end": True, +} + +# Case: output without end token (reasoning in progress) +NO_END_TOKEN = { + "output": "This is reasoning in progress", + "reasoning": None, + "content": "<think>This is reasoning in progress", + "is_reasoning_end": False, +} + +# Case: only end token +ONLY_END_TOKEN = { + "output": "</think>This is response", + "reasoning": None, + "content": "<think></think>This is response", + "is_reasoning_end": True, +} + +# Case: multiple lines +MULTIPLE_LINES = { + "output": "Line 1\nLine 2</think>Response 1\nResponse 2", + "reasoning": None, + "content": "<think>Line 1\nLine 2</think>Response 1\nResponse 2", + "is_reasoning_end": True, +} + +# Case: empty output (non-streaming prepends <think>) +EMPTY = { + "output": "", + "reasoning": None, + "content": "<think>", + "is_reasoning_end": False, +} + +# Case: empty output streaming (no tokens = no output) +EMPTY_STREAMING = { + "output": "", + "reasoning": None, + "content": None, + "is_reasoning_end": False, +} + +# Case: special characters +SPECIAL_CHARS = { + "output": "Let me think... 1+1=2</think>Yes!", + "reasoning": None, + "content": "<think>Let me think... 1+1=2</think>Yes!", + "is_reasoning_end": True, +} + +# Case: code in output +CODE_OUTPUT = { + "output": "```python\nprint('hi')\n```</think>Here's the code.", + "reasoning": None, + "content": "<think>```python\nprint('hi')\n```</think>Here's the code.", + "is_reasoning_end": True, +} + +TEST_CASES = [ + pytest.param( + False, + SIMPLE_OUTPUT, + id="simple_output", + ), + pytest.param( + True, + SIMPLE_OUTPUT, + id="simple_output_streaming", + ), + pytest.param( + False, + NO_END_TOKEN, + id="no_end_token", + ), + pytest.param( + True, + NO_END_TOKEN, + id="no_end_token_streaming", + ), + pytest.param( + False, + ONLY_END_TOKEN, + id="only_end_token", + ), + pytest.param( + True, + ONLY_END_TOKEN, + id="only_end_token_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines", + ), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + False, + EMPTY, + id="empty", + ), + pytest.param( + True, + EMPTY_STREAMING, + id="empty_streaming", + ), + pytest.param( + False, + SPECIAL_CHARS, + id="special_chars", + ), + pytest.param( + True, + SPECIAL_CHARS, + id="special_chars_streaming", + ), + pytest.param( + False, + CODE_OUTPUT, + id="code_output", + ), + pytest.param( + True, + CODE_OUTPUT, + id="code_output_streaming", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + minimax_m2_tokenizer, +): + output = minimax_m2_tokenizer.tokenize(param_dict["output"]) + # decode everything to tokens + output_tokens: list[str] = [ + minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + minimax_m2_tokenizer + ) + + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) + + assert reasoning == param_dict["reasoning"] + assert content == param_dict["content"] + + # Test is_reasoning_end + output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output) + is_reasoning_end = parser.is_reasoning_end(output_ids) + assert is_reasoning_end == param_dict["is_reasoning_end"] diff --git a/tests/reasoning/test_minimax_m2_reasoning_parser.py b/tests/reasoning/test_minimax_m2_reasoning_parser.py new file mode 100644 index 0000000000000..0d1056894c6ae --- /dev/null +++ b/tests/reasoning/test_minimax_m2_reasoning_parser.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "minimax_m2" +end_token = "</think>" + +# MiniMax M2 model path +REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2" + + +@pytest.fixture(scope="module") +def minimax_m2_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +# ============================================================================= +# MiniMax M2 specific behavior: +# - Model does NOT generate <think> start token +# - Model only generates </think> end token +# - All content before </think> is reasoning +# - All content after </think> is the actual response (content) +# ============================================================================= + +# Case: reasoning + end token + content (typical case) +SIMPLE_REASONING = { + "output": "This is a reasoning section</think>This is the rest", + "reasoning": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} + +# Case: reasoning + end token only (no content after) +COMPLETE_REASONING = { + "output": "This is a reasoning section</think>", + "reasoning": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} + +# Case: no end token yet (streaming in progress, all is reasoning) +NO_END_TOKEN = { + "output": "This is reasoning in progress", + "reasoning": "This is reasoning in progress", + "content": None, + "is_reasoning_end": False, +} + +# Case: multiple lines of reasoning +MULTIPLE_LINES = { + "output": "First line\nSecond line</think>Response first line\nResponse second", + "reasoning": "First line\nSecond line", + "content": "Response first line\nResponse second", + "is_reasoning_end": True, +} + +# Case: only end token (empty reasoning, immediate response) +SHORTEST_REASONING_NO_STREAMING = { + "output": "</think>This is the response", + "reasoning": "", + "content": "This is the response", + "is_reasoning_end": True, +} + +# Case: only end token streaming (reasoning is None because it's just the token) +SHORTEST_REASONING_STREAMING = { + "output": "</think>This is the response", + "reasoning": None, + "content": "This is the response", + "is_reasoning_end": True, +} + +# Case: empty output +EMPTY = { + "output": "", + "reasoning": "", + "content": None, + "is_reasoning_end": False, +} + +# Case: empty streaming +EMPTY_STREAMING = { + "output": "", + "reasoning": None, + "content": None, + "is_reasoning_end": False, +} + +# Case: long reasoning with special characters +SPECIAL_CHARS = { + "output": "Let me think... 1+1=2, right?</think>Yes, 1+1=2.", + "reasoning": "Let me think... 1+1=2, right?", + "content": "Yes, 1+1=2.", + "is_reasoning_end": True, +} + +# Case: reasoning with code blocks +CODE_IN_REASONING = { + "output": "```python\nprint('hello')\n```</think>Here is the code.", + "reasoning": "```python\nprint('hello')\n```", + "content": "Here is the code.", + "is_reasoning_end": True, +} + +TEST_CASES = [ + # Core cases: no start token (MiniMax M2 actual behavior) + pytest.param( + False, + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + True, + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_streaming", + ), + pytest.param( + False, + NO_END_TOKEN, + id="no_end_token", + ), + pytest.param( + True, + NO_END_TOKEN, + id="no_end_token_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines", + ), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING, + id="shortest_reasoning", + ), + pytest.param( + True, + SHORTEST_REASONING_STREAMING, + id="shortest_reasoning_streaming", + ), + pytest.param( + False, + EMPTY, + id="empty", + ), + pytest.param( + True, + EMPTY_STREAMING, + id="empty_streaming", + ), + pytest.param( + False, + SPECIAL_CHARS, + id="special_chars", + ), + pytest.param( + True, + SPECIAL_CHARS, + id="special_chars_streaming", + ), + pytest.param( + False, + CODE_IN_REASONING, + id="code_in_reasoning", + ), + pytest.param( + True, + CODE_IN_REASONING, + id="code_in_reasoning_streaming", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + minimax_m2_tokenizer, +): + output = minimax_m2_tokenizer.tokenize(param_dict["output"]) + # decode everything to tokens + output_tokens: list[str] = [ + minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + minimax_m2_tokenizer + ) + + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) + + assert reasoning == param_dict["reasoning"] + assert content == param_dict["content"] + + # Test is_reasoning_end + output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output) + is_reasoning_end = parser.is_reasoning_end(output_ids) + assert is_reasoning_end == param_dict["is_reasoning_end"] + + # Test extract_content + if param_dict["content"] is not None: + content = parser.extract_content_ids(output_ids) + assert content == minimax_m2_tokenizer.convert_tokens_to_ids( + minimax_m2_tokenizer.tokenize(param_dict["content"]) + ) + else: + content = parser.extract_content_ids(output) + assert content == [] diff --git a/tests/reasoning/test_mistral_reasoning_parser.py b/tests/reasoning/test_mistral_reasoning_parser.py index 0fe315c2567f9..d6da723f80b08 100644 --- a/tests/reasoning/test_mistral_reasoning_parser.py +++ b/tests/reasoning/test_mistral_reasoning_parser.py @@ -5,7 +5,7 @@ import pytest from tests.reasoning.utils import run_reasoning_extraction_mistral from vllm.reasoning import ReasoningParser, ReasoningParserManager -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer parser_name = "mistral" @@ -18,47 +18,53 @@ def mistral_tokenizer(): return mistral_tokenizer -SIMPLE_REASONING = { +INVALID_SIMPLE_REASONING = { "output": "This is a reasoning section[/THINK]This is the rest", - "reasoning": "This is a reasoning section", - "content": "This is the rest", - "is_reasoning_end": True, + "reasoning": None, + "content": "This is a reasoning sectionThis is the rest", + "is_reasoning_end": False, } -COMPLETE_REASONING = { +INVALID_COMPLETE_REASONING = { "output": "This is a reasoning section[/THINK]", - "reasoning": "This is a reasoning section", - "content": None, - "is_reasoning_end": True, + "reasoning": None, + "content": "This is a reasoning section", + "is_reasoning_end": False, } NO_CONTENT = { - "output": "This is content", - "reasoning": "This is content", + "output": "[THINK]This is reasoning", + "reasoning": "This is reasoning", "content": None, "is_reasoning_end": False, } +NO_REASONING = { + "output": "This is content", + "reasoning": None, + "content": "This is content", + "is_reasoning_end": False, +} NO_REASONING_STREAMING = { "output": "This is a reasoning section", - "reasoning": "This is a reasoning section", - "content": None, + "reasoning": None, + "content": "This is a reasoning section", "is_reasoning_end": False, } -MULTIPLE_LINES = { +INVALID_MULTIPLE_LINES = { "output": "This\nThat[/THINK]This is the rest\nThat", - "reasoning": "This\nThat", - "content": "This is the rest\nThat", - "is_reasoning_end": True, + "reasoning": None, + "content": "This\nThatThis is the rest\nThat", + "is_reasoning_end": False, } -SHORTEST_REASONING_NO_STREAMING = { - "output": "[/THINK]This is the rest", - "reasoning": "", - "content": "This is the rest", - "is_reasoning_end": True, -} -SHORTEST_REASONING = { +INVALID_SHORTEST_REASONING_NO_STREAMING = { "output": "[/THINK]This is the rest", "reasoning": None, "content": "This is the rest", - "is_reasoning_end": True, + "is_reasoning_end": False, +} +INVALID_SHORTEST_REASONING = { + "output": "[/THINK]This is the rest", + "reasoning": None, + "content": "This is the rest", + "is_reasoning_end": False, } REASONING_WITH_THINK = { "output": "[THINK]This is a reasoning section[/THINK]This is the rest", @@ -78,17 +84,17 @@ MULTIPLE_LINES_WITH_THINK = { "content": "This is the rest\nThat", "is_reasoning_end": True, } -SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { - "output": "[/THINK]This is the rest", - "reasoning": "", - "content": "This is the rest", - "is_reasoning_end": True, -} -SHORTEST_REASONING_WITH_THINK = { +INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { "output": "[/THINK]This is the rest", "reasoning": None, "content": "This is the rest", - "is_reasoning_end": True, + "is_reasoning_end": False, +} +INVALID_SHORTEST_REASONING_WITH_THINK = { + "output": "[/THINK]This is the rest", + "reasoning": None, + "content": "This is the rest", + "is_reasoning_end": False, } THINK_NO_END = { "output": "[THINK]This is a reasoning section", @@ -98,8 +104,8 @@ THINK_NO_END = { } EMPTY = { "output": "", - "reasoning": "", - "content": None, + "reasoning": None, + "content": "", "is_reasoning_end": False, } EMPTY_STREAMING = { @@ -109,47 +115,48 @@ EMPTY_STREAMING = { "is_reasoning_end": False, } NEW_LINE = { - "output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest", + "output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest", "reasoning": "This is a reasoning section", - "content": "\nThis is the rest", + "content": "Before\n\nThis is the rest", "is_reasoning_end": True, } -# Streaming cannot handle new lines at the beginning of the output -# because we need to support [THINK]...[/THINK] and [/THINK]... -# We cannot know if the text before [THINK] is reasoning content -# or not. NEW_LINE_STREAMING = { - "output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest", - "reasoning": "\nThis is a reasoning section", - "content": "\nThis is the rest", + "output": "Before\n[THINK]This is a reasoning section[/THINK]\nThis is the rest", + "reasoning": "This is a reasoning section", + "content": "Before\n\nThis is the rest", "is_reasoning_end": True, } TEST_CASES = [ pytest.param( False, - SIMPLE_REASONING, - id="simple_reasoning", + INVALID_SIMPLE_REASONING, + id="invalid_simple_reasoning", ), pytest.param( True, - SIMPLE_REASONING, - id="simple_reasoning_streaming", + INVALID_SIMPLE_REASONING, + id="invalid_simple_reasoning_streaming", ), pytest.param( False, - COMPLETE_REASONING, - id="complete_reasoning", + INVALID_COMPLETE_REASONING, + id="invalid_complete_reasoning", ), pytest.param( True, - COMPLETE_REASONING, - id="complete_reasoning_streaming", + INVALID_COMPLETE_REASONING, + id="invalid_complete_reasoning_streaming", ), pytest.param( False, NO_CONTENT, - id="no_content_token", + id="no_content", + ), + pytest.param( + False, + NO_REASONING, + id="no_reasoning", ), pytest.param( True, @@ -158,23 +165,23 @@ TEST_CASES = [ ), pytest.param( False, - MULTIPLE_LINES, - id="multiple_lines", + INVALID_MULTIPLE_LINES, + id="invalid_multiple_lines", ), pytest.param( True, - MULTIPLE_LINES, - id="multiple_lines_streaming", + INVALID_MULTIPLE_LINES, + id="invalid_multiple_lines_streaming", ), pytest.param( True, - SHORTEST_REASONING, - id="shortest", + INVALID_SHORTEST_REASONING, + id="invalid_shortest", ), pytest.param( False, - SHORTEST_REASONING_NO_STREAMING, - id="shortest_streaming", + INVALID_SHORTEST_REASONING_NO_STREAMING, + id="invalid_shortest_streaming", ), pytest.param( False, @@ -208,13 +215,13 @@ TEST_CASES = [ ), pytest.param( False, - SHORTEST_REASONING_NO_STREAMING_WITH_THINK, - id="shortest_with_think", + INVALID_SHORTEST_REASONING_NO_STREAMING_WITH_THINK, + id="invalid_shortest_with_think", ), pytest.param( True, - SHORTEST_REASONING_WITH_THINK, - id="shortest_with_think_streaming", + INVALID_SHORTEST_REASONING_WITH_THINK, + id="invalid_shortest_with_think_streaming", ), pytest.param( False, @@ -316,10 +323,26 @@ def test_mistral_reasoning( # Test extract_content if param_dict["content"] is not None: - content = parser.extract_content_ids(output_tokens) - assert content == mistral_tokenizer.tokenizer.encode( - param_dict["content"], bos=False, eos=False + # Handle the case where there are tokens outputted before Thinking. + # This should not occur if the model is well trained and prompted. + if "[THINK]" in param_dict["output"] and not param_dict["output"].startswith( + "[THINK]" + ): + before_content = param_dict["output"].split("[THINK]")[0] + before_token_ids = mistral_tokenizer.tokenizer.encode( + before_content, bos=False, eos=False + ) + left_to_encode = param_dict["content"][len(before_content) :] + # Normal situation. + else: + before_token_ids = [] + left_to_encode = param_dict["content"] + + content_tokens = parser.extract_content_ids(output_tokens) + expected_token_ids = before_token_ids + mistral_tokenizer.tokenizer.encode( + left_to_encode, bos=False, eos=False ) + assert content_tokens == expected_token_ids else: content = parser.extract_content_ids(output_tokens) assert content == [] diff --git a/tests/reasoning/utils.py b/tests/reasoning/utils.py index 695312a0cadfe..a020fb8e97161 100644 --- a/tests/reasoning/utils.py +++ b/tests/reasoning/utils.py @@ -4,7 +4,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.reasoning import ReasoningParser -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer class StreamingReasoningReconstructor: diff --git a/tests/standalone_tests/lazy_imports.py b/tests/standalone_tests/lazy_imports.py index ddcdd2a51ab9f..fff5c54f276d3 100644 --- a/tests/standalone_tests/lazy_imports.py +++ b/tests/standalone_tests/lazy_imports.py @@ -5,9 +5,6 @@ # The utility function cannot be placed in `vllm.utils` # this needs to be a standalone script import sys -from contextlib import nullcontext - -from vllm_test_utils import BlameResult, blame # List of modules that should not be imported too early. # Lazy import `torch._inductor.async_compile` to avoid creating @@ -16,26 +13,10 @@ from vllm_test_utils import BlameResult, blame # `cv2` can easily mess up the environment. module_names = ["torch._inductor.async_compile", "cv2"] +# set all modules in `module_names` to be None. +# if we import any modules during `import vllm`, there would be a +# hard error and nice stacktrace on the first import. +for module_name in module_names: + sys.modules[module_name] = None # type: ignore[assignment] -def any_module_imported(): - return any(module_name in sys.modules for module_name in module_names) - - -# In CI, we only check finally if the module is imported. -# If it is indeed imported, we can rerun the test with `use_blame=True`, -# which will trace every function call to find the first import location, -# and help find the root cause. -# We don't run it in CI by default because it is slow. -use_blame = False -context = blame(any_module_imported) if use_blame else nullcontext() -with context as result: - import vllm # noqa - -if use_blame: - assert isinstance(result, BlameResult) - print(f"the first import location is:\n{result.trace_stack}") - -assert not any_module_imported(), ( - f"Some the modules in {module_names} are imported. To see the first" - f" import location, run the test with `use_blame=True`." -) +import vllm # noqa diff --git a/tests/standalone_tests/python_only_compile.sh b/tests/standalone_tests/python_only_compile.sh index 7cc5ef6596490..2017e34030d60 100644 --- a/tests/standalone_tests/python_only_compile.sh +++ b/tests/standalone_tests/python_only_compile.sh @@ -3,6 +3,43 @@ # for users who do not have any compilers installed on their system set -e + +merge_base_commit=$(git merge-base HEAD origin/main) +echo "INFO: current merge base commit with main: $merge_base_commit" +git show --oneline -s $merge_base_commit + +# test whether the metadata.json url is valid, retry each 3 minutes up to 5 times +# this avoids cumbersome error messages & manual retries in case the precompiled wheel +# for the given commit is still being built in the release pipeline +meta_json_url="https://wheels.vllm.ai/$merge_base_commit/vllm/metadata.json" +echo "INFO: will use metadata.json from $meta_json_url" + +for i in {1..5}; do + echo "Checking metadata.json URL (attempt $i)..." + if curl --fail "$meta_json_url" > metadata.json; then + echo "INFO: metadata.json URL is valid." + # check whether it is valid json by python + if python3 -m json.tool metadata.json; then + echo "INFO: metadata.json is valid JSON. Proceeding with the test." + else + echo "CRITICAL: metadata.json exists but is not valid JSON, please do report in #sig-ci channel!" + exit 1 + fi + break + fi + # failure handling + if [ $i -eq 5 ]; then + echo "ERROR: metadata.json URL is still not valid after 5 attempts." + echo "ERROR: Please check whether the precompiled wheel for commit $merge_base_commit exists." + echo " NOTE: If $merge_base_commit is a new commit on main, maybe try again after its release pipeline finishes." + echo " NOTE: If it fails, please report in #sig-ci channel." + exit 1 + else + echo "WARNING: metadata.json URL is not valid. Retrying in 3 minutes..." + sleep 180 + fi +done + set -x cd /vllm-workspace/ @@ -18,13 +55,13 @@ apt autoremove -y echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py -VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL=1 VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . +VLLM_PRECOMPILED_WHEEL_COMMIT=$merge_base_commit VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . # Run the script python3 -c 'import vllm' # Check if the clangd log file was created if [ ! -f /tmp/changed.file ]; then - echo "changed.file was not created, python only compilation failed" + echo "ERROR: changed.file was not created, python only compilation failed" exit 1 fi diff --git a/tests/test_config.py b/tests/test_config.py index 112b02edd0389..ee706ab3d9c87 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,12 +6,14 @@ from dataclasses import MISSING, Field, asdict, dataclass, field from unittest.mock import patch import pytest +from pydantic import ValidationError from vllm.compilation.backends import VllmBackend from vllm.config import ( CompilationConfig, ModelConfig, PoolerConfig, + SchedulerConfig, VllmConfig, update_config, ) @@ -87,64 +89,6 @@ def test_update_config(): new_config3 = update_config(config3, {"a": "new_value"}) -# Can remove once --task option is fully deprecated -@pytest.mark.parametrize( - ("model_id", "expected_runner_type", "expected_convert_type", "expected_task"), - [ - ("distilbert/distilgpt2", "generate", "none", "generate"), - ("intfloat/multilingual-e5-small", "pooling", "none", "embed"), - ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"), - ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"), - ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "reward"), - ("openai/whisper-small", "generate", "none", "transcription"), - ], -) -def test_auto_task( - model_id, expected_runner_type, expected_convert_type, expected_task -): - config = ModelConfig(model_id, task="auto") - - assert config.runner_type == expected_runner_type - assert config.convert_type == expected_convert_type - - -# Can remove once --task option is fully deprecated -@pytest.mark.parametrize( - ("model_id", "expected_runner_type", "expected_convert_type", "expected_task"), - [ - ("distilbert/distilgpt2", "pooling", "embed", "embed"), - ("intfloat/multilingual-e5-small", "pooling", "embed", "embed"), - ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"), - ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify", "classify"), - ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed", "embed"), - ("openai/whisper-small", "pooling", "embed", "embed"), - ], -) -def test_score_task( - model_id, expected_runner_type, expected_convert_type, expected_task -): - config = ModelConfig(model_id, task="score") - - assert config.runner_type == expected_runner_type - assert config.convert_type == expected_convert_type - - -# Can remove once --task option is fully deprecated -@pytest.mark.parametrize( - ("model_id", "expected_runner_type", "expected_convert_type", "expected_task"), - [ - ("openai/whisper-small", "generate", "none", "transcription"), - ], -) -def test_transcription_task( - model_id, expected_runner_type, expected_convert_type, expected_task -): - config = ModelConfig(model_id, task="transcription") - - assert config.runner_type == expected_runner_type - assert config.convert_type == expected_convert_type - - @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_convert_type"), [ @@ -627,8 +571,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files): ( "internlm/internlm2-1_8b-reward", "decoder", - False, - "Pooling models with all pooling does not support chunked prefill.", + True, + "Pooling models with causal attn and all pooling support chunked prefill.", ), ( "BAAI/bge-base-en", @@ -716,7 +660,7 @@ def test_is_chunked_prefill_supported( ): model_config = ModelConfig(model_id, trust_remote_code=True) assert model_config.attn_type == expected_attn_type - with caplog_vllm.at_level(level=logging.DEBUG): + with caplog_vllm.at_level(level=logging.DEBUG, logger="vllm"): assert model_config.is_chunked_prefill_supported == expected_result assert reason in caplog_vllm.text @@ -746,8 +690,8 @@ def test_is_chunked_prefill_supported( ( "internlm/internlm2-1_8b-reward", "decoder", - False, - "Pooling models with all pooling does not support prefix caching.", + True, + "Pooling models with causal attn and all pooling support prefix caching.", ), ( "BAAI/bge-base-en", @@ -835,7 +779,7 @@ def test_is_prefix_caching_supported( ): model_config = ModelConfig(model_id, trust_remote_code=True) assert model_config.attn_type == expected_attn_type - with caplog_vllm.at_level(level=logging.DEBUG): + with caplog_vllm.at_level(level=logging.DEBUG, logger="vllm"): assert model_config.is_prefix_caching_supported == expected_result assert reason in caplog_vllm.text @@ -1021,17 +965,17 @@ def test_vllm_config_explicit_overrides(): assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE # Explicit pass config flags to override defaults - pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True) + pass_config = PassConfig(eliminate_noops=True, fuse_attn_quant=True) compilation_config = CompilationConfig(pass_config=pass_config) config = VllmConfig( optimization_level=OptimizationLevel.O0, compilation_config=compilation_config, ) - assert config.compilation_config.pass_config.enable_noop is True - assert config.compilation_config.pass_config.enable_attn_fusion is True + assert config.compilation_config.pass_config.eliminate_noops is True + assert config.compilation_config.pass_config.fuse_attn_quant is True # Explicit cudagraph mode override on quantized model at O2 - pass_config = PassConfig(enable_async_tp=True) + pass_config = PassConfig(fuse_gemm_comms=True) compilation_config = CompilationConfig( cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config ) @@ -1041,7 +985,7 @@ def test_vllm_config_explicit_overrides(): compilation_config=compilation_config, ) assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE - assert config.compilation_config.pass_config.enable_async_tp is True + assert config.compilation_config.pass_config.fuse_gemm_comms is True # Mode should still use default for O2 assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE @@ -1083,7 +1027,7 @@ def test_vllm_config_explicit_overrides(): ) # Override one field but not others - pass_config = PassConfig(enable_noop=False) + pass_config = PassConfig(eliminate_noops=False) compilation_config = CompilationConfig(pass_config=pass_config) config = VllmConfig( model_config=regular_model, @@ -1091,7 +1035,18 @@ def test_vllm_config_explicit_overrides(): compilation_config=compilation_config, ) # Explicit override should be respected - assert config.compilation_config.pass_config.enable_noop is False + assert config.compilation_config.pass_config.eliminate_noops is False # Other fields should still use defaults assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE + + +def test_scheduler_config_init(): + with pytest.raises(ValidationError): + # Positional InitVars missing + # (InitVars cannot have defaults otherwise they will become attributes) + SchedulerConfig() + + with pytest.raises(AttributeError): + # InitVar does not become an attribute + print(SchedulerConfig.default_factory().max_model_len) diff --git a/tests/test_envs.py b/tests/test_envs.py index 6a9835a68e7e2..b6b7cf38d4abc 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -8,6 +8,7 @@ import pytest import vllm.envs as envs from vllm.envs import ( + disable_envs_cache, enable_envs_cache, env_list_with_choices, env_set_with_choices, @@ -57,6 +58,43 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch): envs.__getattr__ = envs.__getattr__.__wrapped__ +def test_getattr_with_reset(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1") + # __getattr__ is not decorated with functools.cache + assert not hasattr(envs.__getattr__, "cache_info") + + # Enable envs cache and ignore ongoing environment changes + enable_envs_cache() + assert envs.VLLM_HOST_IP == "1.1.1.1" + # With cache enabled, the environment variable value is cached and unchanged + monkeypatch.setenv("VLLM_HOST_IP", "2.2.2.2") + assert envs.VLLM_HOST_IP == "1.1.1.1" + + disable_envs_cache() + assert envs.VLLM_HOST_IP == "2.2.2.2" + # After cache disabled, the environment variable value would be synced + # with os.environ + monkeypatch.setenv("VLLM_HOST_IP", "3.3.3.3") + assert envs.VLLM_HOST_IP == "3.3.3.3" + + +def test_is_envs_cache_enabled() -> None: + assert not envs._is_envs_cache_enabled() + enable_envs_cache() + assert envs._is_envs_cache_enabled() + + # Only wrap one-layer of cache, so we only need to + # call disable once to reset. + enable_envs_cache() + enable_envs_cache() + enable_envs_cache() + disable_envs_cache() + assert not envs._is_envs_cache_enabled() + + disable_envs_cache() + assert not envs._is_envs_cache_enabled() + + class TestEnvWithChoices: """Test cases for env_with_choices function.""" @@ -365,3 +403,54 @@ class TestEnvSetWithChoices: with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}): env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"]) assert env_func() == {"option1", "option2"} + + +class TestVllmConfigureLogging: + """Test cases for VLLM_CONFIGURE_LOGGING environment variable.""" + + def test_configure_logging_defaults_to_true(self): + """Test that VLLM_CONFIGURE_LOGGING defaults to True when not set.""" + # Ensure the env var is not set + with patch.dict(os.environ, {}, clear=False): + if "VLLM_CONFIGURE_LOGGING" in os.environ: + del os.environ["VLLM_CONFIGURE_LOGGING"] + + # Clear cache if it exists + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + result = envs.VLLM_CONFIGURE_LOGGING + assert result is True + assert isinstance(result, bool) + + def test_configure_logging_with_zero_string(self): + """Test that VLLM_CONFIGURE_LOGGING='0' evaluates to False.""" + with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "0"}): + # Clear cache if it exists + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + result = envs.VLLM_CONFIGURE_LOGGING + assert result is False + assert isinstance(result, bool) + + def test_configure_logging_with_one_string(self): + """Test that VLLM_CONFIGURE_LOGGING='1' evaluates to True.""" + with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "1"}): + # Clear cache if it exists + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + result = envs.VLLM_CONFIGURE_LOGGING + assert result is True + assert isinstance(result, bool) + + def test_configure_logging_with_invalid_value_raises_error(self): + """Test that invalid VLLM_CONFIGURE_LOGGING value raises ValueError.""" + with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "invalid"}): + # Clear cache if it exists + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + with pytest.raises(ValueError, match="invalid literal for int"): + _ = envs.VLLM_CONFIGURE_LOGGING diff --git a/tests/test_inputs.py b/tests/test_inputs.py index b1fb4e06a6906..073be24a4a072 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -7,7 +7,7 @@ from vllm.config import ModelConfig from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.preprocess import InputPreprocessor -from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs +from vllm.tokenizers import cached_tokenizer_from_config pytestmark = pytest.mark.cpu_test @@ -34,6 +34,13 @@ INPUTS_SLICES = [ ] +# Test that a nested mixed-type list of lists raises a TypeError. +@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]]) +def test_invalid_input_raise_type_error(invalid_input): + with pytest.raises(TypeError): + parse_raw_prompts(invalid_input) + + def test_parse_raw_single_batch_empty(): with pytest.raises(ValueError, match="at least one prompt"): parse_raw_prompts([]) @@ -108,7 +115,7 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): ) def test_preprocessor_always_mm_code_path(model_id, prompt): model_config = ModelConfig(model=model_id) - tokenizer = init_tokenizer_from_configs(model_config) + tokenizer = cached_tokenizer_from_config(model_config) input_preprocessor = InputPreprocessor(model_config, tokenizer) # HF processor adds sep token diff --git a/tests/test_logger.py b/tests/test_logger.py index 8900e9c2a1e69..b4f44f52d4df9 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -57,7 +57,7 @@ def test_default_vllm_root_logger_configuration(monkeypatch): _configure_vllm_root_logger() logger = logging.getLogger("vllm") - assert logger.level == logging.DEBUG + assert logger.level == logging.INFO assert not logger.propagate handler = logger.handlers[0] @@ -524,7 +524,7 @@ def mp_function(**kwargs): def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork): - with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork(): + with caplog_vllm.at_level(logging.DEBUG, logger="vllm"), caplog_mp_fork(): import multiprocessing ctx = multiprocessing.get_context("fork") diff --git a/tests/tokenizers_/test_basic.py b/tests/tokenizers_/test_basic.py index 1fca633cc5cd7..0510261eacde7 100644 --- a/tests/tokenizers_/test_basic.py +++ b/tests/tokenizers_/test_basic.py @@ -3,39 +3,39 @@ from typing import _get_protocol_attrs # type: ignore import pytest -from transformers import PreTrainedTokenizerBase +from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) -from vllm.tokenizers import TokenizerLike -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import TokenizerLike, get_tokenizer +from vllm.tokenizers.mistral import MistralTokenizer def _get_missing_attrs(obj: object, target: type): return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)] +def _assert_tokenizer_like(tokenizer: object): + missing_attrs = _get_missing_attrs(tokenizer, TokenizerLike) + assert not missing_attrs, f"Missing attrs: {missing_attrs}" + + def test_tokenizer_like_protocol(): - assert not ( - missing_attrs := _get_missing_attrs( - get_tokenizer("gpt2", use_fast=False), - TokenizerLike, - ) - ), f"Missing attrs: {missing_attrs}" + tokenizer = get_tokenizer("gpt2", use_fast=False) + assert isinstance(tokenizer, PreTrainedTokenizer) + _assert_tokenizer_like(tokenizer) - assert not ( - missing_attrs := _get_missing_attrs( - get_tokenizer("gpt2", use_fast=True), - TokenizerLike, - ) - ), f"Missing attrs: {missing_attrs}" + tokenizer = get_tokenizer("gpt2", use_fast=True) + assert isinstance(tokenizer, PreTrainedTokenizerFast) + _assert_tokenizer_like(tokenizer) - assert not ( - missing_attrs := _get_missing_attrs( - get_tokenizer( - "mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral" - ), - TokenizerLike, - ) - ), f"Missing attrs: {missing_attrs}" + tokenizer = get_tokenizer( + "mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral" + ) + assert isinstance(tokenizer, MistralTokenizer) + _assert_tokenizer_like(tokenizer) @pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"]) diff --git a/tests/tokenizers_/test_detokenize.py b/tests/tokenizers_/test_detokenize.py index ae1d6b0956722..d307993d04df9 100644 --- a/tests/tokenizers_/test_detokenize.py +++ b/tests/tokenizers_/test_detokenize.py @@ -8,7 +8,7 @@ import pytest from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import ( FastIncrementalDetokenizer, diff --git a/tests/tokenizers_/test_mistral.py b/tests/tokenizers_/test_mistral.py index 92efac86dff29..faff611502652 100644 --- a/tests/tokenizers_/test_mistral.py +++ b/tests/tokenizers_/test_mistral.py @@ -91,6 +91,118 @@ from vllm.tokenizers.mistral import ( ], ), ), + ( + { + "messages": [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + "tools": [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "unsupported_field": False, + "name": "get_current_time", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "unsupported_field2": False, + "name": "get_current_time", + "parameters": {}, + }, + }, + ], + }, + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + }, + ], + ), + ), + ( + { + "messages": [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + "tools": [ + { + "type": "function", + "unsupported_field": False, + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + }, + { + "type": "function", + "unsupported_field2": False, + "function": { + "description": "Fetch the current local date and time 2.", + "name": "get_current_time2", + "parameters": {"a": "1"}, + }, + }, + ], + }, + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "description": "Fetch the current local date and time 2.", + "name": "get_current_time2", + "parameters": {"a": "1"}, + }, + }, + ], + ), + ), ], ) def test_prepare_apply_chat_template_tools_and_messages( @@ -1108,13 +1220,6 @@ class TestMistralTokenizer: ) == expected_tokens[mistral_tokenizer.is_tekken] ) - assert ( - mistral_tokenizer.decode( - ids[mistral_tokenizer.is_tekken], - skip_special_tokens=skip_special_tokens, - ) - == expected_tokens[mistral_tokenizer.is_tekken] - ) def test_decode_empty( self, @@ -1140,6 +1245,45 @@ class TestMistralTokenizer: == "<s>" ) + @pytest.mark.parametrize( + "skip_special_tokens,expected_tokens", + ( + ( + False, + ( + ["<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>"], + ["<s>[INST]Hello world ![/INST]Hello</s>"], + ), + ), + (True, (["Hello world ! Hello"], ["Hello world !Hello"])), + ), + ) + def test_batch_decode( + self, + mistral_tokenizer: MistralTokenizer, + skip_special_tokens: bool, + expected_tokens: tuple[str, str], + ): + ids = ( + [[1, 3, 23325, 2294, 1686, 4, 23325, 2]], + [[1, 3, 22177, 4304, 2662, 4, 22177, 2]], + ) + assert ( + mistral_tokenizer.batch_decode( + ids[mistral_tokenizer.is_tekken], + skip_special_tokens=skip_special_tokens, + ) + == expected_tokens[mistral_tokenizer.is_tekken] + ) + + def test_batch_decode_empty( + self, + mistral_tokenizer: MistralTokenizer, + ): + assert mistral_tokenizer.batch_decode( + [[]], + ) == [""] + def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer): tokens = ( [ diff --git a/tests/tokenizers_/test_registry.py b/tests/tokenizers_/test_registry.py index 57b6a14a54b3f..546f38b078dde 100644 --- a/tests/tokenizers_/test_registry.py +++ b/tests/tokenizers_/test_registry.py @@ -2,8 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path -from vllm.tokenizers import TokenizerLike, TokenizerRegistry -from vllm.transformers_utils.tokenizer import get_tokenizer +import pytest + +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.registry import ( + TokenizerRegistry, + get_tokenizer, + resolve_tokenizer_args, +) class TestTokenizer(TokenizerLike): @@ -41,10 +47,22 @@ class TestTokenizer(TokenizerLike): return True +@pytest.mark.parametrize("runner_type", ["generate", "pooling"]) +def test_resolve_tokenizer_args_idempotent(runner_type): + tokenizer_mode, tokenizer_name, args, kwargs = resolve_tokenizer_args( + "facebook/opt-125m", + runner_type=runner_type, + ) + + assert (tokenizer_mode, tokenizer_name, args, kwargs) == resolve_tokenizer_args( + tokenizer_name, *args, **kwargs + ) + + def test_customized_tokenizer(): TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__) - tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer", "abc") + tokenizer = TokenizerRegistry.load_tokenizer("test_tokenizer", "abc") assert isinstance(tokenizer, TestTokenizer) assert tokenizer.path_or_repo_id == "abc" assert tokenizer.bos_token_id == 0 diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/tests/tool_parsers/__init__.py similarity index 100% rename from vllm/distributed/kv_transfer/kv_pipe/__init__.py rename to tests/tool_parsers/__init__.py diff --git a/tests/tool_use/test_deepseekv31_tool_parser.py b/tests/tool_parsers/test_deepseekv31_tool_parser.py similarity index 93% rename from tests/tool_use/test_deepseekv31_tool_parser.py rename to tests/tool_parsers/test_deepseekv31_tool_parser.py index db5168071fbce..69a4cc8b989c5 100644 --- a/tests/tool_use/test_deepseekv31_tool_parser.py +++ b/tests/tool_parsers/test_deepseekv31_tool_parser.py @@ -3,10 +3,10 @@ import pytest -from vllm.entrypoints.openai.tool_parsers.deepseekv31_tool_parser import ( +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.deepseekv31_tool_parser import ( DeepSeekV31ToolParser, ) -from vllm.transformers_utils.tokenizer import get_tokenizer MODEL = "deepseek-ai/DeepSeek-V3.1" diff --git a/tests/tool_use/test_ernie45_moe_tool_parser.py b/tests/tool_parsers/test_ernie45_moe_tool_parser.py similarity index 98% rename from tests/tool_use/test_ernie45_moe_tool_parser.py rename to tests/tool_parsers/test_ernie45_moe_tool_parser.py index 8fbbbba325385..533bd1ec3dfff 100644 --- a/tests/tool_use/test_ernie45_moe_tool_parser.py +++ b/tests/tool_parsers/test_ernie45_moe_tool_parser.py @@ -13,10 +13,9 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.ernie45_tool_parser import Ernie45ToolParser -from vllm.tokenizers import TokenizerLike +from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tool_parsers.ernie45_tool_parser import Ernie45ToolParser # Use a common model that is likely to be available MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking" diff --git a/tests/tool_use/test_glm4_moe_tool_parser.py b/tests/tool_parsers/test_glm4_moe_tool_parser.py similarity index 98% rename from tests/tool_use/test_glm4_moe_tool_parser.py rename to tests/tool_parsers/test_glm4_moe_tool_parser.py index f545f52c02dcb..52f5a9198e9b4 100644 --- a/tests/tool_use/test_glm4_moe_tool_parser.py +++ b/tests/tool_parsers/test_glm4_moe_tool_parser.py @@ -7,12 +7,10 @@ import json import pytest from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers.glm4_moe_tool_parser import ( +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.glm4_moe_tool_parser import ( Glm4MoeModelToolParser, ) -from vllm.transformers_utils.tokenizer import get_tokenizer - -pytestmark = pytest.mark.cpu_test pytest.skip("skip glm4_moe parser test", allow_module_level=True) # Use a common model that is likely to be available diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_parsers/test_jamba_tool_parser.py similarity index 98% rename from tests/tool_use/test_jamba_tool_parser.py rename to tests/tool_parsers/test_jamba_tool_parser.py index c7ca024f3a767..ccad16ae2f6b6 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_parsers/test_jamba_tool_parser.py @@ -9,12 +9,9 @@ import pytest from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers.jamba_tool_parser import JambaToolParser -from vllm.tokenizers import TokenizerLike +from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally -from vllm.transformers_utils.tokenizer import get_tokenizer - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.jamba_tool_parser import JambaToolParser MODEL = "ai21labs/Jamba-tiny-dev" diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_parsers/test_kimi_k2_tool_parser.py similarity index 99% rename from tests/tool_use/test_kimi_k2_tool_parser.py rename to tests/tool_parsers/test_kimi_k2_tool_parser.py index 3a48b5206141d..d02f53c34b455 100644 --- a/tests/tool_use/test_kimi_k2_tool_parser.py +++ b/tests/tool_parsers/test_kimi_k2_tool_parser.py @@ -7,10 +7,8 @@ import json import pytest from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser -from vllm.transformers_utils.tokenizer import get_tokenizer - -pytestmark = pytest.mark.cpu_test +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser # Use a common model that is likely to be available MODEL = "moonshotai/Kimi-K2-Instruct" diff --git a/tests/tool_use/test_minimax_tool_parser.py b/tests/tool_parsers/test_minimax_tool_parser.py similarity index 99% rename from tests/tool_use/test_minimax_tool_parser.py rename to tests/tool_parsers/test_minimax_tool_parser.py index 4332984083dab..28cfc4ea7a175 100644 --- a/tests/tool_use/test_minimax_tool_parser.py +++ b/tests/tool_parsers/test_minimax_tool_parser.py @@ -12,10 +12,8 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.minimax_tool_parser import MinimaxToolParser -from vllm.transformers_utils.tokenizer import get_tokenizer - -pytestmark = pytest.mark.cpu_test +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.minimax_tool_parser import MinimaxToolParser # Use a common model that is likely to be available MODEL = "MiniMaxAi/MiniMax-M1-40k" diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py new file mode 100644 index 0000000000000..9400a67267f4c --- /dev/null +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -0,0 +1,860 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Generator + +import partial_json_parser +import pytest +from mistral_common.protocol.instruct.messages import AssistantMessage +from mistral_common.protocol.instruct.request import InstructRequest +from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall +from vllm.tokenizers import TokenizerLike, get_tokenizer +from vllm.tokenizers.detokenizer_utils import detokenize_incrementally +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers.mistral_tool_parser import MistralToolParser + + +@pytest.fixture(scope="module") +def mistral_pre_v11_tokenizer(): + MODEL = "mistralai/Mistral-7B-Instruct-v0.3" + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture(scope="module") +def mistral_tokenizer(): + MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" + return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral") + + +@pytest.fixture +def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer): + return MistralToolParser(mistral_pre_v11_tokenizer) + + +@pytest.fixture +def mistral_tool_parser(mistral_tokenizer): + return MistralToolParser(mistral_tokenizer) + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall] | list[DeltaToolCall], + expected_tool_calls: list[ToolCall], +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) == 9 + + if isinstance(actual_tool_call, ToolCall): + assert actual_tool_call.type == "function" + elif isinstance(actual_tool_call, DeltaToolCall): + assert actual_tool_call.function is not None + assert actual_tool_call.function.name is not None + assert actual_tool_call.function.arguments is not None + assert actual_tool_call.function is not None + assert actual_tool_call.function.name == expected_tool_call.function.name, ( + f"got wrong function name:${actual_tool_call.function.name}" + ) + assert ( + actual_tool_call.function.arguments == expected_tool_call.function.arguments + ), f"got wrong function argument:${actual_tool_call.function.arguments}" + + +def fix_tool_call_tokenization( + tokens: list[int], + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: TokenizerLike, +): + """ + Replaces the textual token sequence for [TOOL_CALLS] + with its single special token ID. + """ + textual_tool_call_token_ids = mistral_tokenizer.encode( + text=mistral_tool_parser.bot_token, + add_special_tokens=False, + ) + # textual_tool_call_token_ids must not contain special tokens like bos, eos etc + special_tool_call_token_ids = [mistral_tool_parser.bot_token_id] + + # If the input is too short to contain the sequence, no replacement is possible + if not tokens or len(tokens) < len(textual_tool_call_token_ids): + return tokens + + result_tokens = [] + i = 0 + target_len = len(textual_tool_call_token_ids) + + while i < len(tokens): + # Check if the slice from the current position matches the target sequence + if tokens[i : i + target_len] == textual_tool_call_token_ids: + # If it matches, add the replacement and jump the index forward + result_tokens.extend(special_tool_call_token_ids) + i += target_len + else: + # Otherwise, just add the current token and move to the next one + result_tokens.append(tokens[i]) + i += 1 + + return result_tokens + + +def stream_delta_message_generator( + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: TokenizerLike, + model_output: str | None, + tools: list[tuple[str, str]] | None, +) -> Generator[DeltaMessage, None, None]: + if ( + isinstance(mistral_tokenizer, MistralTokenizer) + and mistral_tokenizer.version >= 11 + ): + # With the newer versions of the tokenizer, + # we cannot tokenize free text + # so we need to create a list of messages to get tokenized + assert tools is not None + assistant_msg = AssistantMessage( + tool_calls=[ + ToolCall( + function=FunctionCall( + name=name, + arguments=arg, + ) + ) + for (name, arg) in tools + ], + ) + request = InstructRequest( + messages=[assistant_msg], + ) + all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens + else: + # Older versions of the tokenizer are + # able to encode directly the model's output (free text) into tokens + assert model_output is not None + all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False) + + all_token_ids = fix_tool_call_tokenization( + all_token_ids, mistral_tool_parser, mistral_tokenizer + ) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=mistral_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer), + spaces_between_special_tokens=True, + ) + ) + + current_text = previous_text + delta_text + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=None, # type: ignore[arg-type] + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser): + model_output = "This is a test" + extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + } + ), + ) + ) + ], + None, + ), + ], +) +def test_extract_tool_calls_pre_v11_tokenizer( + mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content +): + extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "multiple_tool_calls", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({"a": 3.5, "b": 4}), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="multiply", arguments=json.dumps({"a": 3, "b": 6}) + ) + ), + ], + None, + ), + ], +) +def test_extract_tool_calls( + mistral_tool_parser, model_output, expected_tool_calls, expected_content +): + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def _test_extract_tool_calls_streaming( + tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content +): + other_content: str = "" + function_names: list[str] = [] + function_args_strs: list[str] = [] + tool_call_idx: int = -1 + tool_call_ids: list[str | None] = [] + + for delta_message in stream_delta_message_generator( + tool_parser, tokenizer, model_output, tools + ): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + streamed_tool_calls = delta_message.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + assert len(tool_parser.prev_tool_call_arr) > 0 + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + function_args_strs.append("") + tool_call_ids.append(None) + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id and not tool_call_ids[tool_call.index]: + tool_call_ids[tool_call.index] = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + function_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + function_args_strs[tool_call.index] += tool_call.function.arguments + + assert other_content == expected_content + + actual_tool_calls = [ + ToolCall( + id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR + ), + ), + ) + for tool_call_id, function_name, function_args_str in zip( + tool_call_ids, function_names, function_args_strs + ) + ] + assert_tool_calls(actual_tool_calls, expected_tool_calls) + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""This is a test""", [], """This is a test"""), + ( + """[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3, "b": 4}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": "3", "b": "4"}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + } + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming_pre_v11_tokenizer( + mistral_pre_v11_tool_parser, + mistral_pre_v11_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_pre_v11_tool_parser, + mistral_pre_v11_tokenizer, + model_output, + None, + expected_tool_calls, + expected_content, + ) + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_add_strings", + "multiple_tools", + ], + argnames=["tools", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + [("add", '{"a": 3, "b": 4}')], + # [TOOL_CALLS]add{"a": 3, "b": 4} + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3, "b": 4}) + ) + ) + ], + "", + ), + ( + [("add_two_strings", '{"a": "3", "b": "4"}')], + # [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"} + [ + ToolCall( + function=FunctionCall( + name="add_two_strings", + arguments=json.dumps({"a": "3", "b": "4"}), + ) + ) + ], + "", + ), + ( + [ + ("add", '{"a": 3.5, "b": 4}'), + ( + "get_current_weather", + '{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501 + ), + ], + # [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming( + mistral_tool_parser, + mistral_tokenizer, + tools, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_tool_parser, + mistral_tokenizer, + None, + tools, + expected_tool_calls, + expected_content, + ) + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "multiple_tool_calls", + "content_before_tool", + "complex", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({"a": 3.5, "b": 4}), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="multiply", arguments=json.dumps({"a": 3, "b": 6}) + ) + ), + ], + "", + ), + ( + # Additional content should not be after the tool calls + """bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({"a": 3.5, "b": 4}), + ) + ) + ], + "bla", + ), + ( + # Complex + """[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="bash", + arguments=json.dumps( + {"command": "print(\"hello world!\")\nre.compile(r'{}')"} + ), + ) + ) + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming_one_chunk( + mistral_tool_parser, + mistral_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + if isinstance(mistral_tokenizer, MistralTokenizer): + all_token_ids = mistral_tokenizer.encode(model_output) + else: + all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False) + all_token_ids = fix_tool_call_tokenization( + all_token_ids, mistral_tool_parser, mistral_tokenizer + ) + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=model_output, + delta_text=model_output, + previous_token_ids=[], + current_token_ids=all_token_ids, + delta_token_ids=all_token_ids, + request=None, + ) # type: ignore[arg-type] + assert isinstance(delta_message, DeltaMessage) + assert len(delta_message.tool_calls) == len(expected_tool_calls) + + assert_tool_calls(delta_message.tool_calls, expected_tool_calls) + + if delta_message.content is None: + assert expected_content == "" + else: + assert delta_message.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""This is a test""", [], """This is a test"""), + ( + """[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3, "b": 4}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": "3", "b": "4"}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + } + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk( + mistral_pre_v11_tool_parser, + mistral_pre_v11_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer): + all_token_ids = mistral_pre_v11_tokenizer.encode(model_output) + else: + all_token_ids = mistral_pre_v11_tokenizer.encode( + model_output, add_special_tokens=False + ) + all_token_ids = fix_tool_call_tokenization( + all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer + ) + + delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=model_output, + delta_text=model_output, + previous_token_ids=[], + current_token_ids=all_token_ids, + delta_token_ids=all_token_ids, + request=None, + ) # type: ignore[arg-type] + assert isinstance(delta_message, DeltaMessage) + assert len(delta_message.tool_calls) == len(expected_tool_calls) + + assert_tool_calls(delta_message.tool_calls, expected_tool_calls) + + if delta_message.content is None: + assert expected_content == "" + else: + assert delta_message.content == expected_content diff --git a/tests/tool_use/test_openai_tool_parser.py b/tests/tool_parsers/test_openai_tool_parser.py similarity index 98% rename from tests/tool_use/test_openai_tool_parser.py rename to tests/tool_parsers/test_openai_tool_parser.py index c874a9601ae70..44b8c92745e91 100644 --- a/tests/tool_use/test_openai_tool_parser.py +++ b/tests/tool_parsers/test_openai_tool_parser.py @@ -15,8 +15,8 @@ from openai_harmony import ( ) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall -from vllm.entrypoints.openai.tool_parsers.openai_tool_parser import OpenAIToolParser -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.openai_tool_parser import OpenAIToolParser MODEL = "gpt2" diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_parsers/test_qwen3coder_tool_parser.py similarity index 98% rename from tests/tool_use/test_qwen3coder_tool_parser.py rename to tests/tool_parsers/test_qwen3coder_tool_parser.py index 864bb0d0c06c2..3a0a612d7fbfd 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_parsers/test_qwen3coder_tool_parser.py @@ -13,15 +13,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( +from vllm.tokenizers import TokenizerLike, get_tokenizer +from vllm.tokenizers.detokenizer_utils import detokenize_incrementally +from vllm.tool_parsers.qwen3coder_tool_parser import ( Qwen3CoderToolParser, ) -from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser -from vllm.tokenizers import TokenizerLike -from vllm.tokenizers.detokenizer_utils import detokenize_incrementally -from vllm.transformers_utils.tokenizer import get_tokenizer - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_parsers/test_seed_oss_tool_parser.py similarity index 99% rename from tests/tool_use/test_seed_oss_tool_parser.py rename to tests/tool_parsers/test_seed_oss_tool_parser.py index d94df61128c9c..c7f595830f34b 100644 --- a/tests/tool_use/test_seed_oss_tool_parser.py +++ b/tests/tool_parsers/test_seed_oss_tool_parser.py @@ -14,12 +14,9 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.seed_oss_tool_parser import SeedOssToolParser -from vllm.tokenizers import TokenizerLike +from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally -from vllm.transformers_utils.tokenizer import get_tokenizer - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.seed_oss_tool_parser import SeedOssToolParser # Use a common model that is likely to be available MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct" diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_parsers/test_xlam_tool_parser.py similarity index 98% rename from tests/tool_use/test_xlam_tool_parser.py rename to tests/tool_parsers/test_xlam_tool_parser.py index fdcdd4038131a..380792a9926a4 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_parsers/test_xlam_tool_parser.py @@ -12,12 +12,9 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser -from vllm.tokenizers import TokenizerLike +from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally -from vllm.transformers_utils.tokenizer import get_tokenizer - -pytestmark = pytest.mark.cpu_test +from vllm.tool_parsers.xlam_tool_parser import xLAMToolParser # Use a common model that is likely to be available MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r" diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index d5572cfbebe3c..35ed8d215f73a 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -12,7 +12,7 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionToolsParam, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools +from vllm.tool_parsers.utils import get_json_schema_from_tools pytestmark = pytest.mark.cpu_test diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 7584b903156b7..de7284a309c53 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -123,7 +123,7 @@ CONFIGS: dict[str, ServerConfig] = { "supports_parallel": True, "extended": True, }, - "mistral": { + "mistral-7b": { "model": "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ "--enforce-eager", @@ -145,6 +145,32 @@ CONFIGS: dict[str, ServerConfig] = { "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally.", + "supports_parallel": True, + }, + "mistral-small-3.2": { + "model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506", + "arguments": [ + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "mistral", + "--tokenizer-mode", + "mistral", + "--config-format", + "mistral", + "--load-format", + "mistral", + "--tensor-parallel-size", + "4", + '--ignore-patterns="consolidated.safetensors"', + ], + "system_prompt": "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally.", + "supports_parallel": True, + "extended": True, }, # FIXME: This test currently fails, need to debug why. # "granite20b": { diff --git a/tests/transformers_utils/test_config.py b/tests/transformers_utils/test_config.py index 7b56c9f0189d4..85680c41ed74d 100644 --- a/tests/transformers_utils/test_config.py +++ b/tests/transformers_utils/test_config.py @@ -6,8 +6,8 @@ only get the `eos_token_id` from the tokenizer as defined by `vllm.LLMEngine._get_eos_token_id`. """ +from vllm.tokenizers import get_tokenizer from vllm.transformers_utils.config import try_get_generation_config -from vllm.transformers_utils.tokenizer import get_tokenizer def test_get_llama3_eos_token(): diff --git a/tests/transformers_utils/test_utils.py b/tests/transformers_utils/test_utils.py index a8d0b9be9ec29..0a6a65b4133c9 100644 --- a/tests/transformers_utils/test_utils.py +++ b/tests/transformers_utils/test_utils.py @@ -5,13 +5,15 @@ from unittest.mock import patch import pytest +from vllm.transformers_utils.gguf_utils import ( + is_gguf, + is_remote_gguf, + split_remote_gguf, +) from vllm.transformers_utils.utils import ( is_cloud_storage, is_gcs, - is_gguf, - is_remote_gguf, is_s3, - split_remote_gguf, ) @@ -132,7 +134,7 @@ class TestSplitRemoteGGUF: class TestIsGGUF: """Test is_gguf utility function.""" - @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True) + @patch("vllm.transformers_utils.gguf_utils.check_gguf_file", return_value=True) def test_is_gguf_with_local_file(self, mock_check_gguf): """Test is_gguf with local GGUF file.""" assert is_gguf("/path/to/model.gguf") @@ -149,7 +151,7 @@ class TestIsGGUF: assert not is_gguf("repo/model:quant") assert not is_gguf("repo/model:INVALID") - @patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False) + @patch("vllm.transformers_utils.gguf_utils.check_gguf_file", return_value=False) def test_is_gguf_false(self, mock_check_gguf): """Test is_gguf returns False for non-GGUF models.""" assert not is_gguf("unsloth/Qwen3-0.6B") diff --git a/tests/utils.py b/tests/utils.py index 9565b0ff06e36..d8102331b3612 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -44,7 +44,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.mem_constants import GB_bytes from vllm.utils.network_utils import get_open_port @@ -119,7 +119,7 @@ class RemoteOpenAIServer: vllm_serve_args: list[str], *, env_dict: dict[str, str] | None = None, - seed: int | None = 0, + seed: int = 0, auto_port: bool = True, max_wait_seconds: float | None = None, override_hf_configs: dict[str, Any] | None = None, @@ -283,7 +283,7 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer): child_process_fxn: Callable[[dict[str, str] | None, str, list[str]], None], *, env_dict: dict[str, str] | None = None, - seed: int | None = 0, + seed: int = 0, auto_port: bool = True, max_wait_seconds: float | None = None, ) -> None: @@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]: try: import aiter # noqa: F401 - attn_backend_list.append("FLASH_ATTN") + attn_backend_list.append("ROCM_AITER_FA") except Exception: - print("Skip FLASH_ATTN on ROCm as aiter is not installed") + print("Skip ROCM_AITER_FA on ROCm as aiter is not installed") return attn_backend_list elif current_platform.is_xpu(): diff --git a/tests/utils_/test_argparse_utils.py b/tests/utils_/test_argparse_utils.py index 2d969b8c9347d..fbc278404e3f0 100644 --- a/tests/utils_/test_argparse_utils.py +++ b/tests/utils_/test_argparse_utils.py @@ -458,25 +458,3 @@ def test_flat_product(): (3, 4, "a", 5, 6), (3, 4, "b", 5, 6), ] - - -def test_o_legacy_syntax_deprecation(caplog_vllm): - """Test that -O.* dotted syntax emits warnings and converts correctly to -cc syntax.""" - parser = FlexibleArgumentParser() - parser.add_argument("-cc", "--compilation-config", type=json.loads) - - # Test that -O.backend gets converted correctly AND emits warning - args = parser.parse_args(["-O.backend=eager"]) - assert args.compilation_config == {"backend": "eager"} - - # Check that deprecation warning was logged - assert len(caplog_vllm.records) >= 1 - assert ( - "The -O.* dotted syntax for --compilation-config is deprecated" - in caplog_vllm.text - ) - - # Test that -O.mode gets converted correctly - # Note: warning_once won't emit again in same session - args = parser.parse_args(["-O.mode=2"]) - assert args.compilation_config == {"mode": 2} diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index b46002c5fa8ff..e7ec8380e0a84 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -74,6 +74,9 @@ BATCH_SPECS = { ), "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "mixed_large": BatchSpec( + seq_lens=[1024, 2048, 4096, 1024, 2048, 4096], query_lens=[1, 1, 1, 32, 32, 32] + ), "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), } @@ -587,7 +590,14 @@ SLIDING_WINDOW_BACKENDS_TO_TEST = [ @pytest.mark.parametrize( "batch_spec_name", - ["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"], + [ + "small_decode", + "small_prefill", + "mixed_medium", + "large_decode", + "large_prefill", + "mixed_large", + ], ) @pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 1cbd0fe56be6d..734819fcdca83 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import ( split_attn_metadata, split_decodes_and_prefills, ) -from vllm.v1.worker.ubatch_utils import create_ubatch_slices +from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices @pytest.fixture @@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): def apply_split_decodes_and_prefills( - query_lens: list[int], decode_threshold: int, require_uniform: bool + query_lens: list[int], + decode_threshold: int, + require_uniform: bool, + padded_num_tokens: int | None = None, ): """Helper function to apply split_decodes_and_prefills and return the results.""" @@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills( block_size=16, device=device, ) + + if padded_num_tokens is not None: + common_metadata.num_actual_tokens = padded_num_tokens + return split_decodes_and_prefills( common_metadata, decode_threshold=decode_threshold, @@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes(): assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens +def test_split_decodes_and_prefills_uniform_padded_batch_all_same(): + """uniform batch where all query lengths are identical with 0 length padded reqs.""" + # All query lengths are 2, with decode_threshold=3 (so 2 <= 3) + # This triggers the padded uniform path at line 891 + query_lens = [2, 2, 2, 0] + padded_num_tokens = 8 + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + apply_split_decodes_and_prefills(query_lens, 3, True, padded_num_tokens) + ) + # With uniform batch, all requests are treated as decodes + assert num_decodes == 4 + assert num_prefills == 0 + assert num_decode_tokens == padded_num_tokens + assert num_prefill_tokens == 0 + + @pytest.mark.parametrize( "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", [ @@ -294,8 +317,15 @@ def test_prefill_split_across_ubatches( qsl_np = common.query_start_loc_cpu.numpy() num_tokens = common.num_actual_tokens - ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point) - assert len(ubatch_slices) == 2 + ubatch_slices, _ = maybe_create_ubatch_slices( + True, + num_scheduled_tokens, + num_tokens, + batch_spec.batch_size, + split_point=split_point, + num_ubatches=2, + ) + assert ubatch_slices is not None and len(ubatch_slices) == 2 first_meta = _make_metadata_with_slice(ubatch_slices[0], common) second_meta = _make_metadata_with_slice(ubatch_slices[1], common) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index b34d587eb362d..8049347280c5a 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -22,10 +22,14 @@ from tests.v1.attention.utils import ( ) from vllm import _custom_ops as ops from vllm.attention.ops import flashmla +from vllm.config import set_current_vllm_config from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend -from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + FlashMLASparseBackend, + triton_convert_req_index_to_global_index, +) +from vllm.v1.attention.backends.utils import split_prefill_chunks SPARSE_BACKEND_BATCH_SPECS = { name: BATCH_SPECS[name] @@ -114,8 +118,12 @@ def _quantize_dequantize_fp8_ds_mla( @pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) @pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +@pytest.mark.skipif( + torch.cuda.get_device_capability() < (9, 0), + reason="FlashMLASparseBackend requires CUDA 9.0 or higher", +) def test_sparse_backend_decode_correctness( - dist_init, batch_name, kv_cache_dtype, tensor_parallel_size + dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init ): if not torch.cuda.is_available(): pytest.skip("CUDA is required for sparse MLA decode test") @@ -320,28 +328,29 @@ def test_sparse_backend_decode_correctness( mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous()) impl_cls = FlashMLASparseBackend.get_impl_cls() - impl = impl_cls( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=1, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype=vllm_config.cache_config.cache_dtype, - logits_soft_cap=None, - attn_type="decoder", - kv_sharing_target_layer_name=None, - q_lora_rank=None, - kv_lora_rank=kv_lora_rank, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, - v_head_dim=v_head_dim, - kv_b_proj=mock_kv_b_proj, - indexer=mock_indexer, - ) + with set_current_vllm_config(vllm_config): + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + indexer=mock_indexer, + ) - impl.process_weights_after_loading(dtype) + impl.process_weights_after_loading(dtype) layer = MockAttentionLayer(device) out_buffer = torch.empty( @@ -366,22 +375,192 @@ def test_sparse_backend_decode_correctness( torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5) +def _triton_convert_reference_impl( + req_ids: torch.Tensor, + block_table: torch.Tensor, + token_indices: torch.Tensor, + block_size: int, + num_topk_tokens: int, + HAS_PREFILL_WORKSPACE: bool = False, + prefill_workspace_request_ids: torch.Tensor | None = None, + prefill_workspace_starts: torch.Tensor | None = None, +) -> torch.Tensor: + """Reference implementation for triton_convert_req_index_to_global_index.""" + num_tokens = req_ids.shape[0] + max_blocks_per_req = block_table.shape[1] + result = torch.empty( + num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device + ) + + for token_id in range(num_tokens): + req_id = req_ids[token_id].item() + + # Determine if this token uses workspace or paged cache + use_prefill_workspace = False + workspace_start = 0 + if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None: + assert prefill_workspace_starts is not None + prefill_req_id = prefill_workspace_request_ids[token_id].item() + if prefill_req_id >= 0: + use_prefill_workspace = True + workspace_start = prefill_workspace_starts[prefill_req_id].item() + + for idx_id in range(num_topk_tokens): + token_idx = token_indices[token_id, idx_id].item() + + if token_idx == -1: + result[token_id, idx_id] = -1 + elif use_prefill_workspace: + # Prefill + using prefill workspace: map to workspace offset + result[token_id, idx_id] = workspace_start + token_idx + else: + # Decode: map to paged cache + block_id = token_idx // block_size + if block_id >= max_blocks_per_req: + result[token_id, idx_id] = -1 + else: + block_num = block_table[req_id, block_id].item() + offset = token_idx % block_size + result[token_id, idx_id] = block_num * block_size + offset + + return result + + +@pytest.mark.parametrize("block_size", [16, 64, 128]) +@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512]) +@pytest.mark.skipif( + torch.cuda.get_device_capability() < (9, 0), + reason="FlashMLASparseBackend requires CUDA 9.0 or higher", +) +def test_triton_convert_req_index_to_global_index_decode_only( + block_size, num_topk_tokens +): + device = torch.device("cuda") + num_tokens = 8 + num_requests = 4 + max_blocks_per_req = 10 + + req_id = torch.randint( + 0, num_requests, (num_tokens,), dtype=torch.int32, device=device + ) + block_table = torch.randint( + 0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device + ) + + token_indices = torch.randint( + 0, + block_size * max_blocks_per_req, + (num_tokens, num_topk_tokens), + dtype=torch.int32, + device=device, + ) + + # Set some to -1 to test masking + token_indices[0, :10] = -1 + token_indices[3, 50:60] = -1 + + # Set some to out of bounds + token_indices[2, 100:110] = max_blocks_per_req * block_size + token_indices[6, 150:160] = max_blocks_per_req * block_size + + result = triton_convert_req_index_to_global_index( + req_id, + block_table, + token_indices, + BLOCK_SIZE=block_size, + NUM_TOPK_TOKENS=num_topk_tokens, + ) + + reference_result = _triton_convert_reference_impl( + req_id, + block_table, + token_indices, + block_size, + num_topk_tokens, + ) + + torch.testing.assert_close(result, reference_result, rtol=0, atol=0) + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.skipif( + torch.cuda.get_device_capability() < (9, 0), + reason="FlashMLASparseBackend requires CUDA 9.0 or higher", +) +def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size): + device = torch.device("cuda") + num_requests = 4 + max_blocks_per_req = 8 + num_topk_tokens = 128 + + # First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3) + req_id = torch.tensor( + [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device + ) + prefill_workspace_request_ids = torch.tensor( + [-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device + ) + + # Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100 + prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device) + + block_table = torch.randint( + 0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device + ) + token_indices = torch.randint( + 0, + block_size * max_blocks_per_req, + (req_id.shape[0], num_topk_tokens), + dtype=torch.int32, + device=device, + ) + + # Set some to -1 to test masking + token_indices[0, :10] = -1 + token_indices[3, 50:60] = -1 + + # Set some to out of bounds + token_indices[2, 100:110] = max_blocks_per_req * block_size + token_indices[6, 150:160] = max_blocks_per_req * block_size + + result = triton_convert_req_index_to_global_index( + req_id, + block_table, + token_indices, + BLOCK_SIZE=block_size, + NUM_TOPK_TOKENS=num_topk_tokens, + HAS_PREFILL_WORKSPACE=True, + prefill_workspace_request_ids=prefill_workspace_request_ids, + prefill_workspace_starts=prefill_workspace_starts, + ) + + reference_result = _triton_convert_reference_impl( + req_id, + block_table, + token_indices, + block_size, + num_topk_tokens, + HAS_PREFILL_WORKSPACE=True, + prefill_workspace_request_ids=prefill_workspace_request_ids, + prefill_workspace_starts=prefill_workspace_starts, + ) + + torch.testing.assert_close(result, reference_result, rtol=0, atol=0) + + @pytest.mark.parametrize( - "seq_lens,max_buf,start,expected", + "seq_lens,max_buf,expected", [ # Basic split: totals per chunk ≤ max_buf - (torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]), - # Non-zero start index - (torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]), - # Exact fits should split between items when adding the next would - # overflow - (torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]), + (torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]), + # Exact fits should split between items when adding the next would overflow + (torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]), # All requests fit in a single chunk - (torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]), - # Large buffer with non-zero start - (torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]), + (torch.tensor([1, 1, 1]), 10, [(0, 3)]), + # Large buffer + (torch.tensor([4, 4, 4]), 100, [(0, 3)]), ], ) -def test_split_prefill_chunks(seq_lens, max_buf, start, expected): - out = split_prefill_chunks(seq_lens, max_buf, start) +def test_split_prefill_chunks(seq_lens, max_buf, expected): + out = split_prefill_chunks(seq_lens, max_buf) assert out == expected diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index df3d53332c7cd..4dcaf9d908690 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -106,8 +106,8 @@ def create_common_attn_metadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens, - seq_lens_cpu=seq_lens_cpu, - num_computed_tokens_cpu=num_computed_tokens_cpu, + _seq_lens_cpu=seq_lens_cpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=batch_spec.batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, @@ -185,6 +185,8 @@ def create_vllm_config( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=enable_chunked_prefill, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, ) device_config = DeviceConfig() diff --git a/tests/v1/core/test_kv_cache_metrics.py b/tests/v1/core/test_kv_cache_metrics.py new file mode 100644 index 0000000000000..9e16aa64ab6af --- /dev/null +++ b/tests/v1/core/test_kv_cache_metrics.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import patch + +import pytest + +from vllm.v1.core.kv_cache_metrics import ( + BlockMetricsState, + KVCacheMetricsCollector, +) +from vllm.v1.core.kv_cache_utils import KVCacheBlock + + +class TestBlockMetricsState: + def test_init(self): + with patch("time.monotonic_ns", return_value=1000000000): + state = BlockMetricsState() + assert state.birth_time_ns == 1000000000 + assert state.last_access_ns == 1000000000 + assert len(state.access_history) == 0 + + def test_access_tracking(self): + with patch("time.monotonic_ns", return_value=1000000000): + state = BlockMetricsState() + + with patch("time.monotonic_ns", return_value=2000000000): + state.record_access() + + assert state.last_access_ns == 2000000000 + assert list(state.access_history) == [2000000000] + + def test_ring_buffer_wraps_at_4(self): + with patch("time.monotonic_ns", return_value=1000000000): + state = BlockMetricsState() + + for i in range(5): + t = 1000000000 + (i + 1) * 1000000000 + with patch("time.monotonic_ns", return_value=t): + state.record_access() + + assert len(state.access_history) == 4 + assert list(state.access_history) == [ + 3000000000, + 4000000000, + 5000000000, + 6000000000, + ] + + def test_lifetime(self): + with patch("time.monotonic_ns", return_value=1000000000): + state = BlockMetricsState() + with patch("time.monotonic_ns", return_value=6500000000): + assert abs(state.get_lifetime_seconds() - 5.5) < 0.001 + + def test_idle_time(self): + with patch("time.monotonic_ns", return_value=1000000000): + state = BlockMetricsState() + state.last_access_ns = 2000000000 + with patch("time.monotonic_ns", return_value=5200000000): + assert abs(state.get_idle_time_seconds() - 3.2) < 0.001 + + def test_reuse_gaps(self): + with patch("time.monotonic_ns", return_value=1000000000): + state = BlockMetricsState() + + base = 1000000000 + for offset in [0, 1.5, 3.0, 5.5]: + state.access_history.append(base + int(offset * 1e9)) + + gaps = state.get_reuse_gaps_seconds() + assert len(gaps) == 3 + assert gaps[0] == 1.5 and gaps[1] == 1.5 and gaps[2] == 2.5 + + def test_ring_wrap_only_gives_3_gaps(self): + # 5 accesses in size-4 buffer = 3 gaps + with patch("time.monotonic_ns", return_value=1000000000): + state = BlockMetricsState() + + for i in range(5): + state.access_history.append(1000000000 + i * 1000000000) + + assert len(state.get_reuse_gaps_seconds()) == 3 + + +class TestKVCacheMetricsCollector: + def test_sample_rate_validation(self): + with pytest.raises(AssertionError): + KVCacheMetricsCollector(sample_rate=-0.1) + with pytest.raises(AssertionError): + KVCacheMetricsCollector(sample_rate=1.5) + with pytest.raises(AssertionError): + KVCacheMetricsCollector(sample_rate=0.0) + + def test_sampling(self): + c = KVCacheMetricsCollector(sample_rate=1.0) + assert sum(1 for _ in range(100) if c.should_sample_block()) == 100 + + c = KVCacheMetricsCollector(sample_rate=0.5) + samples = sum(1 for _ in range(1000) if c.should_sample_block()) + assert 400 < samples < 600 + + def test_alloc(self): + c = KVCacheMetricsCollector(sample_rate=1.0) + + blocks = [KVCacheBlock(block_id=i) for i in range(5)] + with patch("time.monotonic_ns", return_value=1000000000): + for block in blocks: + c.on_block_allocated(block) + + assert len(c.block_metrics) == 5 + + def test_access(self): + c = KVCacheMetricsCollector(sample_rate=1.0) + block = KVCacheBlock(block_id=0) + + with patch("time.monotonic_ns", return_value=1000000000): + c.on_block_allocated(block) + + for i in range(3): + t = 1000000000 + (i + 1) * 1000000000 + with patch("time.monotonic_ns", return_value=t): + c.on_block_accessed(block) + + assert len(c.block_metrics[0].access_history) == 3 + + def test_evict_no_accesses(self): + # lifetime should equal idle if never accessed + c = KVCacheMetricsCollector(sample_rate=1.0) + + block = KVCacheBlock(block_id=0) + with patch("time.monotonic_ns", return_value=1000000000): + c.on_block_allocated(block) + + with patch("time.monotonic_ns", return_value=6000000000): + c.on_block_evicted(block) + + events = c.drain_events() + assert len(events) == 1 + assert abs(events[0].lifetime_seconds - 5.0) < 0.001 + assert abs(events[0].idle_seconds - 5.0) < 0.001 + + def test_evict(self): + c = KVCacheMetricsCollector(sample_rate=1.0) + + block = KVCacheBlock(block_id=0) + with patch("time.monotonic_ns", return_value=1000000000): + c.on_block_allocated(block) + + with patch("time.monotonic_ns", return_value=2000000000): + c.on_block_accessed(block) + with patch("time.monotonic_ns", return_value=3000000000): + c.on_block_accessed(block) + + with patch("time.monotonic_ns", return_value=4000000000): + c.on_block_evicted(block) + + events = c.drain_events() + assert len(events) == 1 + sample = events[0] + assert abs(sample.lifetime_seconds - 3.0) < 0.001 + assert abs(sample.idle_seconds - 1.0) < 0.001 + assert sample.reuse_gaps_seconds == (1.0,) + assert 0 not in c.block_metrics + + def test_reset(self): + c = KVCacheMetricsCollector(sample_rate=1.0) + + with patch("time.monotonic_ns", return_value=1000000000): + for i in range(5): + c.on_block_allocated(KVCacheBlock(block_id=i)) + + assert len(c.block_metrics) == 5 + c.reset() + assert len(c.block_metrics) == 0 + + with patch("time.monotonic_ns", return_value=2000000000): + c.on_block_allocated(KVCacheBlock(block_id=10)) + assert 10 in c.block_metrics + + def test_huge_time_jump(self): + c = KVCacheMetricsCollector(sample_rate=1.0) + + block = KVCacheBlock(block_id=0) + with patch("time.monotonic_ns", return_value=1000000000): + c.on_block_allocated(block) + + with patch("time.monotonic_ns", return_value=9999999999999999): + c.on_block_evicted(block) + + events = c.drain_events() + assert len(events) == 1 + assert events[0].lifetime_seconds > 0 + + +def test_kv_cache_metrics_collector_smoke() -> None: + """Simple smoke test for KVCacheMetricsCollector on CPU.""" + collector = KVCacheMetricsCollector(sample_rate=1.0) + block = KVCacheBlock(block_id=123) + + # Allocate at t = 1.0s. + with patch("time.monotonic_ns", return_value=1_000_000_000): + collector.on_block_allocated(block) + + # Access at t = 2.0s and t = 3.0s. + with patch("time.monotonic_ns", return_value=2_000_000_000): + collector.on_block_accessed(block) + with patch("time.monotonic_ns", return_value=3_000_000_000): + collector.on_block_accessed(block) + + # Evict at t = 4.0s. + with patch("time.monotonic_ns", return_value=4_000_000_000): + collector.on_block_evicted(block) + + events = collector.drain_events() + assert len(events) == 1 + + event = events[0] + # Lifetime: 1.0s → 4.0s. + assert abs(event.lifetime_seconds - 3.0) < 1e-6 + # Idle: last access at 3.0s, evicted at 4.0s. + assert abs(event.idle_seconds - 1.0) < 1e-6 + # One reuse gap between the two accesses. + assert event.reuse_gaps_seconds == (1.0,) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 58a7a2692bfc8..fd5cf6d3e74aa 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1128,7 +1128,11 @@ def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len) dtype="float16", max_model_len=max_model_len, ) - scheduler_config = SchedulerConfig(max_num_batched_tokens=32768) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=32768, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ) vllm_config = VllmConfig( model_config=model_config, @@ -1163,7 +1167,10 @@ def test_get_max_concurrency_for_kv_cache_config(): max_model_len=max_model_len, ) scheduler_config = SchedulerConfig( - max_num_batched_tokens=1024, enable_chunked_prefill=True + max_num_batched_tokens=1024, + enable_chunked_prefill=True, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, ) vllm_config = VllmConfig( diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 64fd5ab1dd9aa..0880a17c78d40 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -45,7 +45,7 @@ pytestmark = pytest.mark.cpu_test def _auto_init_hash_fn(request): hash_fn: Callable if "hash_fn" in request.fixturenames: - hash_fn = init_none_hash(request.getfixturevalue("hash_fn")) + hash_fn = request.getfixturevalue("hash_fn") else: hash_fn = sha256 init_none_hash(hash_fn) diff --git a/tests/v1/core/test_priority_scheduler_random.py b/tests/v1/core/test_priority_scheduler_random.py index b4805be802723..429b179b61dce 100644 --- a/tests/v1/core/test_priority_scheduler_random.py +++ b/tests/v1/core/test_priority_scheduler_random.py @@ -219,7 +219,17 @@ def test_priority_scheduling_blast( vllm_config=scheduler.vllm_config, ) scheduler.add_request(req) - + num_initial_requests = 2 + for _ in range(num_initial_requests): + req = _create_random_request( + max_tokens_range=(1, max_output_tokens), + num_tokens_range=(1, max_input_tokens), + arrival_time_range=(0, 0), + priority_range=(4, 4), + num_mm_item_range=(0, 2), + vllm_config=scheduler.vllm_config, + ) + scheduler.add_request(req) for _ in range(20000): if len(scheduler.waiting) == 0: num_new_requests = random.randint(0, 2) diff --git a/tests/v1/core/test_reset_prefix_cache_e2e.py b/tests/v1/core/test_reset_prefix_cache_e2e.py new file mode 100644 index 0000000000000..b80789945d2fc --- /dev/null +++ b/tests/v1/core/test_reset_prefix_cache_e2e.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm import EngineArgs, LLMEngine, SamplingParams + +PROMPTS = [ + "A robot may not injure a human being ", + "To be or not to be,", + "What is the meaning of life?", + "What does the fox say? " * 20, # Test long prompt +] + + +def test_reset_prefix_cache_e2e(monkeypatch): + # "spawn" is required for test to be deterministic + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + engine_args = EngineArgs( + model="Qwen/Qwen3-0.6B", + gpu_memory_utilization=0.2, + async_scheduling=True, + max_num_batched_tokens=32, + max_model_len=2048, + compilation_config={"mode": 0}, + dtype="float16", + ) + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=16, + ) + + # No preempt case: + for i, prompt in enumerate(PROMPTS): + engine.add_request("ground_truth_" + str(i), prompt, sampling_params) + + ground_truth_results = {} + while engine.has_unfinished_requests(): + request_outputs = engine.step() + for request_output in request_outputs: + if request_output.finished: + ground_truth_results[request_output.request_id] = request_output + + # Preempt case: + for i, prompt in enumerate(PROMPTS): + engine.add_request("preempted_" + str(i), prompt, sampling_params) + + step_id = 0 + preempted_results = {} + while engine.has_unfinished_requests(): + if step_id == 10: + engine.reset_prefix_cache(reset_running_requests=True) + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + preempted_results[request_output.request_id] = request_output + step_id += 1 + + for i in range(len(PROMPTS)): + assert ( + ground_truth_results["ground_truth_" + str(i)].outputs[0].text + == preempted_results["preempted_" + str(i)].outputs[0].text + ), ( + f"ground_truth_results['ground_truth_{i}'].outputs[0].text=" + f"{ground_truth_results['ground_truth_' + str(i)].outputs[0].text} " + f"preempted_results['preempted_{i}'].outputs[0].text=" + f"{preempted_results['preempted_' + str(i)].outputs[0].text}" + ) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index fe4153e609971..1999e9f6c3b99 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -728,6 +728,37 @@ def test_preempt_during_execution(): assert requests[1].output_token_ids[0] == 42 +def test_scheduler_reset_prefix_cache(): + scheduler = create_scheduler(enable_prefix_caching=True) + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + # Initial scheduling, requests should be at the running state now + _ = scheduler.schedule() + + # Verify requests moved from waiting to running + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == len(requests) + for i, request in enumerate(requests): + assert scheduler.running[i] == request + + # Reset prefix cache should fail since there are still running requests + # and they are taking KV cache + assert not scheduler.reset_prefix_cache() + + # Reset prefix cache with reset_running_requests=True. All running requests + # Should be pushed back to the waiting queue and kv cache should be freed + assert scheduler.reset_prefix_cache(reset_running_requests=True) + + # Verify requests moved from running to waiting + assert len(scheduler.waiting) == len(requests) + assert len(scheduler.running) == 0 + + for i, request in enumerate(requests): + assert scheduler.waiting[i] == request + + # Note - these test cases mirror some of those in test_rejection_sampler.py @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", @@ -1477,6 +1508,12 @@ def create_scheduler_with_priority( Returns: {class}`Scheduler` instance with priority scheduling """ + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype="float16", + seed=42, + ) if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( @@ -1486,14 +1523,9 @@ def create_scheduler_with_priority( long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, + is_encoder_decoder=model_config.is_encoder_decoder, policy="priority", # Enable priority scheduling ) - model_config = ModelConfig( - model=model, - trust_remote_code=True, - dtype="float16", - seed=42, - ) # Cache config, optionally force APC cache_config = CacheConfig( block_size=block_size, @@ -1504,7 +1536,7 @@ def create_scheduler_with_priority( ) kv_transfer_config = ( KVTransferConfig( - kv_connector="SharedStorageConnector", + kv_connector="ExampleConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) @@ -1520,7 +1552,7 @@ def create_scheduler_with_priority( ec_transfer_config = ( ECTransferConfig( - ec_connector="ECSharedStorageConnector", + ec_connector="ECExampleConnector", ec_role=ec_role, ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"}, ) @@ -2381,7 +2413,7 @@ def _assert_right_ec_connector_metadata( metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas} # Check all required identifiers exist in metadata; and no extra - # In ECSharedStorageConnector format + # In ECExampleConnector format # NOTE: even having same identifier, the mm_features can be different # since their mm_position can be in different offsets, etc identifiers_dict = {f.identifier for f in mm_features_list} diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 7537c7a60476b..531b9c595b04d 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -69,6 +69,13 @@ def create_scheduler( Returns: {class}`Scheduler` instance """ + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype="float16", + seed=42, + skip_tokenizer_init=skip_tokenizer_init, + ) if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( @@ -79,13 +86,7 @@ def create_scheduler( disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=enable_chunked_prefill, async_scheduling=async_scheduling, - ) - model_config = ModelConfig( - model=model, - trust_remote_code=True, - dtype="float16", - seed=42, - skip_tokenizer_init=skip_tokenizer_init, + is_encoder_decoder=model_config.is_encoder_decoder, ) # Cache config, optionally force APC cache_config = CacheConfig( @@ -107,7 +108,7 @@ def create_scheduler( ) elif use_kv_connector: kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", + kv_connector="ExampleConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) @@ -120,7 +121,7 @@ def create_scheduler( ec_transfer_config = ( ECTransferConfig( - ec_connector="ECSharedStorageConnector", + ec_connector="ECExampleConnector", ec_role=ec_role, ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"}, ) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 314e7094ef97f..0e71d6c63ce68 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -40,7 +40,9 @@ def _create_vllm_config( ) -> MagicMock: mock_config = MagicMock(spec=VllmConfig) mock_config.compilation_config = compilation_config - mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) + mock_config.scheduler_config = SchedulerConfig.default_factory( + max_num_seqs=max_num_seqs, + ) mock_config.parallel_config = ParallelConfig() mock_config.speculative_config = None # No speculative decoding if not lora_config: @@ -159,10 +161,10 @@ class TestCudagraphDispatcher: assert rt_mode == CUDAGraphMode.NONE assert key == BatchDescriptor(num_tokens=15) - # 4. Cascade attention should have a fall back mode + # 4. disable_full should have a fall back mode (e.g., cascade attention) desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False) rt_mode, key = dispatcher.dispatch( - num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True + num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True ) if "PIECEWISE" in cudagraph_mode_str: # string contains check assert rt_mode == CUDAGraphMode.PIECEWISE diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 12621d493e549..b1895e83b8b37 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -100,32 +100,20 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte # test cudagraph_mode with different compilation mode. # (backend_name, cudagraph_mode, compilation_mode, supported) -if current_platform.is_rocm(): - combo_cases_2 = [ - ("RocmAttn", "FULL", CompilationMode.NONE, True), - ("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "PIECEWISE", CompilationMode.NONE, False), - ("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), - ("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True), - ("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), - ("RocmAttn", "NONE", CompilationMode.NONE, True), - ("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True), - ] -else: - combo_cases_2 = [ - ("FA2", "FULL", CompilationMode.NONE, True), - ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True), - ("FA2", "PIECEWISE", CompilationMode.NONE, True), - ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True), - ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True), - ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), - ("FA2", "NONE", CompilationMode.NONE, True), - ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True), - ] +attn_backend = "RocmAttn" if current_platform.is_rocm() else "FA2" + +combo_cases_2 = [ + (attn_backend, "FULL", CompilationMode.NONE, True), + (attn_backend, "FULL", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "PIECEWISE", CompilationMode.NONE, True), + (attn_backend, "PIECEWISE", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.NONE, True), + (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "FULL_DECODE_ONLY", CompilationMode.NONE, True), + (attn_backend, "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), + (attn_backend, "NONE", CompilationMode.NONE, True), + (attn_backend, "NONE", CompilationMode.VLLM_COMPILE, True), +] @pytest.mark.parametrize( diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 4311547baccf1..1c45e7fe366ff 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -10,6 +10,7 @@ from utils import ( BACKENDS, _extract_step_logprobs, _random_prompt, + is_device_capability_below_90, resolve_model_name, skip_unsupported, ) @@ -17,6 +18,8 @@ from utils import ( import vllm.model_executor.layers.batch_invariant as batch_invariant from vllm import LLM, SamplingParams +IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90() + @skip_unsupported @pytest.mark.timeout(1000) @@ -185,11 +188,12 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enable_prefix_caching=False, + # enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", # not everything is supported gpu_memory_utilization=0.9, + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, ) # Use more realistic prompts for better token generation @@ -394,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): max_model_len=2048, dtype="bfloat16", enable_prefix_caching=False, + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, ) prompt = "the capital of france is" @@ -457,10 +462,10 @@ def test_logprobs_without_batch_invariance_should_fail( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, ) # build ragged prompts to change shapes significantly across BS=1 vs BS=N @@ -681,10 +686,10 @@ def test_decode_logprobs_match_prefill_logprobs( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, ) # Use a few test prompts @@ -929,6 +934,7 @@ def LLM_with_max_seqs( dtype="bfloat16", tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), enable_prefix_caching=False, + enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90, # Enable for MOE models # enable_expert_parallel=True, ) diff --git a/tests/v1/determinism/test_online_batch_invariance.py b/tests/v1/determinism/test_online_batch_invariance.py index d74b435797f8f..5e3b997364949 100644 --- a/tests/v1/determinism/test_online_batch_invariance.py +++ b/tests/v1/determinism/test_online_batch_invariance.py @@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( } tp_size = os.getenv("VLLM_TP_SIZE", "1") - server_args: list[str] = [] + server_args: list[str] = [ + "--max-model-len=8192", + "--max-num-seqs=32", + ] if tp_size: server_args += ["-tp", tp_size] diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py index 0d7da107728b4..a8013ed229cfc 100644 --- a/tests/v1/determinism/utils.py +++ b/tests/v1/determinism/utils.py @@ -11,12 +11,15 @@ from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer skip_unsupported = pytest.mark.skipif( - not (current_platform.is_cuda() and current_platform.has_device_capability(90)), - reason="Requires CUDA and >= Hopper (SM90)", + not (current_platform.is_cuda() and current_platform.has_device_capability(80)), + # Supports testing on Ampere and Ada Lovelace devices. + # Note: For devices with SM < 90, batch invariance does not support CUDA Graphs. + reason="Requires CUDA and >= Ampere (SM80)", ) BACKENDS: list[str] = [ "FLASH_ATTN", + "TRITON_MLA", ] if has_flashinfer(): @@ -96,3 +99,7 @@ def _extract_step_logprobs(request_output): return t, inner.token_ids return None, None + + +def is_device_capability_below_90() -> bool: + return not current_platform.has_device_capability(90) diff --git a/tests/v1/distributed/test_dbo.py b/tests/v1/distributed/test_dbo.py index 16f154d196ba5..e5cbe1ce85e96 100644 --- a/tests/v1/distributed/test_dbo.py +++ b/tests/v1/distributed/test_dbo.py @@ -9,9 +9,22 @@ correctly with the DeepSeek-V2-Lite model using GSM8K evaluation. """ import pytest +import torch from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k from tests.utils import RemoteOpenAIServer +from vllm.utils.import_utils import has_deep_ep + +# Detect Blackwell / B200 (compute capability 10.x) +try: + if torch.cuda.is_available(): + cap = torch.cuda.get_device_capability(0) + IS_BLACKWELL = cap[0] >= 10 + else: + IS_BLACKWELL = False +except Exception: + # Be conservative: if we can't detect, don't xfail by default + IS_BLACKWELL = False MODEL_NAME = "deepseek-ai/DeepSeek-V2-Lite-Chat" DP_SIZE = 2 @@ -32,7 +45,15 @@ DEEPEP_BACKENDS = [ ] +@pytest.mark.skipif(not has_deep_ep(), reason="These tests require deep_ep to run") @pytest.mark.parametrize("all2all_backend", DEEPEP_BACKENDS) +@pytest.mark.xfail( + IS_BLACKWELL, + reason=( + "Temporary: DBO accuracy unstable on Blackwell " + "(doesn't meet expectation of MIN_ACCURACY = 0.62)" + ), +) def test_dbo_dp_ep_gsm8k(all2all_backend: str, num_gpus_available): """ Test DBO with DP+EP using GSM8K evaluation. diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 945276376d665..5cef9b33c9984 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -8,6 +8,7 @@ import torch._dynamo.config as dynamo_config from vllm import SamplingParams from vllm.logprobs import Logprob +from vllm.platforms import current_platform from vllm.sampling_params import StructuredOutputsParams from vllm.v1.metrics.reader import Metric @@ -70,6 +71,18 @@ def test_without_spec_decoding( (True, "uni", True, None, True), ] + if current_platform.is_rocm(): + # On ROCm, Only test with structured_outputs (deterministic) + # and skip chunk_prefill (more variable). + test_configs = [ + cfg + for cfg in test_configs + if not cfg[4] # skip chunk_prefill=True + ] + test_sampling_params = [ + p for p in test_sampling_params if p.get("structured_outputs") is not None + ] + run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) @@ -108,7 +121,14 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): (True, "uni", True, spec_config_short, True), ] - run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params) + # On ROCm, use TRITON_ATTN + float32 for better numerical consistency + run_tests( + monkeypatch, + MTP_MODEL, + test_configs, + test_sampling_params, + is_testing_with_spec_decoding=True, + ) @dynamo_config.patch(cache_size_limit=16) @@ -117,13 +137,23 @@ def run_tests( model: str, test_configs: list[tuple], test_sampling_params: list[dict[str, Any]], + is_testing_with_spec_decoding: bool = False, ): """Test consistency of combos of async scheduling, preemption, uni/multiproc executor with spec decoding.""" with monkeypatch.context() as m: # avoid precision errors - m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") + if current_platform.is_rocm(): + if is_testing_with_spec_decoding: + # Use TRITON_ATTN for spec decoding test for consistency + m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") + else: + m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA") + else: + m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") + # lock matmul precision to full FP32 (IEEE) + m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee") # m.setenv("VLLM_BATCH_INVARIANT", "1") outputs: list[tuple[str, list, list]] = [] for n, ( @@ -143,6 +173,7 @@ def run_tests( async_scheduling, spec_config, test_prefill_chunking=test_prefill_chunking, + is_testing_with_spec_decoding=is_testing_with_spec_decoding, ) outputs.append(test_results) @@ -172,17 +203,34 @@ def run_tests( name_0=f"baseline=[{baseline_config}], params={params}", name_1=f"config=[{test_config}], params={params}", ) - assert _all_logprobs_match(base_logprobs, test_logprobs) + + # On ROCm with TRITON_ATTN (spec decoding test), skip strict + # logprobs comparison when logprobs are requested + skip_logprobs_check = ( + current_platform.is_rocm() + and params.get("logprobs") + and is_testing_with_spec_decoding + ) + if not skip_logprobs_check: + assert _all_logprobs_match(base_logprobs, test_logprobs) if ( base_acceptance_rate is not None and test_acceptance_rate is not None ): if "spec_mml=None" in test_config: + # Preemption causes more variance in acceptance rates + if ( + current_platform.is_rocm() + and "preemption=True" in test_config + ): + tolerance = 0.10 + else: + tolerance = 0.05 assert ( test_acceptance_rate > base_acceptance_rate or test_acceptance_rate - == pytest.approx(base_acceptance_rate, rel=5e-2) + == pytest.approx(base_acceptance_rate, rel=tolerance) ) else: # Currently the reported acceptance rate is expected to be @@ -213,6 +261,7 @@ def run_test( async_scheduling: bool, spec_config: dict[str, Any] | None, test_prefill_chunking: bool, + is_testing_with_spec_decoding: bool = False, ): spec_decoding = spec_config is not None cache_arg: dict[str, Any] = ( @@ -231,6 +280,15 @@ def run_test( print("-" * 80) print(f"---- TESTING {test_str}: {test_config}") print("-" * 80) + + # On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for + # spec decoding test (TRITON_ATTN) for better precision. + # On others: always use float32. + if current_platform.is_rocm() and not is_testing_with_spec_decoding: + dtype = "float16" + else: + dtype = "float32" + with VllmRunner( model, max_model_len=512, @@ -240,7 +298,7 @@ def run_test( # enforce_eager=True, async_scheduling=async_scheduling, distributed_executor_backend=executor, - dtype="float32", # avoid precision errors + dtype=dtype, speculative_config=spec_config, disable_log_stats=False, **cache_arg, @@ -300,11 +358,21 @@ def _all_logprobs_match(req_a, req_b) -> bool: def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool: - return len(lps_a) == len(lps_b) and all( - a.decoded_token == b.decoded_token - and a.rank == b.rank - and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6) - for a, b in ((lps_a[x], lps_b[x]) for x in lps_a) + if current_platform.is_rocm(): + # ROCm has higher numerical variance + # due to use of float16. + rel_tol, abs_tol = 5e-2, 1e-5 + else: + rel_tol, abs_tol = 1e-3, 1e-6 + return ( + len(lps_a) == len(lps_b) + and lps_a.keys() == lps_b.keys() + and all( + a.decoded_token == b.decoded_token + and a.rank == b.rank + and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol) + for a, b in ((lps_a[x], lps_b[x]) for x in lps_a) + ) ) diff --git a/tests/v1/e2e/test_async_spec_decode.py b/tests/v1/e2e/test_async_spec_decode.py new file mode 100644 index 0000000000000..561f37a52d573 --- /dev/null +++ b/tests/v1/e2e/test_async_spec_decode.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test that verifies no implicit GPU-CPU synchronization occurs during +speculative decoding generation under expected conditions. +""" + +import multiprocessing +import sys +import traceback + +import pytest +import torch + + +@pytest.fixture +def sync_tracker(): + """ + Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect + lazy init syncs. Prints stack traces immediately when syncs occur. + """ + from vllm.v1.attention.backends.utils import CommonAttentionMetadata + + # Shared counter for cross-process communication (inherited by fork) + sync_count = multiprocessing.Value("i", 0) + + # Save original property + original_prop = CommonAttentionMetadata.seq_lens_cpu + original_fget = original_prop.fget + + # Create tracking wrapper + def tracking_seq_lens_cpu(self): + if self._seq_lens_cpu is None: + # Increment counter + with sync_count.get_lock(): + sync_count.value += 1 + count = sync_count.value + # Print stack trace immediately (shows in subprocess output) + print(f"\n{'=' * 60}", file=sys.stderr) + print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr) + print(f"{'=' * 60}", file=sys.stderr) + traceback.print_stack(file=sys.stderr) + print(f"{'=' * 60}\n", file=sys.stderr) + sys.stderr.flush() + return original_fget(self) + + # Apply patch + CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu) + + class SyncTracker: + @property + def count(self) -> int: + return sync_count.value + + def assert_no_sync(self, msg: str = ""): + count = sync_count.value + assert count == 0, ( + f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered " + f"{count} times. See stack traces above. {msg}" + ) + + yield SyncTracker() + + # Restore original property + CommonAttentionMetadata.seq_lens_cpu = original_prop + torch._dynamo.reset() + + +# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env) +SPEC_DECODE_CONFIGS = [ + pytest.param( + "meta-llama/Llama-3.2-1B-Instruct", + "nm-testing/Llama3_2_1B_speculator.eagle3", + "eagle3", + 2, + id="eagle3-llama", + ), + pytest.param( + "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", + "eagle", + 2, + id="eagle-mla-deepseek", + ), +] + + +@pytest.mark.parametrize( + "model,spec_model,method,num_spec_tokens", + SPEC_DECODE_CONFIGS, +) +def test_no_sync_with_spec_decode( + sync_tracker, + model: str, + spec_model: str, + method: str, + num_spec_tokens: int, +): + """ + Test that no implicit GPU-CPU sync occurs during speculative decoding + generation. + """ + # Import vLLM AFTER sync_tracker fixture has applied the patch + from vllm import LLM, SamplingParams + from vllm.distributed import cleanup_dist_env_and_memory + + llm = LLM( + model=model, + max_model_len=256, + speculative_config={ + "method": method, + "num_speculative_tokens": num_spec_tokens, + "model": spec_model, + }, + enforce_eager=True, + async_scheduling=True, + ) + + outputs = llm.generate( + ["Hello, my name is"], + SamplingParams(temperature=0, max_tokens=10), + ) + + assert len(outputs) == 1 + assert len(outputs[0].outputs[0].text) > 0 + + del llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + sync_tracker.assert_no_sync() diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 71b0e86c75c18..b6a78eaa09209 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import pytest from vllm import LLM, SamplingParams +from vllm.platforms import current_platform from ...utils import check_answers, prep_prompts @@ -40,10 +41,17 @@ def test_sliding_window_retrieval( If we tell it upfront which we are going to be looking for, then it answers correctly (mostly). """ + # NOTE: For ROCm, we have to enforce eager mode to use custom kernel + # implementation of GELU with tanh approximation, as PyTorch's native + # implementation is currently unstable with torch.compile and produces garbage. + enforce_eager = current_platform.is_rocm() + test_config = model_config[model] llm = LLM( - model=model, disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager + model=model, + disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager, + enforce_eager=enforce_eager, ) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index 2778b0c5e5670..f895fb72e94a1 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -7,6 +7,7 @@ import pytest from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode +from vllm.platforms import current_platform from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts @@ -43,15 +44,26 @@ def test_prompts(): return prompts -@fork_new_process_for_each_test +use_fork_for_test = ( + fork_new_process_for_each_test if not current_platform.is_rocm() else lambda x: x +) + + +@use_fork_for_test @pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True]) @pytest.mark.parametrize("enforce_eager", [True, False]) def test_kv_sharing_fast_prefill( monkeypatch: pytest.MonkeyPatch, kv_sharing_fast_prefill: bool, enforce_eager: bool, - test_prompts: list[str], ): + if not enforce_eager and current_platform.is_rocm(): + # Relevant context: https://github.com/vllm-project/vllm/pull/29244 + pytest.skip( + "ROCm: torch.compile produces incorrect output for gemma-3n's GELU " + "with tanh approximation. Use enforce_eager=True instead." + ) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) compilation_config = CompilationConfig( # This allows vLLM compilation backend to handle allocating and @@ -65,7 +77,11 @@ def test_kv_sharing_fast_prefill( with monkeypatch.context() as m: # Make scheduling deterministic for reproducibility - m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + if current_platform.is_rocm(): + # Use spawn to prevent cuda re-initialization error + m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + else: + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") prompts, answer, indices = prep_prompts(batch_size) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 3a25f7411eecd..fcfc8bdce12e9 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -16,6 +16,16 @@ from vllm.platforms import current_platform MTP_SIMILARITY_RATE = 0.8 +def _skip_if_insufficient_gpus_for_tp(tp_size: int): + """Skip test if available GPUs < tp_size on ROCm.""" + if current_platform.is_rocm(): + available_gpus = torch.cuda.device_count() + if available_gpus < tp_size: + pytest.skip( + f"Test requires {tp_size} GPUs, but only {available_gpus} available" + ) + + def get_test_prompts(mm_enabled: bool): prompt_types = ["repeat", "sentence"] if mm_enabled: @@ -191,8 +201,8 @@ def test_suffix_decoding_acceptance( # Expect the acceptance rate to improve. assert first_accept_rate < last_accept_rate - # Heuristic: expect at least 85% acceptance rate at the end. - assert last_accept_rate > 0.85 + # Heuristic: expect at least 80.0% acceptance rate at the end. + assert last_accept_rate > 0.80 del spec_llm torch.cuda.empty_cache() @@ -280,9 +290,20 @@ def test_speculators_model_integration( @pytest.mark.parametrize( - ["model_setup", "mm_enabled", "enable_chunked_prefill"], + ["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"], [ - (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False), + ( + ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), + False, + False, + "auto", + ), + ( + ("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), + False, + False, + "transformers", + ), pytest.param( ( "eagle3", @@ -292,6 +313,7 @@ def test_speculators_model_integration( ), False, False, + "auto", marks=pytest.mark.skip( reason="architecture of its eagle3 is LlamaForCausalLMEagle3" ), @@ -305,6 +327,7 @@ def test_speculators_model_integration( ), False, False, + "auto", marks=pytest.mark.skip( reason="Skipping due to its head_dim not being a a multiple of 32" ), @@ -318,6 +341,7 @@ def test_speculators_model_integration( ), False, True, + "auto", marks=large_gpu_mark(min_gb=40), ), # works on 4x H100 ( @@ -329,6 +353,7 @@ def test_speculators_model_integration( ), False, False, + "auto", ), pytest.param( ( @@ -339,6 +364,7 @@ def test_speculators_model_integration( ), False, False, + "auto", marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 pytest.param( @@ -350,6 +376,7 @@ def test_speculators_model_integration( ), True, True, + "auto", marks=large_gpu_mark(min_gb=80), ), # works on 4x H100 ( @@ -361,10 +388,12 @@ def test_speculators_model_integration( ), False, False, + "auto", ), ], ids=[ "qwen3_eagle3", + "qwen3_eagle3-transformers", "qwen3_vl_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", @@ -381,6 +410,7 @@ def test_eagle_correctness( model_setup: tuple[str, str, str, int], mm_enabled: bool, enable_chunked_prefill: bool, + model_impl: str, attn_backend: str, ): if attn_backend == "TREE_ATTN": @@ -389,6 +419,17 @@ def test_eagle_correctness( "TREE_ATTN is flaky in the test disable for now until it can be " "resolved (see https://github.com/vllm-project/vllm/issues/22922)" ) + if model_impl == "transformers": + import transformers + from packaging.version import Version + + installed = Version(transformers.__version__) + required = Version("5.0.0.dev") + if installed < required: + pytest.skip( + "Eagle3 with the Transformers modeling backend requires " + f"transformers>={required}, but got {installed}" + ) # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) @@ -402,7 +443,11 @@ def test_eagle_correctness( # Scout requires default backend selection # because vision encoder has head_dim 88 being incompatible # with FLASH_ATTN and needs to fall back to Flex Attn - pass + + # pass if not ROCm + if current_platform.is_rocm(): + # TODO: Enable Flex Attn for spec_decode on ROCm + pytest.skip("Flex Attn for spec_decode not supported on ROCm currently") else: m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) @@ -413,10 +458,15 @@ def test_eagle_correctness( "multi-token eagle spec decode on current platform" ) - if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): - m.setenv("VLLM_ROCM_USE_AITER", "1") + if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): + if "deepseek" in model_setup[1].lower(): + pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform") + else: + m.setenv("VLLM_ROCM_USE_AITER", "1") method, model_name, spec_model_name, tp_size = model_setup + _skip_if_insufficient_gpus_for_tp(tp_size) + max_model_len = 2048 max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len @@ -441,6 +491,7 @@ def test_eagle_correctness( max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=enable_chunked_prefill, + model_impl=model_impl, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 @@ -486,6 +537,7 @@ def test_mtp_correctness( m.setenv("VLLM_MLA_DISABLE", "1") method, model_name, tp_size = model_setup + _skip_if_insufficient_gpus_for_tp(tp_size) ref_llm = LLM( model=model_name, diff --git a/tests/v1/ec_connector/integration/run_epd_correctness_test.sh b/tests/v1/ec_connector/integration/run_epd_correctness_test.sh index 55dd39c0a957f..0c2666306558c 100644 --- a/tests/v1/ec_connector/integration/run_epd_correctness_test.sh +++ b/tests/v1/ec_connector/integration/run_epd_correctness_test.sh @@ -148,7 +148,7 @@ run_epd_1e_1pd() { --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_producer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" @@ -167,7 +167,7 @@ run_epd_1e_1pd() { --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_consumer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" @@ -348,7 +348,7 @@ run_epd_1e_1p_1d() { --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_producer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" @@ -369,7 +369,7 @@ run_epd_1e_1p_1d() { --max-num-seqs 128 \ --allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \ --ec-transfer-config '{ - "ec_connector": "ECSharedStorageConnector", + "ec_connector": "ECExampleConnector", "ec_role": "ec_consumer", "ec_connector_extra_config": { "shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'" diff --git a/tests/v1/ec_connector/unit/test_ec_shared_storage_connector.py b/tests/v1/ec_connector/unit/test_ec_example_connector.py similarity index 90% rename from tests/v1/ec_connector/unit/test_ec_shared_storage_connector.py rename to tests/v1/ec_connector/unit/test_ec_example_connector.py index a58daa2628e21..7e9eb21310031 100644 --- a/tests/v1/ec_connector/unit/test_ec_shared_storage_connector.py +++ b/tests/v1/ec_connector/unit/test_ec_example_connector.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Unit tests for ECSharedStorageConnector. +Unit tests for ECExampleConnector. """ import os @@ -13,9 +13,9 @@ import torch from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole -from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import ( - ECSharedStorageConnector, - ECSharedStorageConnectorMetadata, +from vllm.distributed.ec_transfer.ec_connector.example_connector import ( + ECExampleConnector, + ECExampleConnectorMetadata, MMMeta, ) from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange @@ -81,12 +81,12 @@ def mock_request_with_3_mm(): # ------------------ Unit Tests ------------------ # -class TestECSharedStorageConnectorBasics: +class TestECExampleConnectorBasics: """Test basic EC connector functionality.""" def test_initialization_producer(self, mock_vllm_config_producer, temp_storage): """Test connector initializes correctly as producer.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) @@ -98,7 +98,7 @@ class TestECSharedStorageConnectorBasics: def test_initialization_consumer(self, mock_vllm_config_consumer, temp_storage): """Test connector initializes correctly as consumer.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) @@ -109,11 +109,11 @@ class TestECSharedStorageConnectorBasics: def test_role_assignment(self, mock_vllm_config_producer): """Test role is correctly assigned.""" - scheduler_connector = ECSharedStorageConnector( + scheduler_connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) - worker_connector = ECSharedStorageConnector( + worker_connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.WORKER, ) @@ -133,7 +133,7 @@ class TestCacheExistence: ): """Test has_caches returns True when all 3 caches exist.""" # Test for producer first - producer = ECSharedStorageConnector( + producer = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) @@ -154,7 +154,7 @@ class TestCacheExistence: assert all(producer_result), f"Expected all True, got {producer_result}" # Also test consumer can check if cache exists - consumer = ECSharedStorageConnector( + consumer = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.SCHEDULER, ) @@ -170,7 +170,7 @@ class TestCacheExistence: self, mock_vllm_config_producer, mock_request_with_3_mm ): """Test has_caches returns False when no caches exist.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) @@ -186,7 +186,7 @@ class TestCacheExistence: self, mock_vllm_config_producer, mock_request_with_3_mm ): """Test has_caches with some caches existing (1 of 3).""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) @@ -213,7 +213,7 @@ class TestStateManagement: self, mock_vllm_config_producer, mock_request_with_3_mm ): """Test state update after allocation for 3 MM items.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) @@ -238,7 +238,7 @@ class TestStateManagement: self, mock_vllm_config_producer, mock_request_with_3_mm ): """Test metadata building for 3 MM items.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) @@ -252,7 +252,7 @@ class TestStateManagement: metadata = connector.build_connector_meta(scheduler_output) # Assert - assert isinstance(metadata, ECSharedStorageConnectorMetadata) + assert isinstance(metadata, ECExampleConnectorMetadata) assert len(metadata.mm_datas) == 3 assert metadata.mm_datas[0].mm_hash == "img_hash_1" assert metadata.mm_datas[0].num_token == 100 @@ -266,7 +266,7 @@ class TestStateManagement: def test_build_connector_meta_empty(self, mock_vllm_config_producer): """Test metadata building with empty state.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) @@ -274,14 +274,14 @@ class TestStateManagement: scheduler_output = Mock(spec=SchedulerOutput) metadata = connector.build_connector_meta(scheduler_output) - assert isinstance(metadata, ECSharedStorageConnectorMetadata) + assert isinstance(metadata, ECExampleConnectorMetadata) assert len(metadata.mm_datas) == 0 def test_state_cleared_after_metadata_build( self, mock_vllm_config_producer, mock_request_with_3_mm ): """Test that state is properly cleared after building metadata.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) @@ -310,7 +310,7 @@ class TestCacheSaving: self, mock_vllm_config_producer, mock_request_with_3_mm, temp_storage ): """Test cache saving as producer for 3 different MM items.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.WORKER, ) @@ -336,7 +336,7 @@ class TestCacheSaving: def test_save_caches_consumer_skips(self, mock_vllm_config_consumer): """Test cache saving is skipped for consumer.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) @@ -366,7 +366,7 @@ class TestCacheLoading: ): """Test consumer loads 3 caches from storage.""" # First, create producer to save caches - producer = ECSharedStorageConnector( + producer = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.WORKER, ) @@ -379,13 +379,13 @@ class TestCacheLoading: producer.save_caches(saved_caches, mm_hash) # Now consumer loads - consumer = ECSharedStorageConnector( + consumer = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) # Setup metadata for all 3 - metadata = ECSharedStorageConnectorMetadata() + metadata = ECExampleConnectorMetadata() for mm_hash in mm_hashes: metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100)) consumer.bind_connector_metadata(metadata) @@ -410,7 +410,7 @@ class TestCacheLoading: ): """Test cache loading skips already cached items.""" # Setup: producer saves cache - producer = ECSharedStorageConnector( + producer = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.WORKER, ) @@ -420,12 +420,12 @@ class TestCacheLoading: producer.save_caches({mm_hash: saved_cache}, mm_hash) # Consumer setup - consumer = ECSharedStorageConnector( + consumer = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) - metadata = ECSharedStorageConnectorMetadata() + metadata = ECExampleConnectorMetadata() metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100)) consumer.bind_connector_metadata(metadata) @@ -444,13 +444,13 @@ class TestCacheLoading: def test_start_load_caches_empty_metadata(self, mock_vllm_config_consumer): """Test loading with empty metadata does nothing.""" - consumer = ECSharedStorageConnector( + consumer = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) # Setup empty metadata - metadata = ECSharedStorageConnectorMetadata() + metadata = ECExampleConnectorMetadata() consumer.bind_connector_metadata(metadata) # Load (should not raise) @@ -466,7 +466,7 @@ class TestFilenameGeneration: def test_generate_foldername(self, mock_vllm_config_producer, temp_storage): """Test folder name generation.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.WORKER, ) @@ -479,7 +479,7 @@ class TestFilenameGeneration: def test_generate_filename(self, mock_vllm_config_producer, temp_storage): """Test filename generation.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.WORKER, ) @@ -493,7 +493,7 @@ class TestFilenameGeneration: def test_generate_filename_consistency(self, mock_vllm_config_producer): """Test filename generation is consistent.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.WORKER, ) @@ -510,12 +510,12 @@ class TestMetadataBindingLifecycle: def test_bind_connector_metadata(self, mock_vllm_config_consumer): """Test binding connector metadata.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) - metadata = ECSharedStorageConnectorMetadata() + metadata = ECExampleConnectorMetadata() metadata.add_mm_data(MMMeta.make_meta("hash_1", 100)) connector.bind_connector_metadata(metadata) @@ -524,12 +524,12 @@ class TestMetadataBindingLifecycle: def test_clear_connector_metadata(self, mock_vllm_config_consumer): """Test clearing connector metadata.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) - metadata = ECSharedStorageConnectorMetadata() + metadata = ECExampleConnectorMetadata() connector.bind_connector_metadata(metadata) connector.clear_connector_metadata() @@ -538,12 +538,12 @@ class TestMetadataBindingLifecycle: def test_get_connector_metadata(self, mock_vllm_config_consumer): """Test getting connector metadata.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) - metadata = ECSharedStorageConnectorMetadata() + metadata = ECExampleConnectorMetadata() connector.bind_connector_metadata(metadata) retrieved = connector._get_connector_metadata() @@ -552,7 +552,7 @@ class TestMetadataBindingLifecycle: def test_get_connector_metadata_not_set(self, mock_vllm_config_consumer): """Test getting metadata when not set raises.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) @@ -566,7 +566,7 @@ class TestEdgeCases: def test_save_empty_cache(self, mock_vllm_config_producer): """Test saving empty tensor.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.WORKER, ) @@ -579,12 +579,12 @@ class TestEdgeCases: def test_load_nonexistent_cache(self, mock_vllm_config_consumer): """Test loading cache that doesn't exist raises error.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_consumer, role=ECConnectorRole.WORKER, ) - metadata = ECSharedStorageConnectorMetadata() + metadata = ECExampleConnectorMetadata() metadata.add_mm_data(MMMeta.make_meta("nonexistent_hash", 100)) connector.bind_connector_metadata(metadata) @@ -596,7 +596,7 @@ class TestEdgeCases: def test_has_caches_empty_request(self, mock_vllm_config_producer): """Test has_caches with request that has no MM data.""" - connector = ECSharedStorageConnector( + connector = ECExampleConnector( vllm_config=mock_vllm_config_producer, role=ECConnectorRole.SCHEDULER, ) diff --git a/tests/v1/engine/test_abort_final_step.py b/tests/v1/engine/test_abort_final_step.py new file mode 100644 index 0000000000000..560c5c2b1e300 --- /dev/null +++ b/tests/v1/engine/test_abort_final_step.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Test for the fix in PR #29987: Eagerly abort cancelled final-step requests. + +This test verifies that when a request is aborted during its final execution +step (when it would naturally complete), it is properly marked as aborted +rather than being treated as normally completed. + +The test uses a dummy KV connector to verify that the connector receives +the correct finish status (FINISHED_ABORTED, not FINISHED_LENGTH_CAPPED). +""" + +import asyncio +import tempfile +import time +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest + +from vllm import SamplingParams +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.platforms import current_platform +from vllm.sampling_params import RequestOutputKind +from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.request import Request + +if not current_platform.is_cuda(): + pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) + +TEXT_PROMPT = "Hello" + + +class DummyKVConnectorMetadata(KVConnectorMetadata): + """Dummy metadata for the test connector.""" + + def __init__(self): + self.requests: list = [] + + +class DummyKVConnector(KVConnectorBase_V1): + """ + Dummy KV connector that captures request finish statuses to a file. + This is used to verify the fix - without the fix, a request aborted + during its final step would be captured as FINISHED_LENGTH_CAPPED + instead of FINISHED_ABORTED. + + The connector runs in a separate process, so we write statuses to a file + that can be read by the test process. + """ + + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: KVCacheConfig | None = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + # Get the status file path from extra config + extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config or {} + self.status_file = extra_config.get("status_file") + # Log that we were initialized + if self.status_file: + try: + with open(self.status_file, "a") as f: + f.write(f"INIT:{role.name}\n") + except Exception: + pass + + def get_num_new_matched_tokens( + self, + request: Request, + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + return (0, False) + + def update_state_after_alloc( + self, + request: Request, + blocks: Any, + num_external_tokens: int, + ): + pass + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + return DummyKVConnectorMetadata() + + def request_finished( + self, + request: Request, + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """Capture the request status when finished by writing to a file.""" + if self.status_file: + try: + with open(self.status_file, "a") as f: + # Write the status name (e.g., "FINISHED_ABORTED") + f.write(f"{request.status.name}\n") + except Exception as e: + # Log but don't fail - this is just test instrumentation + print(f"[DummyKVConnector] Failed to write status: {e}") + return False, None + + def start_load_kv(self, forward_context: Any, **kwargs: Any) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: Any, + attn_metadata: Any, + **kwargs: Any, + ) -> None: + pass + + def wait_for_save(self): + pass + + +# Register the dummy connector +KVConnectorFactory.register_connector( + "DummyKVConnector", __name__, DummyKVConnector.__name__ +) + + +@pytest.mark.parametrize("async_scheduling", [False, True]) +@pytest.mark.asyncio +async def test_abort_during_final_step(async_scheduling: bool): + """ + Test that a request aborted during its final execution step is treated as + aborted rather than completed. + + This test: + 1. Monkeypatches execute_model to wait for a file to be deleted + 2. Configures a dummy KV connector to capture finish statuses + 3. Starts a request with max_tokens=1 (will complete on first decode step) + 4. Aborts the request, then deletes the file to unblock execute_model + 5. Verifies the KV connector received FINISHED_ABORTED not FINISHED_LENGTH_CAPPED + + See https://github.com/vllm-project/vllm/pull/29987. + + Without the fix, the KV connector would see FINISHED_LENGTH_CAPPED because + update_from_output() would mark the request as completed before processing + the abort. This causes KV cache blocks to not be freed properly in + disaggregated prefill scenarios. + + With the fix, _process_aborts_queue() runs before update_from_output(), so the + abort takes precedence and the KV connector sees FINISHED_ABORTED. + """ + + # Create three temporary files: + # 1. ready_file: deleted by execute_model to signal it has started + # 2. block_file: execute_model waits for this to be deleted + # 3. status_file: KV connector writes finish statuses here + with tempfile.NamedTemporaryFile(delete=False) as f: + ready_file = Path(f.name) + with tempfile.NamedTemporaryFile(delete=False) as f2: + block_file = Path(f2.name) + with tempfile.NamedTemporaryFile(delete=False, mode="w") as f3: + status_file = Path(f3.name) + + try: + # Get the original execute_model method + from vllm.v1.worker.gpu_worker import Worker + + original_execute_model = Worker.execute_model + + def execute_model_with_wait(self, scheduler_output): + # Signal that execute_model has been called by deleting ready_file + if ready_file.exists(): + ready_file.unlink() + + # Wait for the block file to be deleted (triggered from test after abort) + # This runs in the worker process (after fork), so we poll the filesystem + while block_file.exists(): + time.sleep(0.01) + return original_execute_model(self, scheduler_output) + + # Patch execute_model to inject the wait + # This happens before the worker process is forked, so the patch applies there + with patch.object(Worker, "execute_model", execute_model_with_wait): + request_id = "test-abort-final-step" + + # Configure engine with dummy KV connector + # Pass the status file path so the connector can write to it + kv_transfer_config = KVTransferConfig( + kv_connector="DummyKVConnector", + kv_role="kv_both", + kv_connector_extra_config={"status_file": str(status_file)}, + ) + engine_args = AsyncEngineArgs( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + async_scheduling=async_scheduling, + kv_transfer_config=kv_transfer_config, + ) + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(engine_args) + + try: + # Create a request that will complete after just 1 token + sampling_params = SamplingParams( + max_tokens=1, + ignore_eos=True, + output_kind=RequestOutputKind.DELTA, + ) + + # Start generation in a task + outputs = [] + + async def generate(): + async for output in engine.generate( + request_id=request_id, + prompt=TEXT_PROMPT, + sampling_params=sampling_params, + ): + outputs.append(output) + + gen_task = asyncio.create_task(generate()) + + # Wait for execute_model to signal it has started (with timeout) + timeout = 5.0 # 5 second timeout + start_time = time.time() + while ready_file.exists(): + if time.time() - start_time > timeout: + raise TimeoutError( + "Timeout waiting for execute_model to start. " + "The monkeypatch may not be working correctly, " + "for example if spawn was used instead of fork." + ) + await asyncio.sleep(0.01) + + # Abort the request while execute_model is blocked + await engine.abort(request_id) + + # Now unblock execute_model by deleting the file + # The abort should be processed before the model output + block_file.unlink() + + # Wait for generation to complete + await gen_task + + # Give the scheduler a moment to finish cleanup + await asyncio.sleep(0.1) + + # Verify we got output + assert len(outputs) > 0, "Should have received at least one output" + + # The final output should have finish_reason="abort" + final_output = outputs[-1] + assert final_output.finished, ( + "Final output should be marked as finished" + ) + assert final_output.outputs[0].finish_reason == "abort", ( + f"Expected finish_reason='abort' but got " + f"'{final_output.outputs[0].finish_reason}'. " + ) + + with open(status_file) as f4: + status_lines = f4.read().strip().split("\n") + # Filter for actual finish statuses (not INIT or empty lines) + captured_statuses = [ + line + for line in status_lines + if line and line.startswith("FINISHED_") + ] + + assert len(captured_statuses) >= 1, ( + f"Expected at least 1 captured finish status, got " + f"{len(captured_statuses)}. File content: {status_lines}" + ) + + assert "FINISHED_ABORTED" in captured_statuses, ( + f"KV connector should see FINISHED_ABORTED but got " + f"{captured_statuses}. " + ) + + # Verify cleanup + assert not engine.output_processor.has_unfinished_requests() + + finally: + # Shutdown the engine + engine.shutdown() + + finally: + # Clean up temporary files if they still exist + if ready_file.exists(): + ready_file.unlink() + if block_file.exists(): + block_file.unlink() + if status_file.exists(): + status_file.unlink() diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index e96759ed66a79..527a56ff49eec 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -9,6 +9,7 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.hashing import _xxhash def test_prefix_caching_from_cli(): @@ -48,6 +49,21 @@ def test_prefix_caching_from_cli(): args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"]) +@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed") +def test_prefix_caching_xxhash_from_cli(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + + # set hash algorithm to xxhash (pickle) + args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"]) + vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() + assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash" + + # set hash algorithm to xxhash_cbor + args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"]) + vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() + assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor" + + def test_defaults_with_usage_context(): engine_args = EngineArgs(model="facebook/opt-125m") vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 3ba8ab26f5522..5fa16897b4e0c 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -484,12 +484,6 @@ def test_encoder_instance_zero_kv_cache( vision encoder, so they don't need KV cache for text generation. """ # Form vllm config - scheduler_config = SchedulerConfig( - max_num_seqs=10, - max_num_batched_tokens=512, - max_model_len=512, - disable_hybrid_kv_cache_manager=True, - ) model_config = ModelConfig( model="llava-hf/llava-1.5-7b-hf", # Multimodal model enforce_eager=True, @@ -497,6 +491,13 @@ def test_encoder_instance_zero_kv_cache( dtype="float16", seed=42, ) + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + disable_hybrid_kv_cache_manager=True, + is_encoder_decoder=model_config.is_encoder_decoder, + ) cache_config = CacheConfig( block_size=16, gpu_memory_utilization=gpu_memory_utilization, @@ -506,7 +507,7 @@ def test_encoder_instance_zero_kv_cache( ) kv_transfer_config = ( KVTransferConfig( - kv_connector="SharedStorageConnector", + kv_connector="ExampleConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) @@ -514,7 +515,7 @@ def test_encoder_instance_zero_kv_cache( else None ) ec_transfer_config = ECTransferConfig( - ec_connector="ECSharedStorageConnector", + ec_connector="ECExampleConnector", ec_role=ec_role, ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test_encoder"}, ) diff --git a/tests/v1/engine/test_init_error_messaging.py b/tests/v1/engine/test_init_error_messaging.py new file mode 100644 index 0000000000000..bc23a68f9deb1 --- /dev/null +++ b/tests/v1/engine/test_init_error_messaging.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.v1.core.kv_cache_utils import check_enough_kv_cache_memory +from vllm.v1.kv_cache_interface import FullAttentionSpec + + +def test_kv_cache_oom_no_memory(): + from unittest.mock import MagicMock + + config = MagicMock() + config.model_config.max_model_len = 2048 + + spec = { + "layer_0": FullAttentionSpec( + block_size=16, + num_kv_heads=8, + head_size=128, + dtype="float16", + ) + } + + with pytest.raises(ValueError): + check_enough_kv_cache_memory(config, spec, 0) + + +def test_kv_cache_oom_insufficient_memory(monkeypatch): + from unittest.mock import MagicMock + + config = MagicMock() + config.model_config.max_model_len = 2048 + config.cache_config.block_size = 16 + config.parallel_config.tensor_parallel_size = 1 + config.parallel_config.pipeline_parallel_size = 1 + config.parallel_config.decode_context_parallel_size = 1 + + monkeypatch.setattr( + "vllm.v1.core.kv_cache_utils.max_memory_usage_bytes", + lambda c, s: 100 * 1024**3, # 100 GiB + ) + + spec = { + "layer_0": FullAttentionSpec( + block_size=16, + num_kv_heads=8, + head_size=128, + dtype="float16", + ) + } + + with pytest.raises(ValueError): + check_enough_kv_cache_memory(config, spec, 1024**3) # 1 GiB diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index 40b9d1fe850c6..bc9674ee86cf8 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -76,6 +76,8 @@ def sample_json_schema(): }, "required": ["name", "age", "skills", "grade", "email", "work_history"], "additionalProperties": False, + "minProperties": 1, + "maxProperties": 10, } @@ -96,6 +98,9 @@ def unsupported_json_schema(): }, "required": ["score", "tags"], "additionalProperties": False, + "patternProperties": { + "^score$": {"type": "integer"}, + }, } diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 736ccbefbc4da..ddab006d0d31a 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -9,7 +9,7 @@ import regex as re from openai import BadRequestError from tests.utils import RemoteOpenAIServer -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" diff --git a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py index 276de2ff8e2cd..b30556fbc81fb 100644 --- a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py +++ b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import base64 -import io import json import openai # use the official client for correctness check @@ -13,6 +11,7 @@ from transformers import AutoConfig from tests.conftest import ImageTestAssets from tests.utils import RemoteOpenAIServer +from vllm.utils.serial_utils import tensor2base64 # any model with a chat template should work here MODEL_NAME = "llava-hf/llava-1.5-7b-hf" @@ -50,18 +49,6 @@ async def client_with_image_embeds(server_with_image_embeds): yield async_client -def encode_image_embedding_to_base64(image_embedding) -> str: - """ - Encode image embedding to base64 string - """ - buffer = io.BytesIO() - torch.save(image_embedding, buffer) - buffer.seek(0) - binary_data = buffer.read() - base64_image_embedding = base64.b64encode(binary_data).decode("utf-8") - return base64_image_embedding - - @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32]) @@ -73,7 +60,7 @@ async def test_completions_with_image_embeds( ): # Test case: Single image embeds input image_embeds = image_assets[0].image_embeds.to(dtype=dtype) - base64_image_embedding = encode_image_embedding_to_base64(image_embeds) + base64_image_embedding = tensor2base64(image_embeds) chat_completion = await client_with_image_embeds.chat.completions.create( messages=[ {"role": "system", "content": "You are a helpful assistant."}, diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 5768fcdb57ceb..b92d3fcd6fb8b 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -30,7 +30,14 @@ async def lifespan(app: FastAPI): prefiller_base_url = f"http://{host}:{port}/v1" app.state.prefill_clients.append( { - "client": httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + "client": httpx.AsyncClient( + timeout=None, + base_url=prefiller_base_url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), "host": host, "port": port, "id": i, @@ -42,7 +49,14 @@ async def lifespan(app: FastAPI): decoder_base_url = f"http://{host}:{port}/v1" app.state.decode_clients.append( { - "client": httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + "client": httpx.AsyncClient( + timeout=None, + base_url=decoder_base_url, + limits=httpx.Limits( + max_connections=None, + max_keepalive_connections=None, + ), + ), "host": host, "port": port, "id": i, @@ -169,6 +183,10 @@ async def send_request_to_service( ) response.raise_for_status() + # read/consume the response body to release the connection + # otherwise, it would http.ReadError + await response.aread() + return response @@ -206,6 +224,7 @@ async def _handle_completions(api: str, request: Request): # Extract the needed fields response_json = response.json() + await response.aclose() # CRITICAL: Release connection back to pool kv_transfer_params = response_json.get("kv_transfer_params", {}) if kv_transfer_params: req_data["kv_transfer_params"] = kv_transfer_params diff --git a/tests/v1/kv_connector/unit/test_backwards_compatibility.py b/tests/v1/kv_connector/unit/test_backwards_compatibility.py index 7cd23805c599d..0d29ca5fca5e5 100644 --- a/tests/v1/kv_connector/unit/test_backwards_compatibility.py +++ b/tests/v1/kv_connector/unit/test_backwards_compatibility.py @@ -218,12 +218,12 @@ def test_internal_connector_uses_new_signature(): Test that internal connectors (registered in factory) always use the new signature and get kv_cache_config. """ - from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( - SharedStorageConnector, + from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( + ExampleConnector, ) vllm_config = create_vllm_config() - vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector" + vllm_config.kv_transfer_config.kv_connector = "ExampleConnector" scheduler = create_scheduler(vllm_config) kv_cache_config = scheduler.kv_cache_config @@ -233,7 +233,7 @@ def test_internal_connector_uses_new_signature(): ) assert connector is not None - assert isinstance(connector, SharedStorageConnector) + assert isinstance(connector, ExampleConnector) assert connector._kv_cache_config is not None assert connector._kv_cache_config == kv_cache_config diff --git a/tests/v1/kv_connector/unit/test_cache_pollution_prevention.py b/tests/v1/kv_connector/unit/test_cache_pollution_prevention.py new file mode 100644 index 0000000000000..ec3fb8231e19e --- /dev/null +++ b/tests/v1/kv_connector/unit/test_cache_pollution_prevention.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +test that invalid blocks are evicted from prefix cache to prevent pollution. + +verifies that when sync-loading fails, invalid blocks are removed from the +prefix cache hash table so future requests cannot match and reuse corrupted data. +""" + +from collections.abc import Callable +from unittest.mock import Mock + +import pytest + +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import Request, RequestStatus + +from .utils import ( + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) + +pytestmark = pytest.mark.cpu_test + + +def _make_get_num_new_matched_tokens( + req_num_new_matched_tokens: dict[str, int], + async_load: bool, +) -> Callable[[Request, int], tuple[int, bool]]: + def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]: + value = req_num_new_matched_tokens.get(request.request_id, 0) + return value, async_load + + return get_num_new_matched_tokens + + +@pytest.fixture +def fail_scheduler(): + """scheduler with kv_load_failure_policy='fail'""" + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_load_failure_policy = "fail" + return create_scheduler(vllm_config) + + +def test_invalid_blocks_evicted_prevents_cache_pollution( + fail_scheduler: Scheduler, +): + """ + verify invalid blocks are evicted to prevent future cache hits. + + scenario: + 1. request 1 loads externally-computed blocks (sync mode) + 2. some blocks fail to load and are marked invalid + 3. with fail policy, invalid blocks should be evicted from prefix cache + 4. request is marked as FINISHED_ERROR + """ + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * fail_scheduler.block_size + ) + + # request 1: will have invalid blocks + request1 = create_request(num_tokens=num_prompt_tokens, request_id=1) + fail_scheduler.add_request(request=request1) + + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + } + + # mock connector indicating sync load + fail_scheduler.connector = Mock() + fail_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, False) + ) + fail_scheduler.connector.request_finished.return_value = (False, None) + fail_scheduler.connector.take_events.return_value = () + + scheduler_output = fail_scheduler.schedule() + + # request should be running with sync KV load + assert len(fail_scheduler.running) == 1 + assert request1.status == RequestStatus.RUNNING + + # get allocated block IDs + req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_id = req_block_ids[invalid_block_idx] + invalid_block_ids = {invalid_block_id} + + # get the block object to verify eviction later + block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id] + + # cache the blocks to simulate they've been computed and cached + # (in real scenario blocks would be cached after compute) + fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens) + + # verify block has a hash (is cached) before reporting invalid blocks + assert block.block_hash is not None, ( + f"block {invalid_block_id} should be cached (have a hash) before " + f"eviction test, but hash is None" + ) + + # report invalid blocks + model_runner_output = create_model_runner_output( + [request1], + invalid_block_ids=invalid_block_ids, + use_eos=False, + ) + + fail_scheduler.update_from_output(scheduler_output, model_runner_output) + + # verify request finished with error (fail policy) + assert request1.status == RequestStatus.FINISHED_ERROR + + # critical assertion: invalid block and all subsequent blocks should be evicted + # all blocks from invalid_block_idx onwards become invalid since they were + # computed based on the failed block + for idx in range(invalid_block_idx, len(req_block_ids)): + block_id = req_block_ids[idx] + block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id] + assert block_obj.block_hash is None, ( + f"block {block_id} at index {idx} should have been evicted " + f"(hash reset to None), but hash is {block_obj.block_hash}. " + f"All blocks from index {invalid_block_idx} onwards should be evicted " + f"since they depend on the invalid block at index {invalid_block_idx}." + ) + + # verify cache contains exactly the valid blocks (before first affected block) + # and none of the invalid blocks (from first affected block onwards) + + # valid blocks: all blocks before invalid_block_idx should be cached + for idx in range(invalid_block_idx): + block_id = req_block_ids[idx] + block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id] + assert block_obj.block_hash is not None, ( + f"valid block {block_id} at index {idx} should still be cached " + f"(have a hash), but hash is None. Only blocks from index " + f"{invalid_block_idx} onwards should be evicted." + ) + + # invalid blocks: verify they're not in the cached_block_hash_to_block map + cached_blocks = ( + fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block + ) + cached_block_ids = { + b.block_id + for blocks_val in cached_blocks._cache.values() + for b in ( + [blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values() + ) + } + + for idx in range(invalid_block_idx, len(req_block_ids)): + block_id = req_block_ids[idx] + assert block_id not in cached_block_ids, ( + f"invalid block {block_id} at index {idx} should not be in cache hash table" + ) diff --git a/tests/v1/kv_connector/unit/test_error_propagation.py b/tests/v1/kv_connector/unit/test_error_propagation.py new file mode 100644 index 0000000000000..20e181f379f5c --- /dev/null +++ b/tests/v1/kv_connector/unit/test_error_propagation.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable +from unittest.mock import Mock + +import pytest + +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import FinishReason, Request, RequestStatus + +from .utils import ( + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) + +pytestmark = pytest.mark.cpu_test + + +def _make_get_num_new_matched_tokens( + req_num_new_matched_tokens: dict[str, int], + async_load: bool, +) -> Callable[[Request, int], tuple[int, bool]]: + def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]: + value = req_num_new_matched_tokens.get(request.request_id, 0) + return value, async_load + + return get_num_new_matched_tokens + + +@pytest.fixture +def fail_scheduler(): + """scheduler with kv_load_failure_policy='fail'""" + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_load_failure_policy = "fail" + return create_scheduler(vllm_config) + + +def test_error_propagation_sync_load(fail_scheduler: Scheduler): + """test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)""" + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * fail_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + fail_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + fail_scheduler.connector = Mock() + fail_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, False) + ) + fail_scheduler.connector.request_finished.return_value = (False, None) + fail_scheduler.connector.take_events.return_value = () + + scheduler_output = fail_scheduler.schedule() + + assert len(fail_scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1 + + req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_ids = {req_block_ids[invalid_block_idx]} + model_runner_output = create_model_runner_output( + [request], + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output) + + assert request.status == RequestStatus.FINISHED_ERROR + assert request.get_finished_reason() == FinishReason.ERROR + + assert len(outputs) == 1 + engine_outputs = next(iter(outputs.values())) + assert len(engine_outputs.outputs) == 1 + output = engine_outputs.outputs[0] + assert output.request_id == request.request_id + assert output.finish_reason == FinishReason.ERROR + + assert len(fail_scheduler.running) == 0 + + +def test_error_propagation_async_load(fail_scheduler: Scheduler): + """test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)""" + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * fail_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + fail_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + fail_scheduler.connector = Mock() + fail_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, True) + ) + fail_scheduler.connector.request_finished.return_value = (False, None) + fail_scheduler.connector.take_events.return_value = () + + scheduler_output = fail_scheduler.schedule() + + assert len(fail_scheduler.waiting) == 1 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert request.num_computed_tokens == 0 + + (req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id) + invalid_block_ids = {req_block_ids[invalid_block_idx]} + model_runner_output = create_model_runner_output( + reqs=[], + finished_recving=set(), + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output) + + assert request.status == RequestStatus.FINISHED_ERROR + assert request.get_finished_reason() == FinishReason.ERROR + + assert len(outputs) == 1 + engine_outputs = next(iter(outputs.values())) + assert len(engine_outputs.outputs) == 1 + output = engine_outputs.outputs[0] + assert output.request_id == request.request_id + assert output.finish_reason == FinishReason.ERROR + + assert len(fail_scheduler.waiting) == 0 diff --git a/tests/v1/kv_connector/unit/test_shared_storage_connector.py b/tests/v1/kv_connector/unit/test_example_connector.py similarity index 94% rename from tests/v1/kv_connector/unit/test_shared_storage_connector.py rename to tests/v1/kv_connector/unit/test_example_connector.py index e7013a794a8c6..75edb79fb4af4 100644 --- a/tests/v1/kv_connector/unit/test_shared_storage_connector.py +++ b/tests/v1/kv_connector/unit/test_example_connector.py @@ -3,12 +3,14 @@ from dataclasses import asdict from typing import NamedTuple +import pytest from PIL import Image from vllm import LLM, EngineArgs, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import KVTransferConfig from vllm.multimodal.utils import encode_image_base64 +from vllm.platforms import current_platform MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8" @@ -108,18 +110,25 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]): print("-" * 50) +@pytest.mark.skipif( + current_platform.is_rocm(), + reason=( + "hipErrorLaunchFailure when running this test, see issue:" + "https://github.com/ROCm/pytorch/issues/2822" + ), +) def test_shared_storage_connector_hashes(tmp_path): """ - Tests that SharedStorageConnector saves KV to the storage locations + Tests that ExampleConnector saves KV to the storage locations with proper hashes; that are unique for inputs with identical text but different images (same size), or same multiple images but different orders. """ # Using tmp_path as the storage path to store KV print(f"KV storage path at: {str(tmp_path)}") - # Configure the SharedStorageConnector + # Configure the ExampleConnector kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", + kv_connector="ExampleConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": str(tmp_path)}, ) diff --git a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py new file mode 100644 index 0000000000000..940f3a98308b6 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py @@ -0,0 +1,454 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Tests for correctness in invalid block handling. + +These tests verify correct behavior in three scenarios: +1. Sync recompute case: Blocks should not be freed for running requests + that need to recompute invalid blocks +2. Sync fail case: Invalid blocks must be evicted from cache when request fails +3. Async recompute case: Invalid blocks should not be cached after transfer +""" + +from collections.abc import Callable +from unittest.mock import Mock + +import pytest + +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import FinishReason, Request, RequestStatus + +from .utils import ( + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) + +pytestmark = pytest.mark.cpu_test + + +def _make_get_num_new_matched_tokens( + req_num_new_matched_tokens: dict[str, int], + async_load: bool, +) -> Callable[[Request, int], tuple[int, bool]]: + def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]: + value = req_num_new_matched_tokens.get(request.request_id, 0) + return value, async_load + + return get_num_new_matched_tokens + + +@pytest.fixture +def fail_scheduler(): + """scheduler with kv_load_failure_policy='fail'""" + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_load_failure_policy = "fail" + return create_scheduler(vllm_config) + + +@pytest.fixture +def recompute_scheduler(): + """scheduler with kv_load_failure_policy='recompute'""" + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute" + return create_scheduler(vllm_config) + + +def test_sync_recompute_blocks_not_freed_for_running_requests( + recompute_scheduler: Scheduler, +): + """ + Test sync recompute case - blocks must not be freed for running requests. + + When a running request has invalid blocks and retry_policy is 'recompute': + 1. Request should remain in RUNNING state + 2. num_computed_tokens should be truncated to invalid block boundary + 3. Blocks should NOT be freed (request still needs them for recomputation) + 4. Request should remain in scheduler.requests and scheduler.running + """ + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * recompute_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + recompute_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + # mock connector indicating sync load + recompute_scheduler.connector = Mock() + recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, False) + ) + recompute_scheduler.connector.request_finished.return_value = (False, None) + recompute_scheduler.connector.take_events.return_value = () + + scheduler_output = recompute_scheduler.schedule() + + # request should be running with sync KV load + assert len(recompute_scheduler.running) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 1 + assert request.status == RequestStatus.RUNNING + + # get the allocated block IDs before invalid blocks are reported + req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_ids = {req_block_ids[invalid_block_idx]} + + # store original num_computed_tokens for comparison + original_num_computed_tokens = request.num_computed_tokens + + model_runner_output = create_model_runner_output( + [request], + invalid_block_ids=invalid_block_ids, + use_eos=False, # not finished - should continue running + ) + + outputs = recompute_scheduler.update_from_output( + scheduler_output, model_runner_output + ) + + # critical assertions for recompute case: + + # 1. request should still be RUNNING (not finished, not aborted) + assert request.status == RequestStatus.RUNNING, ( + f"Request should remain RUNNING for recompute, got {request.status}" + ) + + # 2. num_computed_tokens should be truncated to first invalid block + expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size + assert request.num_computed_tokens == expected_truncated_tokens, ( + f"num_computed_tokens should be truncated to {expected_truncated_tokens}, " + f"got {request.num_computed_tokens}" + ) + assert request.num_computed_tokens < original_num_computed_tokens, ( + "num_computed_tokens should be reduced after invalid block detection" + ) + + # 3. no output should be generated (request is still running) + # the request should be skipped in the output loop + assert len(outputs) == 0 or request.request_id not in [ + out.request_id for outs in outputs.values() for out in outs.outputs + ], "No output should be generated for recompute requests" + + # 4. request should still be in running queue + assert request in recompute_scheduler.running, ( + "Request should remain in running queue for recomputation" + ) + + # 5. request should still be in scheduler.requests (not deleted) + assert request.request_id in recompute_scheduler.requests, ( + "Request should not be deleted from scheduler.requests" + ) + + # 6. blocks should NOT be freed - verify blocks are still allocated + try: + allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids( + request.request_id + ) + assert allocated_blocks is not None + assert len(allocated_blocks[0]) > 0, ( + "Blocks should still be allocated for recomputation" + ) + except KeyError: + pytest.fail( + "Blocks were freed incorrectly! Running requests need their blocks " + "to recompute invalid portions." + ) + + # 7. verify request can be rescheduled in next step + scheduler_output_2 = recompute_scheduler.schedule() + + # request should appear in the new schedule to recompute invalid blocks + scheduled_req_ids = [ + req.request_id for req in scheduler_output_2.scheduled_new_reqs + ] + if scheduler_output_2.num_scheduled_tokens: + scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys()) + + assert ( + request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0 + ), "Request should be reschedulable for recomputation" + + +def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler): + """ + Test sync fail case - invalid blocks must be evicted from cache. + + When a request fails with policy='fail' and has invalid blocks from sync loading: + 1. Request should be finished with FINISHED_ERROR + 2. Invalid blocks should be evicted from the KV cache + 3. Valid blocks (if shared) should remain in cache + 4. Future requests should not reuse the invalid blocks + + This test verifies that invalid blocks are properly evicted to prevent + cache corruption and reuse of invalid data. + """ + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * fail_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + fail_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + # mock connector indicating sync load + fail_scheduler.connector = Mock() + fail_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, False) + ) + fail_scheduler.connector.request_finished.return_value = (False, None) + fail_scheduler.connector.take_events.return_value = () + + scheduler_output = fail_scheduler.schedule() + + # request should be running with sync KV load + assert len(fail_scheduler.running) == 1 + assert request.status == RequestStatus.RUNNING + + # get allocated block IDs + req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] + invalid_block_id = req_block_ids[invalid_block_idx] + invalid_block_ids = {invalid_block_id} + + # verify the block is in the block pool before we report it as invalid + block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id] + assert block is not None + + # report invalid blocks - request should fail + model_runner_output = create_model_runner_output( + [request], + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output) + + # verify request is finished with error + assert request.status == RequestStatus.FINISHED_ERROR + assert request.get_finished_reason() == FinishReason.ERROR + + # verify output is generated + assert len(outputs) == 1 + engine_outputs = next(iter(outputs.values())) + assert len(engine_outputs.outputs) == 1 + output = engine_outputs.outputs[0] + assert output.request_id == request.request_id + assert output.finish_reason == FinishReason.ERROR + + # verify the request was removed from scheduler + assert request.request_id not in fail_scheduler.requests + assert len(fail_scheduler.running) == 0 + + # critical: verify invalid block was actually freed from cache + # this is the key assertion - the invalid block should no longer be + # tracked by the KV cache manager for this request + # if it's still there, a future request could reuse the invalid data + try: + block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id) + # if we get here, check if blocks were actually freed + if block_ids is not None and len(block_ids[0]) > 0: + pytest.fail( + f"Invalid blocks still tracked for finished request! " + f"Request {request.request_id} should have been freed but " + f"still has {len(block_ids[0])} blocks allocated." + ) + # blocks list exists but is empty - this is fine, they were freed + except KeyError: + # expected - request completely removed from tracking + pass + + # critical: verify invalid block was evicted from prefix cache + # the block should no longer have a hash (hash is reset on eviction) + assert block.block_hash is None, ( + f"Invalid block {invalid_block_id} should have been evicted from cache " + f"(hash should be None), but hash is still {block.block_hash}" + ) + + +def test_async_recompute_blocks_not_cached_when_invalid( + recompute_scheduler: Scheduler, +): + """ + Test async recompute case - invalid blocks not cached after transfer. + + When async KV loading has invalid blocks and retry_policy is 'recompute': + 1. Blocks are allocated but not cached yet + 2. When async transfer completes, only valid blocks should be cached + 3. Invalid blocks should never enter the prefix cache + + This test verifies correctness, the failed_recving_kv_req_ids protection + ensures only valid blocks are cached when the transfer completes, and we + only evict blocks from cache that are already hashed in the block table. + """ + from unittest.mock import patch + + num_prompt_blocks = 100 + num_external_computed_blocks = 99 + invalid_block_idx = 50 + + num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size + num_external_computed_tokens = ( + num_external_computed_blocks * recompute_scheduler.block_size + ) + + request = create_request(num_tokens=num_prompt_tokens) + recompute_scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + # mock connector indicating async load + recompute_scheduler.connector = Mock() + recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, True) + ) + recompute_scheduler.connector.request_finished.return_value = (False, None) + recompute_scheduler.connector.take_events.return_value = () + + scheduler_output = recompute_scheduler.schedule() + + # request should be waiting for remote KVs + assert len(recompute_scheduler.waiting) == 1 + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert request.num_computed_tokens == 0 + + # get the allocated block IDs + (req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids( + request.request_id + ) + invalid_block_id = req_block_ids[invalid_block_idx] + invalid_block_ids = {invalid_block_id} + + # get the block object to verify it's not cached yet and stays uncached + block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id] + + # verify block has no hash before invalid blocks are reported + assert block.block_hash is None, ( + "Async loading blocks should not be cached yet (no hash)" + ) + + # report invalid blocks (transfer not finished yet) + model_runner_output = create_model_runner_output( + reqs=[], + finished_recving=None, # transfer NOT finished + invalid_block_ids=invalid_block_ids, + use_eos=False, + ) + + # critical: spy on evict_blocks to verify it's NOT called for async blocks + original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks + evict_blocks_calls = [] + + def evict_blocks_spy(block_ids): + evict_blocks_calls.append(set(block_ids)) + return original_evict_blocks(block_ids) + + with patch.object( + recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy + ): + recompute_scheduler.update_from_output(scheduler_output, model_runner_output) + + # verify evict_blocks was NOT called (async blocks excluded from eviction) + assert len(evict_blocks_calls) == 0, ( + f"evict_blocks should not be called for async-only invalid blocks, " + f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}" + ) + + # request should still be waiting (not finished with error due to recompute policy) + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids + + # verify num_computed_tokens was truncated to before invalid block + expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size + assert request.num_computed_tokens == expected_valid_tokens + + # verify invalid block still has no hash (was not evicted) + assert block.block_hash is None, ( + f"Async loading blocks shouldn't be cached or evicted. " + f"Block {invalid_block_id} hash should be None but is {block.block_hash}" + ) + + # now simulate async transfer completing + model_runner_output_2 = create_model_runner_output( + reqs=[], + finished_recving={request.request_id}, + invalid_block_ids=None, + use_eos=False, + ) + + recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2) + + # verify request is now marked as finished receiving and ready to be processed + assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids + assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids + + # critical: verify invalid block still has no hash before recompute + # the async transfer invalid data was never cached + assert block.block_hash is None, ( + f"Invalid block {invalid_block_id} should not be cached before recompute " + f"(hash should be None), but hash is {block.block_hash}" + ) + + # critical end-to-end test: spy on cache_blocks to verify it's called with + # the truncated num_computed_tokens value + original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks + cache_blocks_calls = [] + + def cache_blocks_spy(req, num_tokens): + cache_blocks_calls.append((req.request_id, num_tokens)) + return original_cache_blocks(req, num_tokens) + + with patch.object( + recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy + ): + # call schedule() again - this triggers _update_waiting_for_remote_kv() + # which should call cache_blocks with the truncated value + recompute_scheduler.schedule() + + # verify cache_blocks was called with the truncated value + assert len(cache_blocks_calls) == 1, ( + f"cache_blocks should be called exactly once, " + f"got {len(cache_blocks_calls)} calls" + ) + cached_req_id, cached_num_tokens = cache_blocks_calls[0] + assert cached_req_id == request.request_id + assert cached_num_tokens == expected_valid_tokens, ( + f"cache_blocks should be called with truncated value {expected_valid_tokens}, " + f"but was called with {cached_num_tokens}" + ) + + # request should now be RUNNING (scheduled immediately after transfer completes) + # the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call + assert request.status == RequestStatus.RUNNING + + # num_computed_tokens should be >= expected_valid_tokens because the scheduler + # will schedule additional new tokens (up to max_num_batched_tokens) for the request + assert request.num_computed_tokens >= expected_valid_tokens, ( + f"num_computed_tokens should be at least {expected_valid_tokens}, " + f"got {request.num_computed_tokens}" + ) + + # request should no longer be in the failed/finished receiving sets + assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids + assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids + + # request should be in the running queue + assert request in recompute_scheduler.running diff --git a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py index d0a6eeae6286d..4ba6b2201d0e2 100644 --- a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py +++ b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501 - SharedStorageConnectorMetadata, +from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa: E501 + ExampleConnectorMetadata, ) from vllm.distributed.kv_transfer.kv_transfer_state import ( ensure_kv_transfer_initialized, @@ -11,7 +11,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import ( from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin -# Importing utils registers TestSharedStorageConnector with the factory +# Importing utils registers TestExampleConnector with the factory from .utils import create_vllm_config @@ -26,13 +26,13 @@ def _make_empty_scheduler_output(): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - kv_connector_metadata=SharedStorageConnectorMetadata(), + kv_connector_metadata=ExampleConnectorMetadata(), ) def test_kv_connector_mixin_clears_metadata(): vllm_config = create_vllm_config() - vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector" + vllm_config.kv_transfer_config.kv_connector = "TestExampleConnector" vllm_config.kv_transfer_config.kv_role = "kv_both" vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit" diff --git a/tests/v1/kv_connector/unit/test_lmcache_connector.py b/tests/v1/kv_connector/unit/test_lmcache_connector.py new file mode 100644 index 0000000000000..6a8cfc71a67a6 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_lmcache_connector.py @@ -0,0 +1,756 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock + +import pytest + +from vllm.distributed.kv_events import BlockStored +from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import ( + LMCacheConnectorV1, + LMCacheKVEvents, +) +from vllm.v1.outputs import KVConnectorOutput + + +@pytest.fixture +def mock_lmcache_engine_event(): + """Create a mock event object that mimics what the lmcache engine returns.""" + + class MockEvent: + def __init__( + self, + block_hashes, + parent_block_hash, + token_ids, + lora_id, + block_size, + medium, + ): + self.block_hashes = block_hashes + self.parent_block_hash = parent_block_hash + self.token_ids = token_ids + self.lora_id = lora_id + self.block_size = block_size + self.medium = medium + + return MockEvent( + block_hashes=["hash1", "hash2"], + parent_block_hash="parent_hash", + token_ids=[1, 2, 3, 4], + lora_id=None, + block_size=16, + medium="GPU", + ) + + +@pytest.fixture +def mock_connector(): + """Create a mock LMCacheConnectorV1 instance with mocked dependencies.""" + connector = MagicMock(spec=LMCacheConnectorV1) + connector._kv_cache_events = None + connector._lmcache_engine = MagicMock() + + # Make the methods use the real implementation + connector.get_kv_connector_kv_cache_events = ( + LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__( + connector, LMCacheConnectorV1 + ) + ) + connector.update_connector_output = ( + LMCacheConnectorV1.update_connector_output.__get__( + connector, LMCacheConnectorV1 + ) + ) + connector.take_events = LMCacheConnectorV1.take_events.__get__( + connector, LMCacheConnectorV1 + ) + + return connector + + +class TestGetKVConnectorKVCacheEvents: + """Test get_kv_connector_kv_cache_events method.""" + + def test_returns_none_when_no_events(self, mock_connector): + """Test that None is returned when lmcache engine has no events.""" + mock_connector._lmcache_engine.get_kv_events.return_value = None + + result = mock_connector.get_kv_connector_kv_cache_events() + + assert result is None + mock_connector._lmcache_engine.get_kv_events.assert_called_once() + + def test_returns_none_when_empty_list(self, mock_connector): + """Test that None is returned when lmcache engine returns empty list.""" + mock_connector._lmcache_engine.get_kv_events.return_value = [] + + result = mock_connector.get_kv_connector_kv_cache_events() + + assert result is None + + def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event): + """Test conversion of a single event from lmcache engine format.""" + mock_connector._lmcache_engine.get_kv_events.return_value = [ + mock_lmcache_engine_event + ] + + result = mock_connector.get_kv_connector_kv_cache_events() + + assert result is not None + assert isinstance(result, LMCacheKVEvents) + assert result.get_number_of_workers() == 1 + + events = result.get_all_events() + assert len(events) == 1 + assert isinstance(events[0], BlockStored) + assert events[0].block_hashes == ["hash1", "hash2"] + assert events[0].parent_block_hash == "parent_hash" + assert events[0].token_ids == [1, 2, 3, 4] + assert events[0].lora_id is None + assert events[0].block_size == 16 + assert events[0].medium == "GPU" + + def test_converts_multiple_events(self, mock_connector): + """Test conversion of multiple events from lmcache engine format.""" + + class MockEvent: + def __init__(self, i): + self.block_hashes = [f"hash{i}"] + self.parent_block_hash = f"parent{i}" + self.token_ids = [i] + self.lora_id = None + self.block_size = 16 + self.medium = "GPU" + + events = [MockEvent(i) for i in range(5)] + mock_connector._lmcache_engine.get_kv_events.return_value = events + + result = mock_connector.get_kv_connector_kv_cache_events() + + assert result is not None + assert isinstance(result, LMCacheKVEvents) + + converted_events = result.get_all_events() + assert len(converted_events) == 5 + + for i, event in enumerate(converted_events): + assert isinstance(event, BlockStored) + assert event.block_hashes == [f"hash{i}"] + assert event.parent_block_hash == f"parent{i}" + assert event.token_ids == [i] + + def test_preserves_event_attributes(self, mock_connector): + """Test that all event attributes are correctly preserved.""" + + class MockEventWithLora: + def __init__(self): + self.block_hashes = ["hash_a", "hash_b", "hash_c"] + self.parent_block_hash = "parent_xyz" + self.token_ids = [100, 200, 300] + self.lora_id = 42 + self.block_size = 32 + self.medium = "DISK" + + mock_connector._lmcache_engine.get_kv_events.return_value = [ + MockEventWithLora() + ] + + result = mock_connector.get_kv_connector_kv_cache_events() + + events = result.get_all_events() + event = events[0] + + assert event.block_hashes == ["hash_a", "hash_b", "hash_c"] + assert event.parent_block_hash == "parent_xyz" + assert event.token_ids == [100, 200, 300] + assert event.lora_id == 42 + assert event.block_size == 32 + assert event.medium == "DISK" + + def test_handles_none_parent_block_hash(self, mock_connector): + """Test handling of events with None parent_block_hash.""" + + class MockEventNoParent: + def __init__(self): + self.block_hashes = ["hash1"] + self.parent_block_hash = None + self.token_ids = [1, 2] + self.lora_id = None + self.block_size = 16 + self.medium = "GPU" + + mock_connector._lmcache_engine.get_kv_events.return_value = [ + MockEventNoParent() + ] + + result = mock_connector.get_kv_connector_kv_cache_events() + + events = result.get_all_events() + assert events[0].parent_block_hash is None + + +class TestUpdateConnectorOutput: + """Test update_connector_output method.""" + + def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector): + """Test that method returns early when kv_cache_events is None.""" + connector_output = KVConnectorOutput(kv_cache_events=None) + + mock_connector.update_connector_output(connector_output) + + assert mock_connector._kv_cache_events is None + + def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events( + self, mock_connector + ): + """Test that method returns early when kv_cache_events is not + LMCacheKVEvents.""" + # Create a mock object that is not LMCacheKVEvents + fake_events = MagicMock() + connector_output = KVConnectorOutput(kv_cache_events=fake_events) + + mock_connector.update_connector_output(connector_output) + + assert mock_connector._kv_cache_events is None + + def test_sets_kv_cache_events_when_none(self, mock_connector): + """Test that _kv_cache_events is set when it was None.""" + kv_events = LMCacheKVEvents(num_workers=1) + event = BlockStored( + block_hashes=["hash1"], + parent_block_hash=None, + token_ids=[1, 2], + block_size=16, + lora_id=None, + medium="GPU", + ) + kv_events.add_events([event]) + + connector_output = KVConnectorOutput(kv_cache_events=kv_events) + + mock_connector.update_connector_output(connector_output) + + assert mock_connector._kv_cache_events is kv_events + + def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector): + """Test that events are added when _kv_cache_events already exists.""" + # Set up existing events + existing_events = LMCacheKVEvents(num_workers=2) + event1 = BlockStored( + block_hashes=["hash1"], + parent_block_hash=None, + token_ids=[1], + block_size=16, + lora_id=None, + medium="GPU", + ) + existing_events.add_events([event1]) + existing_events.add_events([event1]) # Simulate 2 workers reporting + + mock_connector._kv_cache_events = existing_events + + # Create new events to add + new_events = LMCacheKVEvents(num_workers=1) + event2 = BlockStored( + block_hashes=["hash2"], + parent_block_hash=None, + token_ids=[2], + block_size=16, + lora_id=None, + medium="GPU", + ) + new_events.add_events([event2]) + + connector_output = KVConnectorOutput(kv_cache_events=new_events) + + mock_connector.update_connector_output(connector_output) + + # Check that events were added + all_events = mock_connector._kv_cache_events.get_all_events() + assert len(all_events) == 3 # 2 from existing + 1 from new + assert event1 in all_events + assert event2 in all_events + + def test_increments_workers_when_kv_cache_events_already_exists( + self, mock_connector + ): + """Test that worker count is incremented correctly.""" + # Set up existing events with 2 workers + existing_events = LMCacheKVEvents(num_workers=2) + mock_connector._kv_cache_events = existing_events + + # Create new events from 3 workers + new_events = LMCacheKVEvents(num_workers=3) + event = BlockStored( + block_hashes=["hash1"], + parent_block_hash=None, + token_ids=[1], + block_size=16, + lora_id=None, + medium="GPU", + ) + new_events.add_events([event]) + + connector_output = KVConnectorOutput(kv_cache_events=new_events) + + mock_connector.update_connector_output(connector_output) + + # Worker count should be 2 + 3 = 5 + assert mock_connector._kv_cache_events.get_number_of_workers() == 5 + + def test_multiple_updates(self, mock_connector): + """Test multiple consecutive updates.""" + # First update + events1 = LMCacheKVEvents(num_workers=1) + event1 = BlockStored( + block_hashes=["hash1"], + parent_block_hash=None, + token_ids=[1], + block_size=16, + lora_id=None, + medium="GPU", + ) + events1.add_events([event1]) + output1 = KVConnectorOutput(kv_cache_events=events1) + mock_connector.update_connector_output(output1) + + # Second update + events2 = LMCacheKVEvents(num_workers=2) + event2 = BlockStored( + block_hashes=["hash2"], + parent_block_hash=None, + token_ids=[2], + block_size=16, + lora_id=None, + medium="GPU", + ) + events2.add_events([event2]) + output2 = KVConnectorOutput(kv_cache_events=events2) + mock_connector.update_connector_output(output2) + + # Third update + events3 = LMCacheKVEvents(num_workers=1) + event3 = BlockStored( + block_hashes=["hash3"], + parent_block_hash=None, + token_ids=[3], + block_size=16, + lora_id=None, + medium="GPU", + ) + events3.add_events([event3]) + output3 = KVConnectorOutput(kv_cache_events=events3) + mock_connector.update_connector_output(output3) + + # Check final state + all_events = mock_connector._kv_cache_events.get_all_events() + assert len(all_events) == 3 + assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1 + + def test_updates_with_empty_events(self, mock_connector): + """Test updating with empty event lists.""" + # First update with actual events + events1 = LMCacheKVEvents(num_workers=1) + event1 = BlockStored( + block_hashes=["hash1"], + parent_block_hash=None, + token_ids=[1], + block_size=16, + lora_id=None, + medium="GPU", + ) + events1.add_events([event1]) + output1 = KVConnectorOutput(kv_cache_events=events1) + mock_connector.update_connector_output(output1) + + # Second update with empty events + events2 = LMCacheKVEvents(num_workers=2) + # No events added + output2 = KVConnectorOutput(kv_cache_events=events2) + mock_connector.update_connector_output(output2) + + # Should still have the original event + all_events = mock_connector._kv_cache_events.get_all_events() + assert len(all_events) == 1 + assert mock_connector._kv_cache_events.get_number_of_workers() == 3 + + +class TestTakeEvents: + """Test take_events method.""" + + def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector): + """Test that nothing is yielded when _kv_cache_events is None.""" + mock_connector._kv_cache_events = None + + events = list(mock_connector.take_events()) + + assert events == [] + + def test_yields_events_and_clears(self, mock_connector): + """Test that events are yielded and then cleared.""" + # Set up events + kv_events = LMCacheKVEvents(num_workers=1) + event1 = BlockStored( + block_hashes=["hash1"], + parent_block_hash=None, + token_ids=[1], + block_size=16, + lora_id=None, + medium="GPU", + ) + event2 = BlockStored( + block_hashes=["hash2"], + parent_block_hash=None, + token_ids=[2], + block_size=16, + lora_id=None, + medium="GPU", + ) + kv_events.add_events([event1, event2]) + mock_connector._kv_cache_events = kv_events + + # Take events + events = list(mock_connector.take_events()) + + # Check that events were yielded + assert len(events) == 2 + assert event1 in events + assert event2 in events + + # Check that _kv_cache_events was cleared + assert mock_connector._kv_cache_events is None + + def test_aggregates_before_yielding(self, mock_connector): + """Test that events are aggregated before yielding.""" + # Set up events from multiple workers + kv_events = LMCacheKVEvents(num_workers=3) + common_event = BlockStored( + block_hashes=["hash_common"], + parent_block_hash=None, + token_ids=[1], + block_size=16, + lora_id=None, + medium="GPU", + ) + uncommon_event = BlockStored( + block_hashes=["hash_uncommon"], + parent_block_hash=None, + token_ids=[2], + block_size=16, + lora_id=None, + medium="GPU", + ) + + # All 3 workers report common_event + kv_events.add_events([common_event]) + kv_events.add_events([common_event]) + kv_events.add_events([common_event]) + + # Only 1 worker reports uncommon_event + kv_events.add_events([uncommon_event]) + + mock_connector._kv_cache_events = kv_events + + # Take events + events = list(mock_connector.take_events()) + + # Only the common event should be yielded + assert len(events) == 1 + assert events[0] == common_event + + def test_multiple_take_events_calls(self, mock_connector): + """Test calling take_events multiple times.""" + # First call with events + kv_events1 = LMCacheKVEvents(num_workers=1) + event1 = BlockStored( + block_hashes=["hash1"], + parent_block_hash=None, + token_ids=[1], + block_size=16, + lora_id=None, + medium="GPU", + ) + kv_events1.add_events([event1]) + mock_connector._kv_cache_events = kv_events1 + + events1 = list(mock_connector.take_events()) + assert len(events1) == 1 + assert events1[0] == event1 + assert mock_connector._kv_cache_events is None + + # Second call with no events + events2 = list(mock_connector.take_events()) + assert events2 == [] + + # Third call after adding new events + kv_events2 = LMCacheKVEvents(num_workers=1) + event2 = BlockStored( + block_hashes=["hash2"], + parent_block_hash=None, + token_ids=[2], + block_size=16, + lora_id=None, + medium="GPU", + ) + kv_events2.add_events([event2]) + mock_connector._kv_cache_events = kv_events2 + + events3 = list(mock_connector.take_events()) + assert len(events3) == 1 + assert events3[0] == event2 + + def test_yields_empty_after_aggregation_removes_all(self, mock_connector): + """Test that nothing is yielded if aggregation removes all events.""" + # Set up events from 2 workers with no common events + kv_events = LMCacheKVEvents(num_workers=2) + event1 = BlockStored( + block_hashes=["hash1"], + parent_block_hash=None, + token_ids=[1], + block_size=16, + lora_id=None, + medium="GPU", + ) + event2 = BlockStored( + block_hashes=["hash2"], + parent_block_hash=None, + token_ids=[2], + block_size=16, + lora_id=None, + medium="GPU", + ) + + # Worker 1 reports event1 + kv_events.add_events([event1]) + # Worker 2 reports event2 + kv_events.add_events([event2]) + + mock_connector._kv_cache_events = kv_events + + # Take events + events = list(mock_connector.take_events()) + + # No common events, so nothing should be yielded + assert events == [] + assert mock_connector._kv_cache_events is None + + +class TestIntegrationScenarios: + """Test integration scenarios.""" + + def test_full_workflow(self, mock_connector, mock_lmcache_engine_event): + """Test a complete workflow from getting events to taking them.""" + # Step 1: Get events from lmcache engine + mock_connector._lmcache_engine.get_kv_events.return_value = [ + mock_lmcache_engine_event + ] + kv_events = mock_connector.get_kv_connector_kv_cache_events() + + assert kv_events is not None + assert len(kv_events.get_all_events()) == 1 + + # Step 2: Update connector output (simulate receiving from worker) + output1 = KVConnectorOutput(kv_cache_events=kv_events) + mock_connector.update_connector_output(output1) + + assert mock_connector._kv_cache_events is not None + + # Step 3: Take events + taken_events = list(mock_connector.take_events()) + + assert len(taken_events) == 1 + assert mock_connector._kv_cache_events is None + + def test_multiple_workers_workflow(self, mock_connector): + """Test workflow with multiple workers.""" + + class MockEvent: + def __init__(self, hash_val): + self.block_hashes = [hash_val] + self.parent_block_hash = None + self.token_ids = [1] + self.lora_id = None + self.block_size = 16 + self.medium = "GPU" + + # Worker 1 + mock_connector._lmcache_engine.get_kv_events.return_value = [ + MockEvent("hash_common"), + MockEvent("hash_worker1"), + ] + kv_events1 = mock_connector.get_kv_connector_kv_cache_events() + output1 = KVConnectorOutput(kv_cache_events=kv_events1) + mock_connector.update_connector_output(output1) + + # Worker 2 + mock_connector._lmcache_engine.get_kv_events.return_value = [ + MockEvent("hash_common"), + MockEvent("hash_worker2"), + ] + kv_events2 = mock_connector.get_kv_connector_kv_cache_events() + output2 = KVConnectorOutput(kv_cache_events=kv_events2) + mock_connector.update_connector_output(output2) + + # Take events (should only get common events) + taken_events = list(mock_connector.take_events()) + + # With aggregation, only events reported by both workers should be present + # In this case, hash_common was reported by both + event_hashes = [e.block_hashes[0] for e in taken_events] + assert "hash_common" in event_hashes + + def test_empty_workflow(self, mock_connector): + """Test workflow when there are no events at any stage.""" + # Get events returns None + mock_connector._lmcache_engine.get_kv_events.return_value = None + kv_events = mock_connector.get_kv_connector_kv_cache_events() + + assert kv_events is None + + # Update with None + output = KVConnectorOutput(kv_cache_events=None) + mock_connector.update_connector_output(output) + + # Take events + taken_events = list(mock_connector.take_events()) + + assert taken_events == [] + assert mock_connector._kv_cache_events is None + + def test_repeated_cycles(self, mock_connector): + """Test multiple cycles of the complete workflow.""" + + class MockEvent: + def __init__(self, cycle_num): + self.block_hashes = [f"hash_cycle_{cycle_num}"] + self.parent_block_hash = None + self.token_ids = [cycle_num] + self.lora_id = None + self.block_size = 16 + self.medium = "GPU" + + for cycle in range(3): + # Get events + mock_connector._lmcache_engine.get_kv_events.return_value = [ + MockEvent(cycle) + ] + kv_events = mock_connector.get_kv_connector_kv_cache_events() + + # Update + output = KVConnectorOutput(kv_cache_events=kv_events) + mock_connector.update_connector_output(output) + + # Take + taken_events = list(mock_connector.take_events()) + + # Verify + assert len(taken_events) == 1 + assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}" + assert mock_connector._kv_cache_events is None + + def test_lmcache_kv_events_aggregation(self): + """ + Test LMCacheKVEvents aggregation across TP ranks using + KVOutputAggregator (used by MultiprocExecutor). + """ + from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator + from vllm.v1.outputs import ModelRunnerOutput + + # Create KVOutputAggregator for 3 workers (simulating TP=3) + aggregator = KVOutputAggregator(expected_finished_count=3) + + # Define common and unique events + common_event = BlockStored( + block_hashes=["hash_common"], + parent_block_hash="parent_common", + token_ids=[1, 2, 3], + block_size=16, + lora_id=None, + medium="GPU", + ) + + worker1_unique_event = BlockStored( + block_hashes=["hash_worker1"], + parent_block_hash="parent_w1", + token_ids=[4, 5], + block_size=16, + lora_id=None, + medium="GPU", + ) + + worker2_unique_event = BlockStored( + block_hashes=["hash_worker2"], + parent_block_hash="parent_w2", + token_ids=[6, 7], + block_size=16, + lora_id=None, + medium="GPU", + ) + + worker3_unique_event = BlockStored( + block_hashes=["hash_worker3"], + parent_block_hash="parent_w3", + token_ids=[8, 9], + block_size=16, + lora_id=None, + medium="GPU", + ) + + # Create events for each worker + # Worker 0: reports common event and its unique event + worker0_events = LMCacheKVEvents(num_workers=1) + worker0_events.add_events([common_event, worker1_unique_event]) + + # Worker 1: reports common event and its unique event + worker1_events = LMCacheKVEvents(num_workers=1) + worker1_events.add_events([common_event, worker2_unique_event]) + + # Worker 2: reports common event and its unique event + worker2_events = LMCacheKVEvents(num_workers=1) + worker2_events.add_events([common_event, worker3_unique_event]) + + # Create ModelRunnerOutput instances for each worker + worker_outputs = [] + for i, worker_events in enumerate( + [worker0_events, worker1_events, worker2_events] + ): + output = ModelRunnerOutput( + req_ids=[f"req_{i}"], + req_id_to_index={f"req_{i}": 0}, + sampled_token_ids=[[123]], # dummy token + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + kv_connector_output=KVConnectorOutput( + finished_sending=set([f"req_{i}_send"]) + if i < 2 + else None, # Workers 0,1 finished sending + finished_recving=set([f"req_{i}_recv"]) + if i > 0 + else None, # Workers 1,2 finished receiving + kv_cache_events=worker_events, + ), + ) + worker_outputs.append(output) + + # Use the real aggregation mechanism (like MultiprocExecutor.execute_model) + aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) + kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events + + assert isinstance(kv_cache_events, LMCacheKVEvents) + + # After aggregation, events should be combined from all workers + # The aggregator doesn't automatically aggregate events, so we need to call + # aggregate() to get only common events + kv_cache_events.aggregate() + aggregated_events = kv_cache_events.get_all_events() + + # Only the common event should remain after aggregation + # because it's the only event reported by all 3 workers + assert len(aggregated_events) == 1 + assert aggregated_events[0] == common_event + + # Verify the common event properties + assert aggregated_events[0].block_hashes == ["hash_common"] + assert aggregated_events[0].parent_block_hash == "parent_common" + assert aggregated_events[0].token_ids == [1, 2, 3] diff --git a/tests/v1/kv_connector/unit/test_lmcache_integration.py b/tests/v1/kv_connector/unit/test_lmcache_integration.py index 33418edc325af..cfe8d810cf98a 100644 --- a/tests/v1/kv_connector/unit/test_lmcache_integration.py +++ b/tests/v1/kv_connector/unit/test_lmcache_integration.py @@ -64,22 +64,6 @@ def test_multimodal_interface(): assumes(PlaceholderRange, "offset") assumes(PlaceholderRange, "length") - # test a minimal case - import torch - - from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( - apply_mm_hashes_to_token_ids, - ) - - token_ids = torch.arange(10, dtype=torch.long) - mm_hashes = ["0000", "1111"] # hex repr of 0 and 4369 - mm_positions = [ - PlaceholderRange(offset=0, length=4), - PlaceholderRange(offset=5, length=4), - ] - apply_mm_hashes_to_token_ids(token_ids, mm_hashes, mm_positions) - assert token_ids.tolist() == [0, 0, 0, 0, 4, 4369, 4369, 4369, 4369, 9] - @pytest.mark.skipif( current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm" @@ -122,16 +106,6 @@ def test_config_interface(): assumes(CacheConfig, "block_size") assumes(CacheConfig, "gpu_memory_utilization") - # mla metadata minimal cases - from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( - mla_enabled, - ) - - model_config = ModelConfig(model="deepseek-ai/DeepSeek-R1") - assert mla_enabled(model_config) - model_config = ModelConfig(model="Qwen/Qwen3-0.6B") - assert not mla_enabled(model_config) - # kv metadata minimal case from vllm.utils.torch_utils import get_kv_cache_torch_dtype @@ -139,7 +113,7 @@ def test_config_interface(): parallel_config = ParallelConfig() cache_config = CacheConfig(cache_dtype="bfloat16") kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype) - use_mla = mla_enabled(model_config) + use_mla = False chunk_size = 256 num_layer = model_config.get_num_layers(parallel_config) num_kv_head = model_config.get_num_kv_heads(parallel_config) @@ -184,43 +158,11 @@ def test_request_interface(): assumes(req, "num_tokens") assumes(req, "kv_transfer_params", is_instance_of=(dict, NoneType)) - from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem + from vllm.multimodal.inputs import MultiModalFeatureSpec assumes(MultiModalFeatureSpec, "identifier") assumes(MultiModalFeatureSpec, "mm_position") - # minimal case: - from vllm.multimodal.inputs import PlaceholderRange - - request = Request( - request_id="test_request", - prompt_token_ids=[1, 2, 3], - sampling_params=SamplingParams(max_tokens=10), - pooling_params=None, - eos_token_id=100, - lora_request=None, - mm_features=[ - MultiModalFeatureSpec( - modality="image", - identifier="0000", - data=MultiModalKwargsItem.dummy("dummy_m"), - mm_position=PlaceholderRange(offset=0, length=10), - ) - ], - ) - - from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( - extract_mm_features, - ) - - mm_hashes, mm_positions = extract_mm_features(request) - assert isinstance(mm_hashes, list) - assert len(mm_hashes) == 1 - assert isinstance(mm_positions, list) - assert len(mm_positions) == 1 - assert mm_positions[0].offset == 0 - assert mm_positions[0].length == 10 - def test_new_request_interface(): # protect against interface changes diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index ffa7d884d2762..9b6d52e7c294d 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -77,9 +77,9 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool: "https://github.com/ROCm/pytorch/issues/2822" ), ) -def test_multi_shared_storage_connector_consistency(): +def test_multi_example_connector_consistency(): """ - Tests that MultiConnector with two SharedStorageConnectors saves + Tests that MultiConnector with two ExampleConnectors saves identical KV cache data to separate storage locations. """ storage_1_path = Path("storage_1/") @@ -89,14 +89,14 @@ def test_multi_shared_storage_connector_consistency(): storage_1_path.mkdir() storage_2_path.mkdir() - # Configure MultiConnector with two SharedStorageConnectors + # Configure MultiConnector with two ExampleConnectors kv_transfer_config = KVTransferConfig( kv_connector="MultiConnector", kv_role="kv_both", kv_connector_extra_config={ "connectors": [ { - "kv_connector": "TestSharedStorageConnector", + "kv_connector": "TestExampleConnector", "kv_role": "kv_both", "kv_connector_extra_config": { "shared_storage_path": str(storage_1_path), @@ -105,7 +105,7 @@ def test_multi_shared_storage_connector_consistency(): "kv_connector_module_path": "tests.v1.kv_connector.unit.utils", }, { - "kv_connector": "TestSharedStorageConnector", + "kv_connector": "TestExampleConnector", "kv_role": "kv_both", "kv_connector_extra_config": { "shared_storage_path": str(storage_2_path), @@ -427,7 +427,7 @@ class TestMultiConnectorStats: def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self): """Test that connectors without custom stats (return None) are skipped.""" - # SharedStorageConnector doesn't override build_kv_connector_stats, + # ExampleConnector doesn't override build_kv_connector_stats, # so it returns None and should be skipped serialized_data = { "NixlConnector": { @@ -440,7 +440,7 @@ class TestMultiConnectorStats: "num_failed_notifications": [], } }, - "SharedStorageConnector": {"data": {"some_field": [1, 2, 3]}}, + "ExampleConnector": {"data": {"some_field": [1, 2, 3]}}, } stats = MultiConnector.build_kv_connector_stats(data=serialized_data) @@ -451,8 +451,8 @@ class TestMultiConnectorStats: assert len(stats.data) == 1 assert "NixlConnector" in stats.data assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats) - # SharedStorageConnector should be skipped (returns None) - assert "SharedStorageConnector" not in stats.data + # ExampleConnector should be skipped (returns None) + assert "ExampleConnector" not in stats.data def test_build_kv_connector_stats_handles_malformed_data(self): """Test that malformed data raises appropriate errors.""" @@ -527,13 +527,13 @@ class TestMultiConnectorStats: ) stats2 = MultiKVConnectorStats( - data={"SharedStorageConnector": KVConnectorStats(data={"field": [1, 2]})} + data={"ExampleConnector": KVConnectorStats(data={"field": [1, 2]})} ) result = stats1.aggregate(stats2) assert "NixlConnector" in result.data - assert "SharedStorageConnector" in result.data + assert "ExampleConnector" in result.data def test_reduce(self): """Test that reduce() correctly reduces all nested connector stats.""" diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b7d7a10057b8b..66804fa671c7c 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -9,8 +9,10 @@ import textwrap import time import uuid from collections import defaultdict -from unittest.mock import patch +from typing import Any +from unittest.mock import MagicMock, patch +import msgspec import pytest import ray import torch @@ -18,6 +20,7 @@ import torch from vllm import LLM from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( MultiKVConnectorStats, @@ -29,13 +32,16 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorMetadata, NixlConnectorScheduler, NixlConnectorWorker, + NixlHandshakePayload, NixlKVConnectorStats, + compute_nixl_compatibility_hash, ) from vllm.distributed.kv_transfer.kv_transfer_state import ( ensure_kv_transfer_shutdown, has_kv_transfer_group, ) from vllm.forward_context import ForwardContext +from vllm.platforms import current_platform from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend @@ -317,13 +323,19 @@ def test_kv_transfer_handshake(dist_init): } prefill_connector.register_kv_caches(kv_caches) - # Simulate EngineCore initialization that would - # gather connector metadata from all workers, the scheduler connector - # expects metadata to be in dict[int, KVConnectorHandshakeMetadata], - # where the first key is the dp_rank, the second key is the tp_rank. - metadata = {0: prefill_connector.get_handshake_metadata()} + # Simulate EngineCore initialization that would gather connector + # metadata from all workers + metadata = prefill_connector.get_handshake_metadata() + + # metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes) + + # The scheduler connector expects metadata to be in + # dict[int, KVConnectorHandshakeMetadata], where the first key is + # the dp_rank, the second key is the tp_rank. scheduler_connector = scheduler.get_kv_connector() - scheduler_connector.set_xfer_handshake_metadata(metadata) + scheduler_connector.set_xfer_handshake_metadata({0: metadata}) # Simulate a request that finishes prefill, which returns # corresponding NixlConnectorMetadata for decode instance. @@ -362,9 +374,9 @@ def test_kv_transfer_handshake(dist_init): ) received_metadata = mock_add_remote_agent.call_args.args + assert received_metadata[0] == expected_agent_metadata assert received_metadata[1] == 0 # remote_tp_rank assert received_metadata[2] == 1 # remote_tp_size - assert metadata[0] == received_metadata[0] # Need to shutdown the background thread to release NIXL side channel port scheduler_connector.shutdown() @@ -403,7 +415,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): device_id=0, num_blocks=1, block_lens=self.block_len_per_layer, - attn_backend_name=self.backend_name, # `self.kv_cache_layout` is only forced to HND when vllm engine # is started. We mock HND here. kv_cache_layout="HND", @@ -450,7 +461,7 @@ class TestNixlHandshake: metadata = NixlConnectorMetadata() if num_xfers > 0: num_xfers -= 1 - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3], kv_transfer_params={ @@ -460,6 +471,7 @@ class TestNixlHandshake: num_xfers + 6, ], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, @@ -520,12 +532,13 @@ class TestNixlHandshake: vllm_config, connector.engine_id ) metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id="id", local_block_ids=[1, 2, 3], kv_transfer_params={ "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": "prefill-id", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": prefill_tp_size, @@ -575,12 +588,13 @@ class TestNixlHandshake: metadata = NixlConnectorMetadata() total_reqs = 5 for i in range(total_reqs): - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=f"id_{i}", local_block_ids=[1, 2, 3], kv_transfer_params={ "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-id-{i}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, @@ -651,7 +665,6 @@ class TestNixlHandshake: device_id=0, num_blocks=1, block_lens=worker.block_len_per_layer, - attn_backend_name=worker.backend_name, kv_cache_layout=mismatched_layout, block_size=worker.block_size, ) @@ -706,7 +719,6 @@ class TestNixlHandshake: num_blocks=1, # prefill TP=1, decode TP=2, remote block_lens is double to local block_lens=[i * 2 for i in worker.block_len_per_layer], - attn_backend_name=worker.backend_name, kv_cache_layout="HND", block_size=worker.block_size, ) @@ -740,12 +752,13 @@ def test_kv_connector_stats(dist_init): # Create transfer metadata request_id = "test_req_for_stats" metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[1, 2, 3], kv_transfer_params={ "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, @@ -1099,7 +1112,26 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): llm.llm_engine.engine_core.shutdown() -@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"]) +@pytest.mark.parametrize( + "attn_backend", + [ + pytest.param( + "FLASH_ATTN", + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="Attention backend FLASH_ATTN is not supported on ROCm", + ), + ), + pytest.param( + "ROCM_ATTN", + marks=pytest.mark.skipif( + not current_platform.is_rocm(), + reason="Attention backend ROCM_ATTN is only supported on ROCm", + ), + ), + "TRITON_ATTN", + ], +) def test_register_kv_caches(dist_init, attn_backend, monkeypatch): """ Test that register_kv_caches() properly calls nixl_wrapper methods with @@ -1121,6 +1153,10 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend backend_cls = FlashAttentionBackend + elif attn_backend == "ROCM_ATTN": + from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend + + backend_cls = RocmAttentionBackend else: # TRITON_ATTN from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend @@ -1139,25 +1175,43 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): } # Store tensor info for validation - expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - shared_tensor[1].data_ptr(), - unique_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - ] + test_shape = backend_cls.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1 + + if is_blocks_first: + expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel() + expected_base_addrs = [ + shared_tensor.data_ptr(), + unique_tensor.data_ptr(), + ] + expected_num_entries = 2 + else: + expected_tensor_size = ( + shared_tensor[0].element_size() * shared_tensor[0].numel() + ) + expected_base_addrs = [ + shared_tensor[0].data_ptr(), + shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), + unique_tensor[1].data_ptr(), + ] + expected_num_entries = 4 + + nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector" with ( - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" - ) as mock_nixl_wrapper, - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event" - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" - ) as mock_thread, - ): # noqa: E501 + patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, + patch(f"{nixl_module}.threading.Event"), + patch(f"{nixl_module}.threading.Thread") as mock_thread, + patch(f"{nixl_module}.get_attn_backend") as mock_get_attn_backend, + ): + # Ensure get_attn_backend returns the correct value due to + # _cached_get_attn_backend returning the backend from previous + # test run if not mocking. + mock_get_attn_backend.return_value = backend_cls + # Create connector connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( @@ -1168,6 +1222,9 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): mock_wrapper_instance = mock_nixl_wrapper.return_value connector.connector_worker.nixl_wrapper = mock_wrapper_instance + # Appease NixlHandshakePayload encoding with some bytes + mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata" + # Reassure the shutdown() check that the thread is terminated mock_thread.return_value.is_alive.return_value = False @@ -1177,7 +1234,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): # Verify get_reg_descs was called with caches_data assert mock_wrapper_instance.get_reg_descs.called caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] - assert len(caches_data) == 4 + assert len(caches_data) == expected_num_entries for i, cache_entry in enumerate(caches_data): base_addr, size, _tp_rank, _ = cache_entry @@ -1199,7 +1256,12 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" ) - expected_block_len = expected_tensor_size // 2 + num_blocks = 2 + if is_blocks_first: + expected_block_len = expected_tensor_size // num_blocks // 2 + else: + expected_block_len = expected_tensor_size // num_blocks + for i, block_entry in enumerate(blocks_data): block_start_addr, block_len, tp_rank = block_entry assert block_len == expected_block_len, ( @@ -1296,7 +1358,7 @@ def test_shutdown_cleans_up_resources(dist_init): patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent, patch.object(nixl_wrapper, "deregister_memory") as mock_dereg, ): - worker._recving_transfers = {"req1": [(123, time.perf_counter())]} + worker._recving_transfers = {"req1": [123]} worker.src_xfer_side_handle = 456 worker.dst_xfer_side_handles = {"engine1": 789} worker._remote_agents = {"engine1": {0: "agent1"}} @@ -1453,12 +1515,13 @@ def test_handshake_failure_returns_finished(dist_init): request_id = "test_handshake_fail" metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[1, 2, 3], kv_transfer_params={ "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, @@ -1502,12 +1565,13 @@ def test_transfer_setup_failure_returns_finished(dist_init): request_id = "test_transfer_fail" metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[7, 8, 9], kv_transfer_params={ "remote_block_ids": [10, 11, 12], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, @@ -1534,3 +1598,194 @@ def test_transfer_setup_failure_returns_finished(dist_init): # ensure request appears in get_finished _, done_recving = connector.get_finished(finished_req_ids=set()) assert request_id in done_recving + + +@pytest.mark.parametrize( + "mismatch_type,config_overrides,version_override,should_fail,enforce_handshake_compat", + [ + ("vllm_version", {}, {"vllm_version": "0.6.1"}, True, True), + ("nixl_connector_version", {}, {"connector_version": 37}, True, True), + ("model_name", {"model": "facebook/opt-350m"}, {}, True, True), + ("dtype", {"dtype": "bfloat16"}, {}, True, True), + ("cache_dtype", {"cache_dtype": "fp8"}, {}, True, True), + ("num_kv_heads", {"hf_overrides": {"num_key_value_heads": 8}}, {}, True, True), + ( + "num_hidden_layers", + {"hf_overrides": {"num_hidden_layers": 24}}, + {}, + True, + True, + ), + ("hidden_size", {"hf_overrides": {"hidden_size": 1536}}, {}, True, True), + ("block_size", {"block_size": 8}, {}, False, True), + ("matching_config", {}, {}, False, True), + ("escape_hatch", {"model": "facebook/opt-350m"}, {}, False, False), + ], +) +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_compatibility_hash_validation( + dist_init, + mismatch_type, + config_overrides, + version_override, + should_fail, + enforce_handshake_compat, +): + """ + Test NIXL compatibility hash validation during handshake. + + Parameters: + mismatch_type: description of what is being tested + config_overrides: dict of config to override for the remote instance + version_override: version dict e.g. {"vllm_version": "0.6.1"} + should_fail: whether the handshake should fail + enforce_handshake_compat: whether to enforce compatibility checking + """ + local_vllm_config = create_vllm_config( + model="facebook/opt-125m", + block_size=16, + kv_connector_extra_config={ + "enforce_handshake_compat": enforce_handshake_compat + }, + ) + decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) + decode_worker = decode_connector.connector_worker + + remote_config_params: dict[str, Any] = { + "model": "facebook/opt-125m", + "block_size": 16, + **config_overrides, + } + remote_vllm_config = create_vllm_config(**remote_config_params) + + with contextlib.ExitStack() as stack: + if "vllm_version" in version_override: + stack.enter_context( + patch("vllm.__version__", version_override["vllm_version"]) + ) + elif "connector_version" in version_override: + stack.enter_context( + patch.object( + nixl_connector, + "NIXL_CONNECTOR_VERSION", + version_override["connector_version"], + ) + ) + remote_hash = compute_nixl_compatibility_hash( + remote_vllm_config, decode_worker.backend_name + ) + + prefill_block_size = config_overrides.get("block_size", 16) + prefill_metadata = NixlAgentMetadata( + engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + device_id=0, + num_blocks=1, + block_lens=[4096 * prefill_block_size], # slot_size * block_size + kv_cache_layout="HND", + block_size=prefill_block_size, + ) + handshake_payload = NixlHandshakePayload( + compatibility_hash=remote_hash, + agent_metadata_bytes=msgspec.msgpack.encode(prefill_metadata), + ) + + # Mock ZMQ socket to return our handshake payload + mock_socket = MagicMock() + mock_socket.recv.return_value = msgspec.msgpack.encode(handshake_payload) + + # Mock add_remote_agent to avoid actual NIXL operations + # Patch zmq_ctx to return our mock socket + with ( + patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"), + patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx, + ): + mock_zmq_ctx.return_value.__enter__.return_value = mock_socket + + if should_fail: + with pytest.raises(RuntimeError, match="compatibility hash mismatch"): + decode_worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=1, + expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + ) + else: + result = decode_worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=1, + expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + ) + # Verify handshake returned agent mapping + assert isinstance(result, dict) + assert len(result) == 1 + + +@pytest.mark.parametrize( + "error_scenario", + [ + "handshake_decode_error", + "handshake_validation_error", + "metadata_decode_error", + "metadata_validation_error", + ], +) +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_handshake_decode_errors(dist_init, error_scenario): + """ + Test that msgspec decode errors are properly handled during handshake. + + Tests both DecodeError and ValidationError for both decoders: + - NixlHandshakePayload decoder + - NixlAgentMetadata decoder + """ + local_vllm_config = create_vllm_config( + model="facebook/opt-125m", + block_size=16, + ) + decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER) + decode_worker = decode_connector.connector_worker + + if error_scenario == "handshake_decode_error": + msg_bytes = b"this is not valid msgpack data" + elif error_scenario == "handshake_validation_error": + msg_bytes = msgspec.msgpack.encode({"wrong_field": "value"}) + elif error_scenario == "metadata_decode_error": + valid_handshake = NixlHandshakePayload( + compatibility_hash=decode_worker.compat_hash, + agent_metadata_bytes=b"invalid msgpack for metadata", + ) + msg_bytes = msgspec.msgpack.encode(valid_handshake) + + elif error_scenario == "metadata_validation_error": + valid_handshake = NixlHandshakePayload( + compatibility_hash=decode_worker.compat_hash, + agent_metadata_bytes=msgspec.msgpack.encode({"missing": "fields"}), + ) + msg_bytes = msgspec.msgpack.encode(valid_handshake) + else: + raise AssertionError(f"{error_scenario} not a valid scenario") + + mock_socket = MagicMock() + mock_socket.recv.return_value = msg_bytes + with ( + patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"), + patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx, + ): + mock_zmq_ctx.return_value.__enter__.return_value = mock_socket + + with pytest.raises(RuntimeError): + decode_worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=1, + expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + ) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index f35f91bb3adf8..5cdb1f84b30d4 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -24,8 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata, KVConnectorRole, ) -from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa - SharedStorageConnector, +from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa + ExampleConnector, ) from vllm.utils.hashing import sha256 from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -90,32 +90,39 @@ def create_vllm_config( max_model_len: int = 10000, enable_chunked_prefill: bool = True, enable_permute_local_kv: bool = False, + kv_connector_extra_config: dict[str, Any] | None = None, + dtype: str = "float16", + cache_dtype: str = "auto", + hf_overrides: dict[str, Any] | None = None, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype=dtype, + seed=42, + hf_overrides=hf_overrides or {}, + ) scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len, enable_chunked_prefill=enable_chunked_prefill, - ) - model_config = ModelConfig( - model=model, - trust_remote_code=True, - dtype="float16", - seed=42, + is_encoder_decoder=model_config.is_encoder_decoder, ) # Cache config, optionally force APC cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, - cache_dtype="auto", + cache_dtype=cache_dtype, enable_prefix_caching=True, ) kv_transfer_config = KVTransferConfig( kv_connector="NixlConnector", kv_role="kv_both", enable_permute_local_kv=enable_permute_local_kv, + kv_connector_extra_config=kv_connector_extra_config or {}, ) return VllmConfig( scheduler_config=scheduler_config, @@ -187,6 +194,7 @@ def create_request( do_remote_prefill=True, do_remote_decode=False, remote_engine_id="my-engine-id", + remote_request_id=f"prefill-{request_id}", remote_block_ids=list(range(num_remote_blocks)), remote_host="my-host", remote_port=1234, @@ -256,10 +264,10 @@ def create_model_runner_output( ) -class TestSharedStorageConnector(SharedStorageConnector): +class TestExampleConnector(ExampleConnector): def __init__(self, config: VllmConfig, role, kv_cache_config): self.name = config.kv_transfer_config.kv_connector_extra_config["name"] - self._connector = SharedStorageConnector(config, role) + self._connector = ExampleConnector(config, role) self.call_record: dict[str, int] = defaultdict(int) # Use a unique temp file per connector self._event_file = ( @@ -386,7 +394,7 @@ class MockKVConnector(KVConnectorBase_V1): KVConnectorFactory.register_connector( - "TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__ + "TestExampleConnector", __name__, TestExampleConnector.__name__ ) KVConnectorFactory.register_connector( diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py index a248104e16d2d..3516c0013879d 100644 --- a/tests/v1/kv_offload/test_cpu_gpu.py +++ b/tests/v1/kv_offload/test_cpu_gpu.py @@ -9,7 +9,7 @@ import torch from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec -from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler +from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers BACKENDS_TO_TEST = [FlashAttentionBackend] @@ -82,7 +82,7 @@ def test_transfer( # create handler cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size - handler = CpuGpuOffloadingHandler( + handlers = CpuGpuOffloadingHandlers( attn_backends=attn_backends, gpu_block_size=gpu_block_size, cpu_block_size=cpu_block_size, @@ -112,8 +112,7 @@ def test_transfer( # set transfer direction if gpu_to_cpu: - src_kv_caches = handler.gpu_tensors - dst_kv_caches = handler.cpu_tensors + handler = handlers.gpu_to_cpu_handler src_spec_class = GPULoadStoreSpec dst_spec_class = CPULoadStoreSpec src_blocks = gpu_blocks @@ -122,8 +121,7 @@ def test_transfer( dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block else: - src_kv_caches = handler.cpu_tensors - dst_kv_caches = handler.gpu_tensors + handler = handlers.cpu_to_gpu_handler src_spec_class = CPULoadStoreSpec dst_spec_class = GPULoadStoreSpec src_blocks = cpu_blocks @@ -144,12 +142,12 @@ def test_transfer( dst_spec = dst_spec_class(dst_blocks) # clone src and dst tensors before transfer - orig_src_caches = [x.clone() for x in src_kv_caches] - orig_dst_caches = [x.clone() for x in dst_kv_caches] + orig_src_caches = [x.clone() for x in handler.src_tensors] + orig_dst_caches = [x.clone() for x in handler.dst_tensors] # call transfer function assert handler.transfer_async(1, (src_spec, dst_spec)) - assert set(handler.transfer_events.keys()) == {1} + assert set({x[0] for x in handler._transfers}) == {1} # wait for transfer to complete end_time = time.time() + 10 @@ -161,15 +159,15 @@ def test_transfer( time.sleep(0.1) # verify src tensors did not change - for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches): + for orig_tensor, tensor in zip(orig_src_caches, handler.src_tensors): assert torch.equal(orig_tensor, tensor) # verify dst tensors for dst_block in range(dst_size_in_gpu_blocks): src_block_candidate = dst_to_src.get(dst_block) for src_cache, dst_cache, orig_dst_cache, kv_dim in zip( - src_kv_caches, - dst_kv_caches, + handler.src_tensors, + handler.dst_tensors, orig_dst_caches, handler.kv_dim_before_num_blocks, ): diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py index 1899737737f4b..e3ddb6138cfdd 100644 --- a/tests/v1/logits_processors/test_custom_offline.py +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -import sys from typing import Any import pytest @@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test from tests.v1.logits_processors.utils import ( DUMMY_LOGITPROC_ARG, DUMMY_LOGITPROC_FQCN, - DUMMY_LOGITPROC_MODULE, MAX_TOKENS, MODEL_NAME, POOLING_MODEL_NAME, @@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import ( CustomLogitprocSource, DummyLogitsProcessor, WrappedPerReqLogitsProcessor, - dummy_module, prompts, ) from tests.v1.logits_processors.utils import entry_points as fake_entry_points @@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource kwargs: dict[str, list[str | type[LogitsProcessor]]] = {} if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: # Scenario: load logitproc based on fully-qualified class name (FQCN) - # Inject dummy module which defines logitproc - sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: # Scenario: load logitproc from provided class object diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py index 3e0bb02ed68be..3dc6b89790157 100644 --- a/tests/v1/logits_processors/test_custom_online.py +++ b/tests/v1/logits_processors/test_custom_online.py @@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te from tests.v1.logits_processors.utils import ( DUMMY_LOGITPROC_ARG, DUMMY_LOGITPROC_FQCN, - DUMMY_LOGITPROC_MODULE, MAX_TOKENS, MODEL_NAME, TEMP_GREEDY, - dummy_module, prompts, ) from tests.v1.logits_processors.utils import entry_points as fake_entry_points @@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint( main.main() -def _server_with_logitproc_module( +def _server_with_logitproc_fqcn( env_dict: dict[str, str] | None, model: str, vllm_serve_args: list[str], ) -> None: """Start vLLM server, inject module with dummy logitproc""" - - # Patch `modules` to inject dummy logitproc module from vllm.entrypoints.cli import main - sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module - - # fork is required for workers to see entrypoint patch - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork" if env_dict is not None: os.environ.update(env_dict) @@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch): if request.param: # Launch server, append FQCN argument, inject dummy logitproc module args = default_server_args + request.param - _server_fxn = _server_with_logitproc_module + _server_fxn = _server_with_logitproc_fqcn else: # Launch server, inject dummy logitproc entrypoint args = default_server_args diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index b8548bc319554..e54da72e5e2ed 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token" TEMP_GREEDY = 0.0 MAX_TOKENS = 20 DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc" -DUMMY_LOGITPROC_MODULE = "DummyModule" +DUMMY_LOGITPROC_MODULE = "tests.v1.logits_processors.utils" DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor" diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py index 48067def8357e..7d902bbc6fc24 100644 --- a/tests/v1/metrics/test_stats.py +++ b/tests/v1/metrics/test_stats.py @@ -1,8 +1,109 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.v1.metrics.stats import IterationStats +from vllm.v1.engine import FinishReason +from vllm.v1.metrics.stats import IterationStats, RequestStateStats def test_iteration_stats_repr(): iteration_stats = IterationStats() assert repr(iteration_stats).startswith("IterationStats(") + + +def test_prefill_kv_computed_with_cache(): + """Test that prefill KV compute correctly excludes cached tokens.""" + iteration_stats = IterationStats() + req_stats = RequestStateStats(arrival_time=0.0) + req_stats.scheduled_ts = 0.1 + req_stats.first_token_ts = 0.5 + req_stats.last_token_ts = 5.0 + req_stats.num_generation_tokens = 50 + + # Case 1: With prefix cache (1200 tokens cached) + iteration_stats.update_from_finished_request( + finish_reason=FinishReason.STOP, + num_prompt_tokens=10000, + max_tokens_param=100, + req_stats=req_stats, + num_cached_tokens=1200, + ) + + finished_req = iteration_stats.finished_requests[0] + assert finished_req.num_prompt_tokens == 10000 + assert finished_req.num_cached_tokens == 1200 + + # Verify calculation: prefill KV = prompt tokens - cached tokens + prefill_kv_computed = finished_req.num_prompt_tokens - max( + finished_req.num_cached_tokens, 0 + ) + assert prefill_kv_computed == 8800 # 10000 - 1200 + + +def test_prefill_kv_computed_no_cache(): + """Test prefill KV compute without prefix caching.""" + iteration_stats = IterationStats() + req_stats = RequestStateStats(arrival_time=0.0) + req_stats.scheduled_ts = 0.1 + req_stats.first_token_ts = 0.5 + req_stats.last_token_ts = 2.0 + req_stats.num_generation_tokens = 10 + + # Case 2: No prefix cache + iteration_stats.update_from_finished_request( + finish_reason=FinishReason.STOP, + num_prompt_tokens=2000, + max_tokens_param=100, + req_stats=req_stats, + num_cached_tokens=0, + ) + + finished_req = iteration_stats.finished_requests[0] + assert finished_req.num_prompt_tokens == 2000 + assert finished_req.num_cached_tokens == 0 + + # Verify calculation: prefill KV = full prompt when no cache + prefill_kv_computed = finished_req.num_prompt_tokens - max( + finished_req.num_cached_tokens, 0 + ) + assert prefill_kv_computed == 2000 + + +def test_prefill_kv_computed_edge_cases(): + """Test edge cases for prefill KV compute calculation.""" + iteration_stats = IterationStats() + req_stats = RequestStateStats(arrival_time=0.0) + req_stats.scheduled_ts = 0.1 + req_stats.first_token_ts = 0.5 + req_stats.last_token_ts = 1.0 + req_stats.num_generation_tokens = 1 + + # Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully) + iteration_stats.update_from_finished_request( + finish_reason=FinishReason.STOP, + num_prompt_tokens=100, + max_tokens_param=10, + req_stats=req_stats, + num_cached_tokens=-1, + ) + + finished_req = iteration_stats.finished_requests[0] + # max() should handle negative values + prefill_kv_computed = finished_req.num_prompt_tokens - max( + finished_req.num_cached_tokens, 0 + ) + assert prefill_kv_computed == 100 # Should treat negative as 0 + + # Case 4: All tokens cached (shouldn't happen in practice) + iteration_stats2 = IterationStats() + iteration_stats2.update_from_finished_request( + finish_reason=FinishReason.STOP, + num_prompt_tokens=100, + max_tokens_param=10, + req_stats=req_stats, + num_cached_tokens=100, + ) + + finished_req2 = iteration_stats2.finished_requests[0] + prefill_kv_computed2 = finished_req2.num_prompt_tokens - max( + finished_req2.num_cached_tokens, 0 + ) + assert prefill_kv_computed2 == 0 # All cached, nothing computed diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index c89c33be80c10..76a0e8e25a4ae 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): ), ], ) +@pytest.mark.parametrize("top_logprobs", [0, 3]) def test_spec_decode_logprobs( logprobs_mode: LogprobsMode, model_setup: tuple[str, str, str], + top_logprobs: int, ): """Spec decode logprobs should match those of the base model. @@ -543,7 +545,7 @@ def test_spec_decode_logprobs( prompt = "Hello world " * 50 sampling_params = SamplingParams( - temperature=0, logprobs=3, max_tokens=10, ignore_eos=False + temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False ) method, model_name, spec_model_name = model_setup max_model_len = 256 diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index bf7726ebf907f..61caffee45daf 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -111,7 +111,7 @@ def create_sampling_metadata( top_p=top_p, top_k=top_k, generators=generators, - max_num_logprobs=0, + max_num_logprobs=None, no_penalties=no_penalties, prompt_token_ids=prompt_token_ids, frequency_penalties=frequency_penalties, diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index 1684252174d3d..a75a37befe0e1 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -106,6 +106,25 @@ def test_detokenize_false(llm): def test_bad_words(llm): """Check that we respect bad words.""" + tokenizer = llm.get_tokenizer() + + def contains_bad_word(text: str, tokens: list[int], bad_word: str) -> bool: + """Check if word appears in BOTH text and token sequence.""" + if bad_word not in text: + return False + + for add_prefix_space in [False, True]: + prefix = " " if add_prefix_space else "" + bad_words_token = tokenizer.encode( + prefix + bad_word.lstrip(), add_special_tokens=False + ) + if not bad_words_token: + continue + for i in range(len(tokens) - len(bad_words_token) + 1): + if tokens[i : i + len(bad_words_token)] == bad_words_token: + return True + return False + output = llm.generate(PROMPT, SamplingParams(temperature=0)) split_text = output[0].outputs[0].text.split() @@ -113,14 +132,16 @@ def test_bad_words(llm): params = SamplingParams(temperature=0, bad_words=[bad_words_1]) output = llm.generate(PROMPT, params) new_text = output[0].outputs[0].text - assert bad_words_1 not in new_text + new_tokens = output[0].outputs[0].token_ids + assert not contains_bad_word(new_text, new_tokens, bad_words_1) bad_words_2 = new_text.split()[-1] params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2]) output = llm.generate(PROMPT, params) new_text = output[0].outputs[0].text - assert bad_words_1 not in new_text - assert bad_words_2 not in new_text + new_tokens = output[0].outputs[0].token_ids + assert not contains_bad_word(new_text, new_tokens, bad_words_1) + assert not contains_bad_word(new_text, new_tokens, bad_words_2) def test_logits_processor(llm): diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 9436ab471c21b..55e9b4d0660f5 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -66,7 +66,10 @@ def _create_proposer( device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), ) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) @@ -336,7 +339,7 @@ def test_load_model( "multi-token eagle spec decode on current platform" ) - if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): + if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") # Setup draft model mock @@ -431,7 +434,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): "because it requires special input mocking." ) - if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): + if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") # Use GPU device @@ -538,6 +541,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): attn_metadata_builder_cls, _ = try_get_attention_backend( AttentionBackendEnum.TREE_ATTN ) + elif attn_backend == "ROCM_AITER_FA": + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.ROCM_AITER_FA + ) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index fa1d0437f7c71..15a6bd2659ea9 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -47,7 +47,7 @@ def test_eagle_max_len( "multi-token eagle spec decode on current platform" ) - if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): + if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): m.setenv("VLLM_ROCM_USE_AITER", "1") llm = LLM( @@ -82,7 +82,7 @@ def test_eagle_max_len( len(o.prompt_token_ids) < 80 < len(o.prompt_token_ids) + len(o.outputs[0].token_ids) - < 200 + <= 200 ), ( "This test is only meaningful if the output " "is longer than the eagle max length" diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index c5c0491abaf7c..3b8813ceb818a 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -51,7 +51,10 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), ) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 692c39282c372..6bc412abe8695 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -2,7 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import numpy as np -from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig +from vllm.config import ( + ModelConfig, + SpeculativeConfig, + VllmConfig, +) from vllm.v1.spec_decode.ngram_proposer import ( NgramProposer, _find_longest_matched_ngram_and_propose_tokens, @@ -167,6 +171,34 @@ def test_ngram_proposer(): assert np.array_equal(result[0], np.array([3, 1])) assert np.array_equal(result[1], np.array([])) + # Test non-contiguous indices: requests 0 and 2 need proposals, + # request 1 is in prefill + proposer = get_ngram_proposer(min_n=2, max_n=2, k=2) + max_model_len = 20 + token_ids_cpu = np.zeros((3, max_model_len), dtype=np.int32) + token_ids_cpu[0, :5] = [1, 2, 3, 1, 2] + token_ids_cpu[1, :3] = [4, 5, 6] + token_ids_cpu[2, :5] = [7, 8, 9, 7, 8] + num_tokens_no_spec = np.array([5, 3, 5], dtype=np.int32) + sampled_token_ids = [[2], [], [8]] # Empty list for request 1 simulates prefill + result = proposer.propose( + sampled_token_ids=sampled_token_ids, + req_ids=["0", "1", "2"], + num_tokens_no_spec=num_tokens_no_spec, + token_ids_cpu=token_ids_cpu, + spec_decode_unsupported_reqs=(), + ) + assert len(result) == 3 + assert np.array_equal(result[0], [3, 1]) + assert len(result[1]) == 0 + assert np.array_equal(result[2], [9, 7]) + # Verify internal arrays written to correct indices + assert proposer.valid_ngram_num_drafts[0] == 2 + assert proposer.valid_ngram_num_drafts[1] == 0 + assert proposer.valid_ngram_num_drafts[2] == 2 + assert np.array_equal(proposer.valid_ngram_draft[0, :2], [3, 1]) + assert np.array_equal(proposer.valid_ngram_draft[2, :2], [9, 7]) + # test if 0 threads available: can happen if TP size > CPU count ngram_proposer = get_ngram_proposer(min_n=2, max_n=2, k=2) ngram_proposer.num_numba_thread_available = 0 diff --git a/tests/v1/spec_decode/test_speculators_eagle3.py b/tests/v1/spec_decode/test_speculators_eagle3.py index 5ce6e1593b5c1..9a252cfffc8f0 100644 --- a/tests/v1/spec_decode/test_speculators_eagle3.py +++ b/tests/v1/spec_decode/test_speculators_eagle3.py @@ -5,6 +5,7 @@ import torch from vllm.config import SpeculativeConfig from vllm.model_executor.models.interfaces import supports_eagle3 +from vllm.platforms import current_platform @pytest.mark.parametrize( @@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3 pytest.param( "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", id="qwen3-eagle3-speculator-w4a16-verifier", + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="The tests are skipped on rocm platform.", + ), ), ], ) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index a4ee53008ce82..0afeeb8914b87 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -88,8 +88,8 @@ def forward_attention( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc.cpu(), seq_lens=seq_lens, - seq_lens_cpu=seq_lens.cpu(), - num_computed_tokens_cpu=context_lens.cpu(), + _seq_lens_cpu=seq_lens.cpu(), + _num_computed_tokens_cpu=context_lens.cpu(), num_reqs=batch_size, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, diff --git a/tests/v1/structured_output/test_backend_guidance.py b/tests/v1/structured_output/test_backend_guidance.py index 771076186a3b4..4c01560fc88c3 100644 --- a/tests/v1/structured_output/test_backend_guidance.py +++ b/tests/v1/structured_output/test_backend_guidance.py @@ -1,9 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +from concurrent.futures import Future + +import pytest from transformers import AutoTokenizer from vllm.config import StructuredOutputsConfig, VllmConfig from vllm.config.model import ModelConfig +from vllm.config.parallel import ParallelConfig from vllm.config.speculative import SpeculativeConfig from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.v1.request import Request @@ -116,3 +121,72 @@ def test_grammar_bitmask_with_specdec(): ) # EOS not the final token grammar_bitmask(request, prompt[i:]) # EOS not present grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) + + +@pytest.mark.parametrize("async_grammar", [True, False]) +def test_grammar_init_async_and_sync(async_grammar): + """Test grammar initialization works correctly in both async and sync modes. + + This test validates that the distributed_executor_backend config option + correctly controls whether grammar compilation happens asynchronously + (via executor.submit) or synchronously. When set to "external_launcher", + grammar compilation is synchronous to avoid deadlocks. + """ + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + prompt = tokenizer.encode('{"a": "b"}') + + # Use "external_launcher" for sync mode, None for async mode + executor_backend = None if async_grammar else "external_launcher" + vllm_config = VllmConfig( + model_config=ModelConfig(tokenizer=TOKENIZER), + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), + parallel_config=ParallelConfig(distributed_executor_backend=executor_backend), + ) + structured_output_manager = StructuredOutputManager(vllm_config) + + sampling_params = SamplingParams( + structured_outputs=StructuredOutputsParams( + json='{"type": "object"}', + ), + ) + sampling_params.structured_outputs._backend = "guidance" + + request = Request( + "test_request", + prompt_token_ids=prompt, + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=tokenizer.eos_token_id, + ) + + structured_output_manager.grammar_init(request) + + # Check the internal _grammar type immediately after init + # Before _check_grammar_completion is called, async mode should have a Future + raw_grammar = request.structured_output_request._grammar + if async_grammar: + assert isinstance(raw_grammar, Future), ( + "Async mode should store a Future before completion" + ) + else: + assert not isinstance(raw_grammar, Future), ( + "Sync mode should store the grammar directly, not a Future" + ) + + # Wait for grammar to be ready (handles both async and sync cases) + start_time = time.time() + while not request.structured_output_request._check_grammar_completion(): + if time.time() - start_time > 5: # 5-second timeout + pytest.fail("Grammar compilation timed out") + time.sleep(0.01) + + # After completion, _grammar should no longer be a Future + assert not isinstance(request.structured_output_request._grammar, Future) + + # Verify grammar is properly initialized and functional + grammar = request.structured_output_request.grammar + assert grammar is not None + assert not grammar.is_terminated() + + # Verify the grammar can accept valid tokens + assert grammar.accept_tokens(request.request_id, prompt) diff --git a/tests/v1/structured_output/test_reasoning_structured_output.py b/tests/v1/structured_output/test_reasoning_structured_output.py index 70047a993c3f9..ba52af3ad604d 100644 --- a/tests/v1/structured_output/test_reasoning_structured_output.py +++ b/tests/v1/structured_output/test_reasoning_structured_output.py @@ -70,6 +70,7 @@ class TestReasoningStructuredOutput: request.use_structured_output = True request.prompt_token_ids = [1, 2, 3, 4, 5] request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8] + request.num_computed_tokens = 5 return request def test_should_fill_bitmask_with_enable_in_reasoning( diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index 513a21dd6bb39..c026ab0e4e785 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -44,8 +44,6 @@ def unsupported_array_schemas(): @pytest.fixture def unsupported_object_schemas(): return [ - {"type": "object", "minProperties": 1}, - {"type": "object", "maxProperties": 5}, {"type": "object", "propertyNames": {"pattern": "^[a-z]+$"}}, {"type": "object", "patternProperties": {"^S": {"type": "string"}}}, ] @@ -79,6 +77,8 @@ def supported_schema(): }, }, }, + "minProperties": 1, + "maxProperties": 100, } diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 00749c5415c8e..dbbbfce97d286 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct): def test_multimodal_kwargs(): e1 = MultiModalFieldElem( - "audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField() + "audio", + "a0", + torch.zeros(1000, dtype=torch.bfloat16), + MultiModalBatchedField(), ) e2 = MultiModalFieldElem( "video", "v0", [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], - MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), + MultiModalFlatField( + slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], + dim=0, + ), ) e3 = MultiModalFieldElem( - "image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4) + "image", + "i0", + torch.zeros(1000, dtype=torch.int32), + MultiModalSharedField(batch_size=4), ) e4 = MultiModalFieldElem( "image", "i1", torch.zeros(1000, dtype=torch.int32), - MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2), + MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2), ) audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) @@ -138,8 +147,8 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) - # expected total encoding length, should be 14306, +-20 for minor changes - assert 14275 <= total_len <= 14325 + # expected total encoding length, should be 14395, +-20 for minor changes + assert 14375 <= total_len <= 14425 decoded = decoder.decode(encoded).mm[0] assert isinstance(decoded, MultiModalKwargsItems) diff --git a/tests/v1/tpu/test_perf.py b/tests/v1/tpu/test_perf.py index e230491cddb01..e62b969fe3b95 100644 --- a/tests/v1/tpu/test_perf.py +++ b/tests/v1/tpu/test_perf.py @@ -14,7 +14,7 @@ import pytest from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import get_tokenizer if TYPE_CHECKING: from tests.conftest import VllmRunner diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 7b3a07b4e12a5..cfc06666e7984 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -26,16 +26,17 @@ from vllm.v1.worker.tpu_model_runner import ( def get_vllm_config(): - scheduler_config = SchedulerConfig( - max_num_seqs=10, - max_num_batched_tokens=512, - max_model_len=512, - ) model_config = ModelConfig( model="facebook/opt-125m", dtype="bfloat16", # TPUs typically use bfloat16 seed=42, ) + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + is_encoder_decoder=model_config.is_encoder_decoder, + ) cache_config = CacheConfig( block_size=16, gpu_memory_utilization=0.9, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 89669ee8b71a0..7b8c4268a5237 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -6,8 +6,10 @@ import pytest import torch from vllm.attention.backends.abstract import MultipleOf +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import Attention from vllm.config import ( + AttentionConfig, CacheConfig, ModelConfig, ParallelConfig, @@ -79,16 +81,17 @@ def initialize_kv_cache(runner: GPUModelRunner): def get_vllm_config(): - scheduler_config = SchedulerConfig( - max_num_seqs=10, - max_num_batched_tokens=512, - max_model_len=512, - ) model_config = ModelConfig( model="facebook/opt-125m", dtype="float16", seed=42, ) + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + is_encoder_decoder=model_config.is_encoder_decoder, + ) cache_config = CacheConfig( block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, @@ -760,7 +763,11 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1 -def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): +@pytest.mark.skipif( + current_platform.is_rocm(), + reason="Attention backend FLASHINFER is not supported on ROCm.", +) +def test_hybrid_attention_mamba_tensor_shapes(): """ The GPU model runner creates different views into the KVCacheTensors for the attention and mamba layers @@ -784,14 +791,15 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): initialize_model_parallel(tensor_model_parallel_size=1) torch.set_default_dtype(torch.float16) + model_config = ModelConfig( + model="ibm-granite/granite-4.0-tiny-preview", + dtype="float16", + ) scheduler_config = SchedulerConfig( max_num_seqs=10, max_num_batched_tokens=512, max_model_len=512, - ) - model_config = ModelConfig( - model="ibm-granite/granite-4.0-tiny-preview", - dtype="float16", + is_encoder_decoder=model_config.is_encoder_decoder, ) cache_config = CacheConfig( block_size=BLOCK_SIZE, @@ -800,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): cache_dtype="auto", ) parallel_config = ParallelConfig() + attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER) vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, scheduler_config=scheduler_config, parallel_config=parallel_config, + attention_config=attention_config, ) layer_0 = "model.layers.0.self_attn.attn" @@ -814,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): layer_4 = "model.layers.4.mixer" layer_5 = "model.layers.5.mixer" - with set_current_vllm_config(vllm_config), monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + with set_current_vllm_config(vllm_config): hf_config = vllm_config.model_config.hf_config fwd_context = {} for key in [layer_0, layer_1]: @@ -845,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ) # suppress var not used error assert fwd_context is not None - vllm_ctx = vllm_config.compilation_config.static_forward_context - - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + vllm_ctx = vllm_config.compilation_config.static_forward_context runner = GPUModelRunner(vllm_config, DEVICE) kv_cache_spec = runner.get_kv_cache_spec() @@ -859,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): )[0] runner.initialize_kv_cache(kv_cache_config) - # random partition of blocks - # blocks0 will be assigned to attention layers - # blocks1 will be assigned to mamba layers - num_blocks = kv_cache_config.num_blocks - ind = np.arange(num_blocks) - np.random.shuffle(ind) - blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] + # random partition of blocks + # blocks0 will be assigned to attention layers + # blocks1 will be assigned to mamba layers + num_blocks = kv_cache_config.num_blocks + ind = np.arange(num_blocks) + np.random.shuffle(ind) + blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] - attn_shape = vllm_ctx[layer_0].kv_cache[0].shape - conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape - ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape + attn_shape = vllm_ctx[layer_0].kv_cache[0].shape + conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape + ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape - # assert we are using FlashInfer - assert attn_shape[0] % num_blocks == 0 - block_split_ratio = attn_shape[0] // num_blocks + # assert we are using FlashInfer + assert attn_shape[0] % num_blocks == 0 + block_split_ratio = attn_shape[0] // num_blocks - # use small blocks for testing to avoid memory issues - test_block_size = min(2, len(blocks0), len(blocks1)) + # use small blocks for testing to avoid memory issues + test_block_size = min(2, len(blocks0), len(blocks1)) - # use non-overlapping blocks to avoid data contamination - # Split kernel blocks: first half for attention, second half for mamba - mid_point = num_blocks // 2 + # use non-overlapping blocks to avoid data contamination + # Split kernel blocks: first half for attention, second half for mamba + mid_point = num_blocks // 2 - # attention uses kernel blocks from first half (mapped to logical blocks) - kv_blocks_for_attention = np.array([0, 1])[:test_block_size] + # attention uses kernel blocks from first half (mapped to logical blocks) + kv_blocks_for_attention = np.array([0, 1])[:test_block_size] - # mamba uses kernel blocks from second half - kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size] + # mamba uses kernel blocks from second half + kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size] - # create small constant tensors for testing with corrected shapes - # attention: [block_size, ...] starting from dimension 2 - attn_constant_shape = attn_shape[2:] - conv_constant_shape = conv_shape[1:] - ssm_constant_shape = ssm_shape[1:] + # create small constant tensors for testing with corrected shapes + # attention: [block_size, ...] starting from dimension 2 + attn_constant_shape = attn_shape[2:] + conv_constant_shape = conv_shape[1:] + ssm_constant_shape = ssm_shape[1:] - attn_blocks_constant = torch.full( - (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 - ) - conv_blocks_constant = torch.full( - (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 - ) - ssm_blocks_constant = torch.full( - (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 - ) + attn_blocks_constant = torch.full( + (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 + ) + conv_blocks_constant = torch.full( + (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 + ) + ssm_blocks_constant = torch.full( + (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 + ) - # Fill attention blocks with constants using kv block indices - kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio + # Fill attention blocks with constants using kv block indices + kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio - for layer in [layer_0, layer_1]: - # attention: kv_cache[0][kernel_block_idx, kv_idx, ...] - for i, kernel_block in enumerate(kernel_blocks_for_attention): - vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] + for layer in [layer_0, layer_1]: + # attention: kv_cache[0][kernel_block_idx, kv_idx, ...] + for i, kernel_block in enumerate(kernel_blocks_for_attention): + vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] - # fill mamba blocks with constants using kernel block indices - for layer in [layer_2, layer_3, layer_4, layer_5]: - # mamba: kv_cache[0][component][kernel_block_idx, ...] - for i, kv_block in enumerate(kv_blocks_for_mamba): - vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] - vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] + # fill mamba blocks with constants using kernel block indices + for layer in [layer_2, layer_3, layer_4, layer_5]: + # mamba: kv_cache[0][component][kernel_block_idx, ...] + for i, kv_block in enumerate(kv_blocks_for_mamba): + vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] + vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] - # verify attention and mamba contents are correct - for layer in [layer_0, layer_1]: - for i, kernel_block in enumerate(kernel_blocks_for_attention): - actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] - expected = attn_blocks_constant[i] + # verify attention and mamba contents are correct + for layer in [layer_0, layer_1]: + for i, kernel_block in enumerate(kernel_blocks_for_attention): + actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] + expected = attn_blocks_constant[i] - # Check K and V separately - assert torch.equal(actual_kv[0], expected) - assert torch.equal(actual_kv[1], expected) + # Check K and V separately + assert torch.equal(actual_kv[0], expected) + assert torch.equal(actual_kv[1], expected) - for layer in [layer_2, layer_3, layer_4, layer_5]: - for i, kv_block in enumerate(kv_blocks_for_mamba): - actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] - actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] - expected_conv = conv_blocks_constant[i] - expected_ssm = ssm_blocks_constant[i] + for layer in [layer_2, layer_3, layer_4, layer_5]: + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] - assert torch.equal(actual_conv, expected_conv) - assert torch.equal(actual_ssm, expected_ssm) + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) - for layer in [layer_2, layer_3, layer_4, layer_5]: - for i, kv_block in enumerate(kv_blocks_for_mamba): - actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] - actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] - expected_conv = conv_blocks_constant[i] - expected_ssm = ssm_blocks_constant[i] - assert torch.equal(actual_conv, expected_conv) - assert torch.equal(actual_ssm, expected_ssm) + for layer in [layer_2, layer_3, layer_4, layer_5]: + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) def test_hybrid_block_table_initialization(): diff --git a/tests/v1/worker/test_gpu_profiler.py b/tests/v1/worker/test_gpu_profiler.py index f7255fae05a4e..933ea42f18cd5 100644 --- a/tests/v1/worker/test_gpu_profiler.py +++ b/tests/v1/worker/test_gpu_profiler.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -import vllm.envs as envs -from vllm.profiler.gpu_profiler import WorkerProfiler +from vllm.config import ProfilerConfig +from vllm.profiler.wrapper import WorkerProfiler class ConcreteWorkerProfiler(WorkerProfiler): @@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler): A basic implementation of a worker profiler for testing purposes. """ - def __init__(self): + def __init__(self, profiler_config: ProfilerConfig): self.start_call_count = 0 self.stop_call_count = 0 self.should_fail_start = False - super().__init__() + super().__init__(profiler_config) def _start(self) -> None: if self.should_fail_start: @@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler): self.stop_call_count += 1 -@pytest.fixture(autouse=True) -def reset_mocks(): - """Fixture to reset mocks and env variables before each test.""" - envs.VLLM_PROFILER_DELAY_ITERS = 0 - envs.VLLM_PROFILER_MAX_ITERS = 0 +@pytest.fixture +def default_profiler_config(): + return ProfilerConfig( + profiler="torch", + torch_profiler_dir="/tmp/mock", + delay_iterations=0, + max_iterations=0, + ) -def test_immediate_start_stop(): +def test_immediate_start_stop(default_profiler_config): """Test standard start without delay.""" - profiler = ConcreteWorkerProfiler() - + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.start() assert profiler._running is True assert profiler._active is True @@ -48,10 +50,10 @@ def test_immediate_start_stop(): assert profiler.stop_call_count == 1 -def test_delayed_start(): +def test_delayed_start(default_profiler_config): """Test that profiler waits for N steps before actually starting.""" - envs.VLLM_PROFILER_DELAY_ITERS = 2 - profiler = ConcreteWorkerProfiler() + default_profiler_config.delay_iterations = 2 + profiler = ConcreteWorkerProfiler(default_profiler_config) # User requests start profiler.start() @@ -71,10 +73,10 @@ def test_delayed_start(): assert profiler.start_call_count == 1 -def test_max_iterations(): +def test_max_iterations(default_profiler_config): """Test that profiler stops automatically after max iterations.""" - envs.VLLM_PROFILER_MAX_ITERS = 2 - profiler = ConcreteWorkerProfiler() + default_profiler_config.max_iterations = 2 + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.start() assert profiler._running is True @@ -95,12 +97,11 @@ def test_max_iterations(): assert profiler.stop_call_count == 1 -def test_delayed_start_and_max_iters(): +def test_delayed_start_and_max_iters(default_profiler_config): """Test combined delayed start and max iterations.""" - envs.VLLM_PROFILER_DELAY_ITERS = 2 - envs.VLLM_PROFILER_MAX_ITERS = 2 - profiler = ConcreteWorkerProfiler() - + default_profiler_config.delay_iterations = 2 + default_profiler_config.max_iterations = 2 + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.start() # Step 1 @@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters(): assert profiler.stop_call_count == 1 -def test_idempotency(): +def test_idempotency(default_profiler_config): """Test that calling start/stop multiple times doesn't break logic.""" - profiler = ConcreteWorkerProfiler() + profiler = ConcreteWorkerProfiler(default_profiler_config) # Double Start profiler.start() @@ -142,10 +143,10 @@ def test_idempotency(): assert profiler.stop_call_count == 1 # Should only stop once -def test_step_inactive(): +def test_step_inactive(default_profiler_config): """Test that stepping while inactive does nothing.""" - envs.VLLM_PROFILER_DELAY_ITERS = 2 - profiler = ConcreteWorkerProfiler() + default_profiler_config.delay_iterations = 2 + profiler = ConcreteWorkerProfiler(default_profiler_config) # Not started yet profiler.step() @@ -155,9 +156,9 @@ def test_step_inactive(): assert profiler.start_call_count == 0 -def test_start_failure(): +def test_start_failure(default_profiler_config): """Test behavior when the underlying _start method raises exception.""" - profiler = ConcreteWorkerProfiler() + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.should_fail_start = True profiler.start() @@ -168,9 +169,9 @@ def test_start_failure(): assert profiler.start_call_count == 0 # Logic failed inside start -def test_shutdown(): +def test_shutdown(default_profiler_config): """Test that shutdown calls stop only if running.""" - profiler = ConcreteWorkerProfiler() + profiler = ConcreteWorkerProfiler(default_profiler_config) # Case 1: Not running profiler.shutdown() @@ -182,10 +183,10 @@ def test_shutdown(): assert profiler.stop_call_count == 1 -def test_mixed_delay_and_stop(): +def test_mixed_delay_and_stop(default_profiler_config): """Test manual stop during the delay period.""" - envs.VLLM_PROFILER_DELAY_ITERS = 5 - profiler = ConcreteWorkerProfiler() + default_profiler_config.delay_iterations = 5 + profiler = ConcreteWorkerProfiler(default_profiler_config) profiler.start() profiler.step() diff --git a/tools/ep_kernels/README.md b/tools/ep_kernels/README.md index 85e9d2a4f8129..ab0e358802bf8 100644 --- a/tools/ep_kernels/README.md +++ b/tools/ep_kernels/README.md @@ -7,7 +7,7 @@ Here we break down the requirements in 2 steps: 1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this. 2. Configure NVIDIA driver to enable IBGDA. This step requires root access, and must be done on the host machine. -2 is necessary for multi-node deployment. +Step 2 is necessary for multi-node deployment. All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`. @@ -23,6 +23,6 @@ TORCH_CUDA_ARCH_LIST="10.0" bash install_python_libraries.sh Additional step for multi-node deployment: ```bash -sudo bash configure_system_drivers.sh +sudo bash configure_system_drivers.sh # update-initramfs can take several minutes sudo reboot # Reboot is required to load the new driver ``` diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh index 88be5cd778fff..1bb7fd8345238 100755 --- a/tools/ep_kernels/install_python_libraries.sh +++ b/tools/ep_kernels/install_python_libraries.sh @@ -10,9 +10,10 @@ set -ex CUDA_HOME=${CUDA_HOME:-/usr/local/cuda} PPLX_COMMIT_HASH=${PPLX_COMMIT_HASH:-"12cecfd"} DEEPEP_COMMIT_HASH=${DEEPEP_COMMIT_HASH:-"73b6ea4"} -NVSHMEM_VER=3.3.9 +NVSHMEM_VER=3.3.24 # Suppports both CUDA 12 and 13 WORKSPACE=${WORKSPACE:-$(pwd)/ep_kernels_workspace} MODE=${MODE:-install} +CUDA_VERSION_MAJOR=$(${CUDA_HOME}/bin/nvcc --version | egrep -o "release [0-9]+" | cut -d ' ' -f 2) # Parse arguments while [[ $# -gt 0 ]]; do @@ -75,11 +76,9 @@ ARCH=$(uname -m) case "${ARCH,,}" in x86_64|amd64) NVSHMEM_SUBDIR="linux-x86_64" - NVSHMEM_FILE="libnvshmem-linux-x86_64-${NVSHMEM_VER}_cuda12-archive.tar.xz" ;; aarch64|arm64) NVSHMEM_SUBDIR="linux-sbsa" - NVSHMEM_FILE="libnvshmem-linux-sbsa-${NVSHMEM_VER}_cuda12-archive.tar.xz" ;; *) echo "Unsupported architecture: ${ARCH}" >&2 @@ -87,6 +86,7 @@ case "${ARCH,,}" in ;; esac +NVSHMEM_FILE="libnvshmem-${NVSHMEM_SUBDIR}-${NVSHMEM_VER}_cuda${CUDA_VERSION_MAJOR}-archive.tar.xz" NVSHMEM_URL="https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/${NVSHMEM_SUBDIR}/${NVSHMEM_FILE}" pushd "$WORKSPACE" @@ -142,13 +142,6 @@ clone_repo() { fi } -deepep_cuda13_patch() { - cuda_version_major=$(${CUDA_HOME}/bin/nvcc --version | egrep -o "release [0-9]+" | cut -d ' ' -f 2) - if [ ${cuda_version_major} -ge 13 ]; then - sed -i "s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '${CUDA_HOME}/include/cccl']|" "setup.py" - fi -} - do_build() { local repo=$1 local name=$2 @@ -160,8 +153,9 @@ do_build() { clone_repo "$repo" "$name" "$key" "$commit" cd "$name" - if [ "$name" == "DeepEP" ]; then - deepep_cuda13_patch + # DeepEP CUDA 13 patch + if [[ "$name" == "DeepEP" && "${CUDA_VERSION_MAJOR}" -ge 13 ]]; then + sed -i "s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '${CUDA_HOME}/include/cccl']|" "setup.py" fi if [ "$MODE" = "install" ]; then diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 724b393044266..3f7e0a069f869 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -43,6 +43,7 @@ FILES = [ "vllm/worker", "vllm/v1/core", "vllm/v1/engine", + "vllm/v1/executor", "vllm/v1/metrics", "vllm/v1/pool", "vllm/v1/sample", @@ -60,7 +61,6 @@ SEPARATE_GROUPS = [ "vllm/model_executor", # v1 related "vllm/v1/attention", - "vllm/v1/executor", "vllm/v1/kv_offload", "vllm/v1/spec_decode", "vllm/v1/structured_output", diff --git a/use_existing_torch.py b/use_existing_torch.py index fd4caa69ec9c1..e2d3f2ec81956 100644 --- a/use_existing_torch.py +++ b/use_existing_torch.py @@ -3,9 +3,7 @@ import glob -requires_files = glob.glob("requirements/*.txt") -requires_files += ["pyproject.toml"] -for file in requires_files: +for file in (*glob.glob("requirements/*.txt"), "pyproject.toml"): print(f">>> cleaning {file}") with open(file) as f: lines = f.readlines() @@ -17,5 +15,4 @@ for file in requires_files: f.write(line) else: print(line.strip()) - print(f"<<< done cleaning {file}") - print() + print(f"<<< done cleaning {file}\n") diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index a8f472d147a0d..c32bf04c71c1f 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -9,6 +9,8 @@ import vllm.envs as envs from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer +_FP8_DTYPE = current_platform.fp8_dtype() + def is_aiter_found() -> bool: from importlib.util import find_spec @@ -22,6 +24,15 @@ def is_aiter_found() -> bool: # we keep this global outside to not cause torch compile breaks. IS_AITER_FOUND = is_aiter_found() +# Can't use dtypes.fp8 directly inside an op +# because it returns wrong result on gfx942. +# This is a workaround to get the correct FP8 dtype. +# This might because that the get_gfx() is wrapped as a custom op. +if IS_AITER_FOUND: + from aiter import dtypes + + AITER_FP8_DTYPE = dtypes.fp8 + def if_aiter_supported(func: Callable) -> Callable: """Decorator that only executes the function if @@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable: return wrapper -def _rocm_aiter_group_fp8_quant_impl( - x: torch.Tensor, - group_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: - assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size" - from aiter import QuantType, dtypes, get_hip_quant - - aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) - return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8) - - -def _rocm_aiter_group_fp8_quant_fake( - x: torch.Tensor, - group_size: int, -) -> tuple[torch.Tensor, torch.Tensor]: - from aiter import dtypes - - M, N = x.shape - x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device) - out_bs = torch.empty( - ( - M, - (N + group_size - 1) // group_size, - ), - dtype=torch.float32, - device=x.device, - ) - return x_fp8, out_bs - - def _rocm_aiter_fused_moe_impl( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -283,6 +264,28 @@ def _rocm_aiter_grouped_topk_fake( pass +# Cache whether aiter supports FP8 MLA parameters +_AITER_MLA_SUPPORTS_FP8: bool | None = None + + +def _check_aiter_mla_fp8_support() -> bool: + """Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters.""" + global _AITER_MLA_SUPPORTS_FP8 + if _AITER_MLA_SUPPORTS_FP8 is None: + try: + import inspect + + from aiter.mla import mla_decode_fwd + + sig = inspect.signature(mla_decode_fwd) + _AITER_MLA_SUPPORTS_FP8 = ( + "q_scale" in sig.parameters and "kv_scale" in sig.parameters + ) + except Exception: + _AITER_MLA_SUPPORTS_FP8 = False + return _AITER_MLA_SUPPORTS_FP8 + + def _rocm_aiter_mla_decode_fwd_impl( q: torch.Tensor, kv_buffer: torch.Tensor, @@ -299,6 +302,16 @@ def _rocm_aiter_mla_decode_fwd_impl( ) -> None: from aiter.mla import mla_decode_fwd + kwargs = { + "sm_scale": sm_scale, + "logit_cap": logit_cap, + } + + # Only pass q_scale and kv_scale if the aiter library supports them + if _check_aiter_mla_fp8_support(): + kwargs["q_scale"] = q_scale + kwargs["kv_scale"] = kv_scale + mla_decode_fwd( q, kv_buffer.view(-1, 1, 1, q.shape[-1]), @@ -308,10 +321,7 @@ def _rocm_aiter_mla_decode_fwd_impl( kv_indices, kv_last_page_lens, max_seqlen_qo, - sm_scale=sm_scale, - logit_cap=logit_cap, - q_scale=q_scale, - kv_scale=kv_scale, + **kwargs, ) @@ -438,53 +448,324 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( return torch.empty_like(x), torch.empty_like(residual) +def _rocm_aiter_per_tensor_quant_impl( + x: torch.Tensor, + quant_dtype: torch.dtype, + scale: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.quant import per_tensor_quant_hip + + return per_tensor_quant_hip(x, scale, quant_dtype) + + +def _rocm_aiter_per_tensor_quant_fake( + x: torch.Tensor, + quant_dtype: torch.dtype, + scale: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x, dtype=quant_dtype), torch.empty( + 1, dtype=torch.float32, device=x.device + ) + + +def _rocm_aiter_per_token_quant_impl( + x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.quant import dynamic_per_token_scaled_quant + + assert quant_dtype in [torch.int8, _FP8_DTYPE] + + out_shape = x.shape + out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device) + if scale is None: + scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device) + dynamic_per_token_scaled_quant( + out, + x, + scale, + scale_ub=None, + shuffle_scale=False, + num_rows=None, + num_rows_factor=1, + ) + return out, scale + + +def _rocm_aiter_per_token_quant_fake( + x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + out_shape = x.shape + return ( + torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device), + torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device), + ) + + +def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( + x, + weight, + variance_epsilon, + None, + None, + None, + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + res1=residual, + ) + return (x_quant, x_quant_scales, res) + + +def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M, N = x.shape + scale_shape = (M, (N + group_size - 1) // group_size) + return ( + torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), + torch.empty_like(residual, device=residual.device), + ) + + +def _rocm_aiter_rmsnorm_fp8_group_quant_impl( + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + (x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant( + x, + weight, + variance_epsilon, + None, + None, + None, + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + res1=None, + ) + return (x_quant, x_quant_scales) + + +def _rocm_aiter_rmsnorm_fp8_group_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + scale_shape = (M, (N + group_size - 1) // group_size) + return ( + torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), + ) + + +def _rocm_aiter_group_fp8_quant_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size" + from aiter import QuantType, get_hip_quant + + aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128) + return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE) + + +def _rocm_aiter_group_fp8_quant_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device) + out_bs = torch.empty( + ( + M, + (N + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + +def _rocm_aiter_act_mul_and_fp8_group_quant_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.triton.activation import act_mul_and_fp8_group_quant + + return act_mul_and_fp8_group_quant( + x, + activation="silu", + group_size=group_size, + dtype_quant=AITER_FP8_DTYPE, + ) + + +def _rocm_aiter_act_mul_and_fp8_group_quant_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + assert N % 2 == 0 + N_half = N // 2 + x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device) + out_bs = torch.empty( + ( + M, + (N_half + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False class rocm_aiter_ops: + """ROCm AITER operations wrapper for AMD GPU acceleration in vLLM. + + This class centralizes the import and registration of AITER ops, + and provides a unified interface for checking if AITER is enabled. + Operations are only available on supported gfx9 + architectures when aiter is installed. + + The class uses environment variables to control which features are enabled, + allowing fine-grained control over which AITER optimizations are used. + + Environment Variables: + VLLM_ROCM_USE_AITER: Main toggle for all AITER operations. + VLLM_ROCM_USE_AITER_LINEAR: Controls GEMM and quantization ops. + VLLM_ROCM_USE_AITER_RMSNORM: Controls RMSNorm operations. + VLLM_ROCM_USE_AITER_MOE: Controls MoE (Mixture of Experts) ops. + VLLM_ROCM_USE_AITER_MLA: Controls MLA (Multi-head Latent Attention) ops. + VLLM_ROCM_USE_AITER_MHA: Controls MHA ops including flash_attn_varlen. + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: Controls Triton unified attention. + VLLM_ROCM_USE_AITER_FP8BMM: Controls FP8 batched matrix multiply. + VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: Controls FP4 assembly GEMM. + VLLM_ROCM_USE_AITER_TRITON_ROPE: Controls Triton rotary embeddings. + VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: Controls shared expert fusion. + VLLM_ROCM_USE_AITER_TRITON_GEMM: Controls Triton unquantized GEMM. + + Note: + The environment variables are assigned when the module is imported, + so you can't change the environment variables after the module is imported. + This is done out of performance consideration. Accessing environment variables + is expensive as described in issue https://github.com/vllm-project/vllm/issues/17067 + so we don't want to do it repeatedly, especially in the hot path (the forward pass). + You can call the refresh_env_variables() function to reload the env variables + after monkey patching the env variables in the unit test. + + Check Functions: + All check functions (is_*_enabled) are decorated with @if_aiter_supported, + which verifies: (1) platform is ROCm, (2) device arch is gfx9, and + (3) aiter library is installed. The check function then also verifies + the corresponding environment variable is enabled. + i.e. ___ + is_enabled() == current_platform.is_rocm() and | checked by + current_platform.is_on_gfx9() and | @if_aiter_supported + IS_AITER_FOUND and _______________| + cls._AITER_ENABLED -----> Check by the logic in `is_enabled()` + + Example: + from vllm._aiter_ops import rocm_aiter_ops + + # Check if aiter is enabled before using operations + if rocm_aiter_ops.is_enabled(): + result = rocm_aiter_ops.rms_norm(x, weight, epsilon) + + Operations: + - RMS normalization: rms_norm, rms_norm2d_with_add + - GEMM operations: gemm_a8w8, gemm_a8w8_blockscale + - Fused MoE: fused_moe, asm_moe_tkw1 + - Routing: topk_softmax, biased_grouped_topk, grouped_topk + - MLA decode: mla_decode_fwd + - Quantization: per_tensor_quant, per_token_quant, group_fp8_quant + - Triton ops: triton_rotary_embed, triton_fp8_bmm, triton_gemm_a8w8_blockscale + """ + + # Check if the env variable is set _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA - _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN _MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA _TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + @classmethod + def refresh_env_variables(cls): + """ + Since the environment variables are assigned when the module is imported, + This is a helper function to reload all the env variables from + the environment variables. + for example, after monkey patching the env variables in the unit test, + you can call this function to reload the env variables. + """ + cls._AITER_ENABLED = envs.VLLM_ROCM_USE_AITER + cls._LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + cls._RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM + cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE + cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA + cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA + cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM + cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM + cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE + cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + @classmethod @if_aiter_supported def is_enabled(cls) -> bool: - """Verifies device specs and availability of aiter main env variable.""" return cls._AITER_ENABLED @classmethod @if_aiter_supported def is_linear_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._LINEAR_ENABLED @classmethod @if_aiter_supported def is_linear_fp8_enaled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" - return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() + return cls.is_linear_enabled() @classmethod @if_aiter_supported def is_rmsnorm_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._RMSNORM_ENABLED @classmethod @if_aiter_supported def is_fused_moe_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._FMOE_ENABLED @classmethod @@ -495,25 +776,16 @@ class rocm_aiter_ops: @classmethod @if_aiter_supported def is_mla_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._MLA_ENABLED @classmethod @if_aiter_supported def is_mha_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._MHA_ENABLED - @classmethod - @if_aiter_supported - def is_pa_attn_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" - return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED - @classmethod @if_aiter_supported def is_triton_unified_attn_enabled(cls) -> bool: - """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED @classmethod @@ -548,14 +820,6 @@ class rocm_aiter_ops: ) # register all the custom ops here - direct_register_custom_op( - op_name="rocm_aiter_group_fp8_quant", - op_func=_rocm_aiter_group_fp8_quant_impl, - mutates_args=[], - fake_impl=_rocm_aiter_group_fp8_quant_fake, - dispatch_key=current_platform.dispatch_key, - ) - direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=_rocm_aiter_asm_moe_tkw1_impl, @@ -615,27 +879,62 @@ class rocm_aiter_ops: direct_register_custom_op( op_name="rocm_aiter_gemm_a8w8_blockscale", op_func=_rocm_aiter_gemm_a8w8_blockscale_impl, - mutates_args=[], fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_rms_norm", op_func=_rocm_aiter_rms_norm_impl, - mutates_args=[], fake_impl=_rocm_aiter_rms_norm_fake, - dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, - mutates_args=[], fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_fp8_group_quant", + op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant", + op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl, + fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_act_mul_and_fp8_group_quant", + op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl, + fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_group_fp8_quant", + op_func=_rocm_aiter_group_fp8_quant_impl, + fake_impl=_rocm_aiter_group_fp8_quant_fake, + ) + + direct_register_custom_op( + op_name="rocm_aiter_per_tensor_quant", + op_func=_rocm_aiter_per_tensor_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_per_tensor_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_per_token_quant", + op_func=_rocm_aiter_per_token_quant_impl, + mutates_args=["scale"], + fake_impl=_rocm_aiter_per_token_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod @@ -830,6 +1129,22 @@ class rocm_aiter_ops: kv_scale=kv_scale, ) + @staticmethod + def per_tensor_quant( + x: torch.Tensor, + quant_dtype: torch.dtype, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale) + + @staticmethod + def per_token_quant( + x: torch.Tensor, + quant_dtype: torch.dtype, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale) + @staticmethod def triton_fp4_gemm_dynamic_qaunt( x: torch.Tensor, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e60158898685a..2319655008c50 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -436,6 +436,46 @@ def rms_norm_dynamic_per_token_quant( return output, scales +# fused quant layer norm ops blocked +def rms_norm_per_block_quant( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + group_size: list[int], + scale_ub: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + is_scale_transposed: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + assert len(group_size) == 2 + output = torch.empty_like(input, dtype=quant_dtype) + if is_scale_transposed: + scales = torch.empty( + (input.shape[-1] // group_size[1], input.numel() // input.shape[-1]), + device=input.device, + dtype=torch.float32, + ).transpose(0, 1) + else: + scales = torch.empty( + (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]), + device=input.device, + dtype=torch.float32, + ) + + torch.ops._C.rms_norm_per_block_quant( + output, + input, + weight, + scales, + epsilon, + scale_ub, + residual, + group_size[1], + is_scale_transposed, + ) + return output, scales + + # quantization ops # awq def awq_dequantize( @@ -458,15 +498,15 @@ def awq_dequantize( def awq_gemm( input: torch.Tensor, qweight: torch.Tensor, - qzeros: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor, split_k_iters: int, ) -> torch.Tensor: if envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton - return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) - return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) + return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters) # gptq @@ -592,8 +632,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): def _awq_gemm_fake( input: torch.Tensor, qweight: torch.Tensor, - qzeros: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor, split_k_iters: torch.SymInt, ) -> torch.Tensor: num_in_feats = input.size(0) @@ -655,6 +695,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: return torch.empty_like(b, memory_format=torch.contiguous_format) + @register_fake("_C::cutlass_encode_and_reorder_int4b_grouped") + def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(b, memory_format=torch.contiguous_format) + if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @@ -1018,6 +1062,7 @@ def get_cutlass_moe_mm_problem_sizes( n: int, k: int, blockscale_offsets: torch.Tensor | None = None, + force_swap_ab: bool | None = None, ): """ Compute only the per-expert problem sizes needed by the two grouped matrix @@ -1027,9 +1072,20 @@ def get_cutlass_moe_mm_problem_sizes( - problem_sizes1, problem_sizes2: M×N×K sizes of each expert's multiplication for the two grouped MMs used in the fused MoE operation. + Optional: + - force_swap_ab: If set to True or False, explicitly enable or disable the + A/B input swap optimization. If None (default), the swap + is selected automatically based on tensor sizes. """ return torch.ops._C.get_cutlass_moe_mm_problem_sizes( - topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets + topk_ids, + problem_sizes1, + problem_sizes2, + num_experts, + n, + k, + blockscale_offsets, + force_swap_ab, ) @@ -1417,6 +1473,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: return torch.ops._C.cutlass_encode_and_reorder_int4b(b) +def cutlass_w4a8_moe_mm( + out_tensors: torch.Tensor, + a_tensors: torch.Tensor, + b_tensors: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, + a_strides: torch.Tensor, + b_strides: torch.Tensor, + c_strides: torch.Tensor, + group_scale_strides: torch.Tensor, + maybe_schedule: str | None = None, +): + """ + Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the + W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8) + and both per-channel + per-token scaling in the epilogue. + + Args: + out_tensors: + Output buffer for all experts (updated in-place). + a_tensors: + FP8 (E4M3FN) activations for all experts. + b_tensors: + INT4-packed weight matrix for all experts, packed to INT32 + a_scales: + Per-token FP8 activation scales, applied in the epilogue. + b_scales: + Per-channel FP8 weight scales for each expert, applied in the epilogue. + b_group_scales: + FP8 scale values for group-wise INT4 weight blocks. + b_group_size: + Number of elements grouped under each entry of b_group_scales. + expert_offsets: + Cumulative token offsets + problem_sizes: + Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher. + a/b/c/group_scale_strides: + Strides describing the memory layout of the input tensors. + maybe_schedule: + Optional override to choose a specific kernel or epilogue schedule. + + Returns: + out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result. + """ + return torch.ops._C.cutlass_w4a8_moe_mm( + out_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + b_group_scales, + b_group_size, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + c_strides, + group_scale_strides, + maybe_schedule, + ) + + +def cutlass_encode_and_reorder_int4b_grouped( + b_tensors: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors) + + if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") @@ -1598,7 +1726,7 @@ def scaled_fp8_quant( output, input, scale, scale_ub ) else: - scale = torch.empty((1, 1), device=input.device, dtype=torch.float32) + scale = torch.empty(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: assert scale.numel() == 1, f"{scale.shape}" @@ -1877,6 +2005,7 @@ def moe_align_block_size( sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, + expert_map: torch.Tensor | None = None, ) -> None: torch.ops._moe_C.moe_align_block_size( topk_ids, @@ -1885,6 +2014,7 @@ def moe_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, + expert_map, ) @@ -1919,6 +2049,7 @@ def moe_lora_align_block_size( num_tokens_post_pad: torch.Tensor, adapter_enabled: torch.Tensor, lora_ids: torch.Tensor, + expert_map: torch.Tensor | None = None, ) -> None: torch.ops._moe_C.moe_lora_align_block_size( topk_ids, @@ -1933,6 +2064,7 @@ def moe_lora_align_block_size( num_tokens_post_pad, adapter_enabled, lora_ids, + expert_map, ) @@ -2271,6 +2403,29 @@ def cp_gather_cache( ) +def cp_gather_and_upconvert_fp8_kv_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor, + workspace_starts: torch.Tensor, + batch_size: int, +) -> None: + """Gather and upconvert FP8 KV cache to BF16 workspace. + + Args: + src_cache: FP8 KV cache [num_blocks, block_size, 656] + dst: BF16 output workspace [total_tokens, 576] + block_table: Block indices [num_reqs, max_blocks] + seq_lens: Sequence lengths [num_reqs] + workspace_starts: Workspace start offsets [num_reqs] + batch_size: Number of requests + """ + torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache( + src_cache, dst, block_table, seq_lens, workspace_starts, batch_size + ) + + def indexer_k_quant_and_cache( k: torch.Tensor, kv_cache: torch.Tensor, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index c290670eeacb0..025ede1eb0a4e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -166,6 +166,10 @@ class AttentionBackend(ABC): def supports_sink(cls) -> bool: return False + @classmethod + def supports_mm_prefix(cls) -> bool: + return False + @classmethod def is_sparse(cls) -> bool: return False @@ -207,6 +211,7 @@ class AttentionBackend(ABC): use_mla: bool, has_sink: bool, use_sparse: bool, + use_mm_prefix: bool, device_capability: "DeviceCapability", attn_type: str, ) -> list[str]: @@ -219,6 +224,10 @@ class AttentionBackend(ABC): invalid_reasons.append("kv_cache_dtype not supported") if not cls.supports_block_size(block_size): invalid_reasons.append("block_size not supported") + if use_mm_prefix and not cls.supports_mm_prefix(): + invalid_reasons.append( + "partial multimodal token full attention not supported" + ) if use_mla != cls.is_mla(): if use_mla: invalid_reasons.append("MLA not supported") @@ -285,10 +294,26 @@ class AttentionImpl(ABC, Generic[T]): # Some features like decode context parallelism require the softmax lse. can_return_lse_for_decode: bool = False + # Whether the attention impl supports Prefill Context Parallelism. + supports_pcp: bool = False + # Whether the attention impl(or ops) supports MTP + # when cp_kv_cache_interleave_size > 1 + supports_mtp_with_cp_non_trivial_interleave_size: bool = False + # some attention backends might not always want to return lse # even if they can return lse (for efficiency reasons) need_to_return_lse_for_decode: bool = False + # Whether this attention implementation supports pre-quantized query input. + # When True, the attention layer will quantize queries before passing them + # to this backend, allowing torch.compile to fuse the quantization with + # previous operations. This is typically supported when using FP8 KV cache + # with compatible attention kernels (e.g., TRT-LLM). + # Subclasses should set this in __init__. + # TODO add support to more backends: + # https://github.com/vllm-project/vllm/issues/25584 + supports_quant_query_input: bool = False + dcp_world_size: int dcp_rank: int @@ -368,22 +393,6 @@ class AttentionImpl(ABC, Generic[T]): """ return False - def supports_quant_query_input(self) -> bool: - """ - Check if this attention implementation supports pre-quantized query input. - - When True, the attention layer will quantize queries before passing them - to this backend, allowing torch.compile to fuse the quantization with - previous operations. This is typically supported when using FP8 KV cache - with compatible attention kernels (e.g., TRT-LLM). - TODO add support to more backends: - https://github.com/vllm-project/vllm/issues/25584 - - Returns: - bool: True if the implementation can accept pre-quantized queries. - """ - return False - def process_weights_after_loading(self, act_dtype: torch.dtype): pass diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 125e4e3827747..eaa0fa1d5db39 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -252,35 +252,3 @@ def register_backend( return lambda x: x return decorator - - -# Backwards compatibility alias for plugins -class _BackendMeta(type): - """Metaclass to provide deprecation warnings when accessing _Backend.""" - - def __getattribute__(cls, name: str): - if name not in ("__class__", "__mro__", "__name__"): - logger.warning( - "_Backend has been renamed to AttentionBackendEnum. " - "Please update your code to use AttentionBackendEnum instead. " - "_Backend will be removed in a future release." - ) - return getattr(AttentionBackendEnum, name) - - def __getitem__(cls, name: str): - logger.warning( - "_Backend has been renamed to AttentionBackendEnum. " - "Please update your code to use AttentionBackendEnum instead. " - "_Backend will be removed in a future release." - ) - return AttentionBackendEnum[name] - - -class _Backend(metaclass=_BackendMeta): - """Deprecated: Use AttentionBackendEnum instead. - - This class is provided for backwards compatibility with plugins - and will be removed in a future release. - """ - - pass diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index da5a626171298..7ef77db8fbb5b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from collections.abc import Callable +import functools from typing import cast import torch @@ -16,7 +16,9 @@ from vllm.attention.backends.abstract import ( MLAAttentionImpl, ) from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend from vllm.attention.selector import get_attn_backend +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.config import CacheConfig, get_current_vllm_config @@ -25,6 +27,7 @@ from vllm.config.vllm import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.linear import ( ColumnParallelLinear, UnquantizedLinearMethod, @@ -46,55 +49,9 @@ from vllm.v1.kv_cache_interface import ( SlidingWindowSpec, ) -if current_platform.is_rocm(): - from vllm.platforms.rocm import on_gfx9 -else: - on_gfx9 = lambda *args, **kwargs: False - - -FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) -def maybe_get_vit_flash_attn_backend( - attn_backend: AttentionBackendEnum, - attn_backend_override: AttentionBackendEnum | None = None, -) -> tuple[AttentionBackendEnum, Callable | None]: - if current_platform.is_rocm(): - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - attn_backend = AttentionBackendEnum.ROCM_AITER_FA - elif ( - attn_backend_override is None - and on_gfx9() - and attn_backend == AttentionBackendEnum.FLASH_ATTN - ): - pass - else: - return AttentionBackendEnum.TORCH_SDPA, None - elif current_platform.is_cuda(): - pass - elif current_platform.is_xpu(): - assert attn_backend == AttentionBackendEnum.FLASH_ATTN, ( - "XPU platform only supports FLASH_ATTN as vision attention backend." - ) - pass - else: - return AttentionBackendEnum.TORCH_SDPA, None - - if attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - }: - if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - from vllm.attention.utils.fa_utils import flash_attn_varlen_func - else: - flash_attn_varlen_func = None - - return attn_backend, flash_attn_varlen_func - - def _init_kv_cache_quant( layer: nn.Module, quant_config: QuantizationConfig | None, @@ -230,6 +187,10 @@ class Attention(nn.Module, AttentionLayerBase): self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None + # NOTE: model_config may be None during certain tests + model_config = vllm_config.model_config + self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm + # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() @@ -241,11 +202,30 @@ class Attention(nn.Module, AttentionLayerBase): block_size, use_mla=False, has_sink=self.has_sink, + use_mm_prefix=self.use_mm_prefix, attn_type=attn_type, ) else: self.attn_backend = attn_backend + # prefix caching + batch invariance is currently not supported for + # FLASHINFER and TRITON_MLA. + if ( + cache_config is not None + and cache_config.enable_prefix_caching + and vllm_is_batch_invariant() + and ( + self.attn_backend.get_name() == "FLASHINFER" + or self.attn_backend.get_name() == "TRITON_MLA" + ) + ): + logger.warning_once( + "Disabling prefix caching for FLASHINFER/TRITON_MLA " + "with batch invariance, as it is not yet supported.", + scope="local", + ) + cache_config.enable_prefix_caching = False + impl_cls = self.attn_backend.get_impl_cls() self.impl = impl_cls( num_heads, @@ -303,7 +283,7 @@ class Attention(nn.Module, AttentionLayerBase): self.query_quant = None if ( self.kv_cache_dtype.startswith("fp8") - and self.impl.supports_quant_query_input() + and self.impl.supports_quant_query_input ): self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) @@ -338,7 +318,7 @@ class Attention(nn.Module, AttentionLayerBase): assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} # check if query quantization is supported - if self.impl.supports_quant_query_input(): + if self.impl.supports_quant_query_input: query, _ = self.query_quant(query, self._q_scale) if self.use_output: @@ -467,29 +447,15 @@ class MultiHeadAttention(nn.Module): attn_backend_override = None if multimodal_config is not None: attn_backend_override = multimodal_config.mm_encoder_attn_backend - backend = get_vit_attn_backend( + + self.attn_backend = get_vit_attn_backend( head_size=head_size, dtype=dtype, attn_backend_override=attn_backend_override, ) - self.attn_backend = ( - backend - if backend - in { - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.PALLAS, - AttentionBackendEnum.ROCM_AITER_FA, - AttentionBackendEnum.FLASH_ATTN, - } - else AttentionBackendEnum.TORCH_SDPA - ) - - self.attn_backend, self._flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, ) self.is_flash_attn_backend = self.attn_backend in { @@ -497,6 +463,17 @@ class MultiHeadAttention(nn.Module): AttentionBackendEnum.ROCM_AITER_FA, } + self.fa_version = None + if ( + self.attn_backend == AttentionBackendEnum.FLASH_ATTN + and current_platform.is_cuda() + ): + self.fa_version = get_flash_attn_version() + assert self._flash_attn_varlen_func is not None + self._flash_attn_varlen_func = functools.partial( + self._flash_attn_varlen_func, fa_version=self.fa_version + ) + logger.info_once( f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder." ) @@ -623,6 +600,23 @@ class MLAAttention(nn.Module, AttentionLayerBase): use_mla=True, use_sparse=use_sparse, ) + + if ( + cache_config is not None + and cache_config.enable_prefix_caching + and vllm_is_batch_invariant() + and ( + self.attn_backend.get_name() == "TRITON_MLA" + or self.attn_backend.get_name() == "FLASHINFER" + ) + ): + logger.warning_once( + "Disabling prefix caching for TRITON_MLA / FLASHINFER " + "with batch invariance, as it is not yet supported.", + scope="local", + ) + cache_config.enable_prefix_caching = False + impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) self.impl = impl_cls( num_heads=self.num_heads, diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py index 068fd0a0eb7d0..cfd203bdd37b9 100644 --- a/vllm/attention/layers/cross_attention.py +++ b/vllm/attention/layers/cross_attention.py @@ -103,7 +103,7 @@ def create_cross_attention_backend( # needed here to know how many tokens to attend to from the cached # cross-attention KV cache. new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens - new_metadata.seq_lens_cpu = torch.from_numpy( + new_metadata._seq_lens_cpu = torch.from_numpy( common_attn_metadata.encoder_seq_lens_cpu ) diff --git a/vllm/attention/layers/mm_encoder_attention.py b/vllm/attention/layers/mm_encoder_attention.py new file mode 100644 index 0000000000000..c9107ebcab856 --- /dev/null +++ b/vllm/attention/layers/mm_encoder_attention.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable + +import torch + +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.ops.vit_attn_wrappers import ( + vit_flash_attn_wrapper, + vit_torch_sdpa_wrapper, +) +from vllm.config import MultiModalConfig +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.models.vision import get_vit_attn_backend + +logger = init_logger(__name__) + + +def maybe_get_vit_flash_attn_backend( + attn_backend: AttentionBackendEnum | None, +) -> Callable | None: + # At this point, + # we already have the attn_backend, + # overriding logic is done in the platform-specific implementation. + # so we don't need to override backend here. + # Just return the attn_backend and flash_attn_varlen_func. + + if attn_backend == AttentionBackendEnum.FLASH_ATTN: + from vllm.attention.utils.fa_utils import flash_attn_varlen_func + elif attn_backend == AttentionBackendEnum.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + flash_attn_varlen_func = None + + # if attn_backend is TORCH_SDPA, + # it will reach here and the flash_attn_varlen_func will be None. + return flash_attn_varlen_func + + +@CustomOp.register("mm_encoder_attn") +class MMEncoderAttention(CustomOp): + """Multi-headed attention without any cache, used for multimodal encoder.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float | None = None, + num_kv_heads: int | None = None, + prefix: str = "", + multimodal_config: MultiModalConfig | None = None, + ) -> None: + """ + Args: + num_heads: number of attention heads per partition. + head_size: hidden_size per attention head. + scale: scale factor. + num_kv_heads: number of kv heads. + prefix: This has no effect, it is only here to make it easier to + swap between Attention and MultiHeadAttention + multimodal_config: configs for multi-modal. + """ + super().__init__() + + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.layer_name = prefix + + assert self.num_heads % self.num_kv_heads == 0, ( + f"num_heads ({self.num_heads}) is not " + f"divisible by num_kv_heads ({self.num_kv_heads})" + ) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + + # Try to get vision attention backend from multimodal_config. + attn_backend_override = None + if multimodal_config is not None: + attn_backend_override = multimodal_config.mm_encoder_attn_backend + + # Get device-specific vision attention backend. + self.attn_backend = get_vit_attn_backend( + head_size=head_size, + dtype=dtype, + attn_backend_override=attn_backend_override, + ) + + self.is_flash_attn_backend = self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + } + + self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, + ) + + logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") + + @classmethod + def enabled(cls) -> bool: + return True + + def reshape_qkv_to_4d( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + bsz: int, + q_len: int, + kv_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Reshape query, key, value to 4D tensors: + (batch_size, seq_len, num_heads, head_size) + """ + query = query.view(bsz, q_len, self.num_heads, self.head_size) + key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) + + if (num_repeat := self.num_queries_per_kv) > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_repeat, dim=2) + value = torch.repeat_interleave(value, num_repeat, dim=2) + + return query, key, value + + def reshape_qkv_to_3d( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + bsz: int, + q_len: int, + kv_len: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Reshape query, key, value to 3D tensors: + (batch_size * seq_len, num_heads, head_size) + """ + query = query.view(bsz * q_len, self.num_heads, self.head_size) + key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size) + + if (num_repeat := self.num_queries_per_kv) > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_repeat, dim=1) + value = torch.repeat_interleave(value, num_repeat, dim=1) + + return query, key, value + + def _forward_sdpa( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + ) -> torch.Tensor: + # TODO(Isotr0py): Migrate MultiHeadAttention + assert cu_seqlens is not None + + bsz, q_len = query.size()[:2] + kv_len = key.size(1) + + query, key, value = self.reshape_qkv_to_4d( + query, key, value, bsz, q_len, kv_len + ) + + output = vit_torch_sdpa_wrapper( + q=query, + k=key, + v=value, + cu_seqlens=cu_seqlens, + ) + return output + + def _forward_fa( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + assert self.flash_attn_varlen_func is not None, ( + "Flash attention function is not set." + ) + # # TODO(Isotr0py): Migrate MultiHeadAttention + assert cu_seqlens is not None and max_seqlen is not None + + bsz = query.shape[0] + + output = vit_flash_attn_wrapper( + q=query, + k=key, + v=value, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + batch_size=bsz, + is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), + ) + return output + + def forward_native( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + return self._forward_sdpa(query, key, value, cu_seqlens) + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + if self.is_flash_attn_backend: + return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: + return self._forward_sdpa(query, key, value, cu_seqlens) + else: + raise ValueError( + f"Unsupported multi-modal encoder attention backend for CUDA: " + f"{self.attn_backend}." + ) + + def forward_cpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + return self._forward_sdpa(query, key, value, cu_seqlens) + + def forward_xpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + assert self.is_flash_attn_backend, ( + "XPU only supports FLASH_ATTN for vision attention." + ) + return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) + + def forward_tpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + assert self.attn_backend == AttentionBackendEnum.PALLAS, ( + f"MMEncoderAttention on TPU only supports PALLAS backend, " + f"but got {self.attn_backend}." + ) + if cu_seqlens is None: + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + from torch_xla.experimental.custom_kernel import flash_attention + + out = flash_attention(query, key, value, sm_scale=self.scale) + out = out.transpose(1, 2) + return out + logger.warning_once( + "PALLAS backend with cu_seqlens is not supported for ViT yet. ", + "Falling back to SDPA implementation.", + ) + return self._forward_sdpa(query, key, value, cu_seqlens) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 565be1c39bec1..a1877bb4429b9 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -355,7 +355,7 @@ def kernel_unified_attention_2d( @triton.jit def kernel_unified_attention_3d( segm_output_ptr, - # [num_tokens, num_query_heads, num_segments, head_size] + # [num_tokens, num_query_heads, num_segments, head_size_padded] segm_max_ptr, # [num_tokens, num_query_heads, num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] query_ptr, # [num_tokens, num_query_heads, head_size] @@ -749,6 +749,11 @@ def unified_attention( q_descale, k_descale, v_descale, + seq_threshold_3D=None, + num_par_softmax_segments=None, + softmax_segm_output=None, + softmax_segm_max=None, + softmax_segm_expsum=None, alibi_slopes=None, output_scale=None, qq_bias=None, @@ -793,8 +798,19 @@ def unified_attention( TILE_SIZE_PREFILL = 32 TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 - # if batch contains a prefill - if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + # Launch the 2D kernel if + # 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or + # 2. The batch includes at least one prefill request, or + # 3. The number of sequences exceeds the configured threshold + if ( + seq_threshold_3D is None + or num_par_softmax_segments is None + or softmax_segm_output is None + or softmax_segm_max is None + or softmax_segm_expsum is None + or max_seqlen_q > 1 + or num_seqs > seq_threshold_3D + ): kernel_unified_attention_2d[ ( total_num_q_blocks, @@ -847,37 +863,12 @@ def unified_attention( USE_FP8=output_scale is not None, ) else: - # for initial version, NUM_SEGMENTS = 16 is chosen as a default - # value that showed good performance in tests - NUM_SEGMENTS = 16 - - segm_output = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - triton.next_power_of_2(head_size), - dtype=torch.float32, - device=q.device, - ) - segm_max = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - segm_expsum = torch.empty( - q.shape[0], - num_query_heads, - NUM_SEGMENTS, - dtype=torch.float32, - device=q.device, - ) - - kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, + kernel_unified_attention_3d[ + (total_num_q_blocks, num_kv_heads, num_par_softmax_segments) + ]( + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, @@ -917,13 +908,13 @@ def unified_attention( BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, ) reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, + segm_output_ptr=softmax_segm_output, + segm_max_ptr=softmax_segm_max, + segm_expsum_ptr=softmax_segm_expsum, seq_lens_ptr=seqused_k, num_seqs=num_seqs, num_query_heads=num_query_heads, @@ -936,6 +927,6 @@ def unified_attention( HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, USE_FP8=output_scale is not None, ) diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index d9f15f1e42858..46c7d83dfa5c2 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -44,9 +44,7 @@ def flash_attn_maxseqlen_wrapper( dropout_p=0.0, causal=False, ) - context_layer = einops.rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() + context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) return context_layer @@ -59,8 +57,7 @@ def flash_attn_maxseqlen_wrapper_fake( batch_size: int, is_rocm_aiter: bool, ) -> torch.Tensor: - b, s, h, d = q.shape - return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + return torch.empty_like(q) direct_register_custom_op( @@ -93,12 +90,12 @@ def torch_sdpa_wrapper( cu_seqlens: torch.Tensor, ) -> torch.Tensor: outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] + + lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + q_chunks = torch.split(q, lens, dim=1) + k_chunks = torch.split(k, lens, dim=1) + v_chunks = torch.split(v, lens, dim=1) + for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): q_i, k_i, v_i = ( einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] ) @@ -106,7 +103,6 @@ def torch_sdpa_wrapper( output_i = einops.rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) - context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous() return context_layer @@ -116,8 +112,7 @@ def torch_sdpa_wrapper_fake( v: torch.Tensor, cu_seqlens: torch.Tensor, ) -> torch.Tensor: - b, s, h, d = q.shape - return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) + return torch.empty_like(q) direct_register_custom_op( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index a7190df3c4f10..e66f698add99d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,20 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import inspect -import os -from collections.abc import Generator -from contextlib import contextmanager from functools import cache -from typing import cast, get_args +from typing import NamedTuple, cast, get_args import torch -import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.registry import ( MAMBA_TYPE_TO_BACKEND_MAP, - AttentionBackendEnum, MambaAttentionBackendEnum, ) from vllm.config.cache import CacheDType @@ -24,58 +18,29 @@ from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) -def get_env_variable_attn_backend() -> AttentionBackendEnum | None: - """ - Get the backend override specified by the vLLM attention - backend environment variable, if one is specified. +class AttentionSelectorConfig(NamedTuple): + head_size: int + dtype: torch.dtype + kv_cache_dtype: CacheDType | None + block_size: int | None + use_mla: bool = False + has_sink: bool = False + use_sparse: bool = False + use_mm_prefix: bool = False + attn_type: str = AttentionType.DECODER - Returns: - - * AttentionBackendEnum value if an override is specified - * None otherwise - """ - backend_name = os.environ.get("VLLM_ATTENTION_BACKEND") - if backend_name is None: - return None - if backend_name == "XFORMERS": - raise ValueError( - "Attention backend 'XFORMERS' has been removed (See PR #29262 for " - "details). Please select a supported attention backend." + def __repr__(self): + return ( + f"AttentionSelectorConfig(head_size={self.head_size}, " + f"dtype={self.dtype}, " + f"kv_cache_dtype={self.kv_cache_dtype}, " + f"block_size={self.block_size}, " + f"use_mla={self.use_mla}, " + f"has_sink={self.has_sink}, " + f"use_sparse={self.use_sparse}, " + f"use_mm_prefix={self.use_mm_prefix}, " + f"attn_type={self.attn_type})" ) - return AttentionBackendEnum[backend_name] - - -# Global state allows a particular choice of backend -# to be forced, overriding the logic which auto-selects -# a backend based on system & workload configuration -# (default behavior if this variable is None) -# -# THIS SELECTION TAKES PRECEDENCE OVER THE -# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE -forced_attn_backend: AttentionBackendEnum | None = None - - -def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None: - """ - Force all attention operations to use a specified backend. - - Passing `None` for the argument re-enables automatic - backend selection., - - Arguments: - - * attn_backend: backend selection (None to revert to auto) - """ - global forced_attn_backend - forced_attn_backend = attn_backend - - -def get_global_forced_attn_backend() -> AttentionBackendEnum | None: - """ - Get the currently-forced choice of attention backend, - or None if auto-selection is currently enabled. - """ - return forced_attn_backend def get_attn_backend( @@ -86,6 +51,7 @@ def get_attn_backend( use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, + use_mm_prefix: bool = False, attn_type: str | None = None, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -97,7 +63,12 @@ def get_attn_backend( f"Valid values are: {valid_cache_dtypes}" ) - return _cached_get_attn_backend( + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + backend_enum = vllm_config.attention_config.backend + + attn_selector_config = AttentionSelectorConfig( head_size=head_size, dtype=dtype, kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype), @@ -105,87 +76,27 @@ def get_attn_backend( use_mla=use_mla, has_sink=has_sink, use_sparse=use_sparse, - attn_type=attn_type, + use_mm_prefix=use_mm_prefix, + attn_type=attn_type or AttentionType.DECODER, + ) + + return _cached_get_attn_backend( + backend=backend_enum, + attn_selector_config=attn_selector_config, ) @cache def _cached_get_attn_backend( - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: CacheDType | None, - block_size: int | None, - use_mla: bool = False, - has_sink: bool = False, - use_sparse: bool = False, - attn_type: str | None = None, + backend, + attn_selector_config: AttentionSelectorConfig, ) -> type[AttentionBackend]: - # Check whether a particular choice of backend was - # previously forced. - # - # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND - # ENVIRONMENT VARIABLE. - selected_backend = None - backend_by_global_setting: AttentionBackendEnum | None = ( - get_global_forced_attn_backend() - ) - if backend_by_global_setting is not None: - selected_backend = backend_by_global_setting - else: - # Check the environment variable and override if specified - backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - if backend_by_env_var.endswith("_VLLM_V1"): - logger.warning( - "The suffix '_VLLM_V1' in the environment variable " - "VLLM_ATTENTION_BACKEND is no longer necessary as " - "V0 backends have been deprecated. " - "Please remove this suffix from your " - "environment variable setting.", - ) - backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") - try: - selected_backend = AttentionBackendEnum[backend_by_env_var] - except KeyError as e: - raise ValueError( - f"Invalid attention backend: '{backend_by_env_var}'. Valid " - f"backends are: {list(AttentionBackendEnum.__members__.keys())}" - ) from e - - # get device-specific attn_backend from vllm.platforms import current_platform - sig = inspect.signature(current_platform.get_attn_backend_cls) - if "use_v1" in sig.parameters: - logger.warning_once( - "use_v1 parameter for get_attn_backend_cls is deprecated and will " - "be removed in v0.13.0 or v1.0.0, whichever is soonest. Please " - "remove it from your plugin code." - ) - attention_cls = current_platform.get_attn_backend_cls( - selected_backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - True, # use_v1 - use_mla, - has_sink, - use_sparse, - attn_type, - ) - else: - attention_cls = current_platform.get_attn_backend_cls( - selected_backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - attn_type, - ) + attention_cls = current_platform.get_attn_backend_cls( + backend, + attn_selector_config=attn_selector_config, + ) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}" @@ -232,37 +143,3 @@ def _cached_get_mamba_attn_backend( mamba_attn_backend = selected_backend.get_class() return mamba_attn_backend - - -@contextmanager -def global_force_attn_backend_context_manager( - attn_backend: AttentionBackendEnum, -) -> Generator[None, None, None]: - """ - Globally force a vLLM attention backend override within a - context manager, reverting the global attention backend - override to its prior state upon exiting the context - manager. - - Arguments: - - * attn_backend: attention backend to force - - Returns: - - * Generator - """ - - # Save the current state of the global backend override (if any) - original_value = get_global_forced_attn_backend() - - # Globally force the new backend override - global_force_attn_backend(attn_backend) - - # Yield control back to the enclosed code block - try: - yield - finally: - # Revert the original global backend override, if any - global_force_attn_backend(original_value) - _cached_get_attn_backend.cache_clear() diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index 8a46587473e43..e38c88f4838d1 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm import envs from vllm.logger import init_logger from vllm.platforms import current_platform @@ -49,10 +48,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: 3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2 ) - # 2. override if passed by environment - if envs.VLLM_FLASH_ATTN_VERSION is not None: - assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] - fa_version = envs.VLLM_FLASH_ATTN_VERSION + # 2. override if passed by environment or config + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + if vllm_config.attention_config.flash_attn_version is not None: + fa_version = vllm_config.attention_config.flash_attn_version # 3. fallback for unsupported combinations if device_capability.major == 10 and fa_version == 3: diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index ec9b0fd6e969c..49ee0faf049d1 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -32,7 +32,6 @@ from typing import Any, cast import numpy as np from PIL import Image -from transformers import PreTrainedTokenizerBase from typing_extensions import deprecated from vllm.lora.request import LoRARequest @@ -189,7 +188,7 @@ class BenchmarkDataset(ABC): @abstractmethod def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, request_id_prefix: str = "", no_oversample: bool = False, @@ -201,7 +200,7 @@ class BenchmarkDataset(ABC): for generating a list of SampleRequest objects. Args: - tokenizer (PreTrainedTokenizerBase): The tokenizer to be used + tokenizer (TokenizerLike): The tokenizer to be used for processing the dataset's text. num_requests (int): The number of sample requests to generate. request_id_prefix (str): The prefix of request_id. @@ -380,7 +379,7 @@ def process_video(video: Any) -> Mapping[str, Any]: def gen_prompt_decode_to_target_len( - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, token_sequence: list[int], target_token_len: int, max_retry: int = 10, @@ -468,7 +467,7 @@ class RandomDataset(BenchmarkDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, request_id_prefix: str = "", no_oversample: bool = False, @@ -580,7 +579,7 @@ class RandomDataset(BenchmarkDataset): range_ratio: float, input_len: int, output_len: int, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Get the sampling parameters for the dataset. @@ -626,7 +625,7 @@ class RandomDataset(BenchmarkDataset): def generate_token_sequence( self, *, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, prefix_token_ids: list[int], prefix_len: int, vocab_size: int, @@ -686,7 +685,7 @@ class RandomDatasetForReranking(RandomDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, request_id_prefix: str = "", range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, @@ -716,7 +715,11 @@ class RandomDatasetForReranking(RandomDataset): doc_lens, _, doc_offsets = self.get_sampling_params( num_requests, range_ratio, doc_len_param, 0, tokenizer ) + vocab_size = tokenizer.vocab_size + prohibited_tokens = tokenizer.all_special_ids + all_tokens = np.arange(vocab_size) + allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens))) query_prompt, query_input_len, token_mismatch_total = ( self.generate_token_sequence( @@ -727,6 +730,7 @@ class RandomDatasetForReranking(RandomDataset): input_len=query_len, offset=int(query_offsets[0]), index=0, + allowed_tokens=allowed_tokens, ) ) @@ -740,6 +744,7 @@ class RandomDatasetForReranking(RandomDataset): input_len=int(doc_lens[i]), offset=int(doc_offsets[i]), index=i + 1, + allowed_tokens=allowed_tokens, ) token_mismatch_total += token_mismatch requests.append((prompt, total_input_len)) @@ -1077,7 +1082,7 @@ class RandomMultiModalDataset(RandomDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, request_id_prefix: str = "", no_oversample: bool = False, @@ -1231,7 +1236,7 @@ class ShareGPTDataset(BenchmarkDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, lora_path: str | None = None, max_loras: int | None = None, @@ -1633,7 +1638,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ) -def get_samples(args, tokenizer) -> list[SampleRequest]: +def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]: if not hasattr(args, "request_id_prefix"): args.request_id_prefix = "" @@ -1842,6 +1847,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: random_seed=args.seed, dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle, + prefix_len=args.common_prefix_len, ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -1970,7 +1976,7 @@ class CustomDataset(BenchmarkDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, lora_path: str | None = None, max_loras: int | None = None, @@ -2100,7 +2106,7 @@ class SonnetDataset(BenchmarkDataset): def sample( self, - tokenizer, + tokenizer: TokenizerLike, num_requests: int, prefix_len: int = DEFAULT_PREFIX_LEN, input_len: int = DEFAULT_INPUT_LEN, @@ -2201,7 +2207,7 @@ class BurstGPTDataset(BenchmarkDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, max_loras: int | None = None, lora_path: str | None = None, @@ -2286,7 +2292,7 @@ class ConversationDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, enable_multimodal_chat: bool = False, @@ -2346,7 +2352,7 @@ class MultiModalConversationDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, enable_multimodal_chat: bool = False, @@ -2415,7 +2421,7 @@ class VisionArenaDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, enable_multimodal_chat: bool = False, @@ -2469,7 +2475,7 @@ class MMVUDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, enable_multimodal_chat: bool = False, @@ -2530,7 +2536,7 @@ class InstructCoderDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, enable_multimodal_chat: bool = False, @@ -2594,7 +2600,7 @@ class MTBenchDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, enable_multimodal_chat: bool = False, @@ -2660,7 +2666,7 @@ class BlazeditDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, skip_chat_template: bool = False, @@ -2741,7 +2747,7 @@ class AIMODataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, request_id_prefix: str = "", @@ -2851,7 +2857,7 @@ class NextEditPredictionDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, request_id_prefix: str = "", no_oversample: bool = False, @@ -2923,7 +2929,7 @@ class ASRDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, request_id_prefix: str = "", @@ -3001,7 +3007,7 @@ class MLPerfDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, request_id_prefix: str = "", @@ -3080,7 +3086,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, prefix_len: int = DEFAULT_PREFIX_LEN, suffix_len: int = DEFAULT_SUFFIX_LEN, @@ -3166,7 +3172,7 @@ class MMStarDataset(HuggingFaceDataset): def sample( self, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, num_requests: int, output_len: int | None = None, enable_multimodal_chat: bool = False, diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index b4f1751837f48..99c1c846f19af 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -12,7 +12,6 @@ from typing import Any import numpy as np from tqdm import tqdm -import vllm.envs as envs from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType @@ -79,12 +78,11 @@ def add_cli_args(parser: argparse.ArgumentParser): def main(args: argparse.Namespace): - if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: - raise OSError( - "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " - "Please set it to a valid path to use torch profiler." - ) engine_args = EngineArgs.from_cli_args(args) + if args.profile and not engine_args.profiler_config.profiler == "torch": + raise ValueError( + "The torch profiler is not enabled. Please provide profiler_config." + ) # Lazy import to avoid importing LLM when the bench command is not selected. from vllm import LLM, SamplingParams @@ -144,7 +142,7 @@ def main(args: argparse.Namespace): run_to_completion(profile_dir=None) if args.profile: - profile_dir = envs.VLLM_TORCH_PROFILER_DIR + profile_dir = engine_args.profiler_config.torch_profiler_dir print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 519303c0bfa0a..f5d8ea5a975a9 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -36,7 +36,6 @@ from typing import Any, Literal import aiohttp import numpy as np from tqdm.asyncio import tqdm -from transformers import PreTrainedTokenizerBase from vllm.benchmarks.datasets import SampleRequest, add_dataset_parser, get_samples from vllm.benchmarks.lib.endpoint_request_func import ( @@ -47,7 +46,7 @@ from vllm.benchmarks.lib.endpoint_request_func import ( ) from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.network_utils import join_host_port @@ -236,7 +235,9 @@ async def get_request( def calculate_metrics_for_embeddings( - outputs: list[RequestFuncOutput], dur_s: float, selected_percentiles: list[float] + outputs: list[RequestFuncOutput], + dur_s: float, + selected_percentiles: list[float], ) -> EmbedBenchmarkMetrics: """Calculate the metrics for the embedding requests. @@ -286,7 +287,7 @@ def calculate_metrics( input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], dur_s: float, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, selected_percentiles: list[float], goodput_config_dict: dict[str, float], ) -> tuple[BenchmarkMetrics, list[int]]: @@ -489,7 +490,7 @@ async def benchmark( base_url: str, model_id: str, model_name: str, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, input_requests: list[SampleRequest], logprobs: int | None, request_rate: float, @@ -789,7 +790,7 @@ async def benchmark( ) print( "{:<40} {:<10.2f}".format( - "Total Token throughput (tok/s):", metrics.total_token_throughput + "Total token throughput (tok/s):", metrics.total_token_throughput ) ) @@ -1032,6 +1033,19 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default="auto", + help="""Tokenizer mode:\n + - "auto" will use the tokenizer from `mistral_common` for Mistral models + if available, otherwise it will use the "hf" tokenizer.\n + - "hf" will use the fast tokenizer if available.\n + - "slow" will always use the slow tokenizer.\n + - "mistral" will always use the tokenizer from `mistral_common`.\n + - "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n + - Other custom values can be supported via plugins.""", + ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( "--logprobs", @@ -1085,8 +1099,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--profile", action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "VLLM_TORCH_PROFILER_DIR to enable profiler.", + help="Use vLLM Profiling. --profiler-config must be provided on the server.", ) parser.add_argument( "--save-result", @@ -1221,17 +1234,11 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Repetition penalty sampling parameter. Only has effect on " "openai-compatible backends.", ) - - parser.add_argument( - "--tokenizer-mode", - type=str, - default="auto", - choices=["auto", "slow", "mistral", "custom"], - help='The tokenizer mode.\n\n* "auto" will use the ' - 'fast tokenizer if available.\n* "slow" will ' - "always use the slow tokenizer. \n* " - '"mistral" will always use the `mistral_common` tokenizer. \n*' - '"custom" will use --tokenizer to select the preregistered tokenizer.', + sampling_group.add_argument( + "--common-prefix-len", + type=int, + default=None, + help="Common prefix length shared by all prompts (used by random dataset)", ) parser.add_argument( diff --git a/vllm/benchmarks/startup.py b/vllm/benchmarks/startup.py new file mode 100644 index 0000000000000..086f7bf62f838 --- /dev/null +++ b/vllm/benchmarks/startup.py @@ -0,0 +1,326 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark the cold and warm startup time of vLLM models. + +This script measures total startup time (including model loading, compilation, +and cache operations) for both cold and warm scenarios: +- Cold startup: Fresh start with no caches (temporary cache directories) +- Warm startup: Using cached compilation and model info +""" + +import argparse +import dataclasses +import json +import multiprocessing +import os +import shutil +import tempfile +import time +from contextlib import contextmanager +from typing import Any + +import numpy as np +from tqdm import tqdm + +from vllm.benchmarks.lib.utils import ( + convert_to_pytorch_benchmark_format, + write_to_json, +) +from vllm.engine.arg_utils import EngineArgs + + +@contextmanager +def cold_startup(): + """ + Context manager to measure cold startup time: + 1. Uses a temporary directory for vLLM cache to avoid any pollution + between cold startup iterations. + 2. Uses inductor's fresh_cache to clear torch.compile caches. + """ + from torch._inductor.utils import fresh_cache + + # Use temporary directory for caching to avoid any pollution between cold startups + original_cache_root = os.environ.get("VLLM_CACHE_ROOT") + temp_cache_dir = tempfile.mkdtemp(prefix="vllm_startup_bench_cold_") + try: + os.environ["VLLM_CACHE_ROOT"] = temp_cache_dir + with fresh_cache(): + yield + finally: + # Clean up temporary cache directory + shutil.rmtree(temp_cache_dir, ignore_errors=True) + if original_cache_root: + os.environ["VLLM_CACHE_ROOT"] = original_cache_root + else: + os.environ.pop("VLLM_CACHE_ROOT", None) + + +def run_startup_in_subprocess(engine_args_dict, result_queue): + """ + Run LLM startup in a subprocess and return timing metrics via a queue. + This ensures complete isolation between iterations. + """ + try: + # Import inside the subprocess to avoid issues with forking + from vllm import LLM + from vllm.engine.arg_utils import EngineArgs + + engine_args = EngineArgs(**engine_args_dict) + + # Measure total startup time + start_time = time.perf_counter() + + llm = LLM(**dataclasses.asdict(engine_args)) + + total_startup_time = time.perf_counter() - start_time + + # Extract compilation time if available + compilation_time = 0.0 + if hasattr(llm.llm_engine, "vllm_config"): + vllm_config = llm.llm_engine.vllm_config + if ( + hasattr(vllm_config, "compilation_config") + and vllm_config.compilation_config is not None + ): + compilation_time = vllm_config.compilation_config.compilation_time + + result_queue.put( + { + "total_startup_time": total_startup_time, + "compilation_time": compilation_time, + } + ) + + except Exception as e: + result_queue.put(None) + result_queue.put(str(e)) + + +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: + base_name = os.path.splitext(args.output_json)[0] + + cold_startup_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "avg_cold_startup_time": results["avg_cold_startup_time"], + }, + extra_info={ + "cold_startup_times": results["cold_startup_times"], + "cold_startup_percentiles": results["cold_startup_percentiles"], + }, + ) + if cold_startup_records: + write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records) + + cold_compilation_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "avg_cold_compilation_time": results["avg_cold_compilation_time"], + }, + extra_info={ + "cold_compilation_times": results["cold_compilation_times"], + "cold_compilation_percentiles": results["cold_compilation_percentiles"], + }, + ) + if cold_compilation_records: + write_to_json( + f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records + ) + + warm_startup_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "avg_warm_startup_time": results["avg_warm_startup_time"], + }, + extra_info={ + "warm_startup_times": results["warm_startup_times"], + "warm_startup_percentiles": results["warm_startup_percentiles"], + }, + ) + if warm_startup_records: + write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records) + + warm_compilation_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "avg_warm_compilation_time": results["avg_warm_compilation_time"], + }, + extra_info={ + "warm_compilation_times": results["warm_compilation_times"], + "warm_compilation_percentiles": results["warm_compilation_percentiles"], + }, + ) + if warm_compilation_records: + write_to_json( + f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records + ) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-iters-cold", + type=int, + default=5, + help="Number of cold startup iterations.", + ) + parser.add_argument( + "--num-iters-warmup", + type=int, + default=3, + help="Number of warmup iterations before benchmarking warm startups.", + ) + parser.add_argument( + "--num-iters-warm", + type=int, + default=5, + help="Number of warm startup iterations.", + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save the startup time results in JSON format.", + ) + + parser = EngineArgs.add_cli_args(parser) + return parser + + +def main(args: argparse.Namespace): + # Set multiprocessing start method to 'spawn' for clean process isolation + # This ensures each subprocess starts fresh without inheriting state + multiprocessing.set_start_method("spawn", force=True) + + engine_args = EngineArgs.from_cli_args(args) + + def create_llm_and_measure_startup(): + """ + Create LLM instance in a subprocess and measure startup time. + Returns timing metrics, using subprocess for complete isolation. + """ + # Convert engine_args to dictionary for pickling + engine_args_dict = dataclasses.asdict(engine_args) + + # Create a queue for inter-process communication + result_queue = multiprocessing.Queue() + process = multiprocessing.Process( + target=run_startup_in_subprocess, + args=( + engine_args_dict, + result_queue, + ), + ) + process.start() + process.join() + + if not result_queue.empty(): + result = result_queue.get() + if result is None: + if not result_queue.empty(): + error_msg = result_queue.get() + raise RuntimeError(f"Subprocess failed: {error_msg}") + else: + raise RuntimeError("Subprocess failed with unknown error") + return result + else: + raise RuntimeError("Subprocess did not return a result") + + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n") + + print("Measuring cold startup time...\n") + cold_startup_times = [] + cold_compilation_times = [] + for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"): + with cold_startup(): + metrics = create_llm_and_measure_startup() + cold_startup_times.append(metrics["total_startup_time"]) + cold_compilation_times.append(metrics["compilation_time"]) + + # Warmup for warm startup + print("\nWarming up for warm startup measurement...\n") + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + create_llm_and_measure_startup() + + print("\nMeasuring warm startup time...\n") + warm_startup_times = [] + warm_compilation_times = [] + for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"): + metrics = create_llm_and_measure_startup() + warm_startup_times.append(metrics["total_startup_time"]) + warm_compilation_times.append(metrics["compilation_time"]) + + # Calculate statistics + cold_startup_array = np.array(cold_startup_times) + cold_compilation_array = np.array(cold_compilation_times) + warm_startup_array = np.array(warm_startup_times) + warm_compilation_array = np.array(warm_compilation_times) + + avg_cold_startup = np.mean(cold_startup_array) + avg_cold_compilation = np.mean(cold_compilation_array) + avg_warm_startup = np.mean(warm_startup_array) + avg_warm_compilation = np.mean(warm_compilation_array) + + percentages = [10, 25, 50, 75, 90, 99] + cold_startup_percentiles = np.percentile(cold_startup_array, percentages) + cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages) + warm_startup_percentiles = np.percentile(warm_startup_array, percentages) + warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages) + + print("\n" + "=" * 60) + print("STARTUP TIME BENCHMARK RESULTS") + print("=" * 60) + + # Cold startup statistics + print("\nCOLD STARTUP:") + print(f"Avg total startup time: {avg_cold_startup:.2f} seconds") + print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds") + print("Startup time percentiles:") + for percentage, percentile in zip(percentages, cold_startup_percentiles): + print(f" {percentage}%: {percentile:.2f} seconds") + print("Compilation time percentiles:") + for percentage, percentile in zip(percentages, cold_compilation_percentiles): + print(f" {percentage}%: {percentile:.2f} seconds") + + # Warm startup statistics + print("\nWARM STARTUP:") + print(f"Avg total startup time: {avg_warm_startup:.2f} seconds") + print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds") + print("Startup time percentiles:") + for percentage, percentile in zip(percentages, warm_startup_percentiles): + print(f" {percentage}%: {percentile:.2f} seconds") + print("Compilation time percentiles:") + for percentage, percentile in zip(percentages, warm_compilation_percentiles): + print(f" {percentage}%: {percentile:.2f} seconds") + + print("=" * 60) + + # Output JSON results if specified + if args.output_json: + results = { + "avg_cold_startup_time": float(avg_cold_startup), + "avg_cold_compilation_time": float(avg_cold_compilation), + "cold_startup_times": cold_startup_times, + "cold_compilation_times": cold_compilation_times, + "cold_startup_percentiles": dict( + zip(percentages, cold_startup_percentiles.tolist()) + ), + "cold_compilation_percentiles": dict( + zip(percentages, cold_compilation_percentiles.tolist()) + ), + "avg_warm_startup_time": float(avg_warm_startup), + "avg_warm_compilation_time": float(avg_warm_compilation), + "warm_startup_times": warm_startup_times, + "warm_compilation_times": warm_compilation_times, + "warm_startup_percentiles": dict( + zip(percentages, warm_startup_percentiles.tolist()) + ), + "warm_compilation_percentiles": dict( + zip(percentages, warm_compilation_percentiles.tolist()) + ), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/vllm/benchmarks/sweep/param_sweep.py b/vllm/benchmarks/sweep/param_sweep.py index 986561ed8502a..a438a328880fd 100644 --- a/vllm/benchmarks/sweep/param_sweep.py +++ b/vllm/benchmarks/sweep/param_sweep.py @@ -9,8 +9,26 @@ class ParameterSweep(list["ParameterSweepItem"]): @classmethod def read_json(cls, filepath: os.PathLike): with open(filepath, "rb") as f: - records = json.load(f) + data = json.load(f) + # Support both list and dict formats + if isinstance(data, dict): + return cls.read_from_dict(data) + + return cls.from_records(data) + + @classmethod + def read_from_dict(cls, data: dict[str, dict[str, object]]): + """ + Read parameter sweep from a dict format where keys are names. + + Example: + { + "experiment1": {"max_tokens": 100, "temperature": 0.7}, + "experiment2": {"max_tokens": 200, "temperature": 0.9} + } + """ + records = [{"_benchmark_name": name, **params} for name, params in data.items()] return cls.from_records(records) @classmethod @@ -21,6 +39,15 @@ class ParameterSweep(list["ParameterSweepItem"]): f"but found type: {type(records)}" ) + # Validate that all _benchmark_name values are unique if provided + names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r] + if names and len(names) != len(set(names)): + duplicates = [name for name in names if names.count(name) > 1] + raise ValueError( + f"Duplicate _benchmark_name values found: {set(duplicates)}. " + f"All _benchmark_name values must be unique." + ) + return cls(ParameterSweepItem.from_record(record) for record in records) @@ -38,6 +65,18 @@ class ParameterSweepItem(dict[str, object]): def __or__(self, other: dict[str, Any]): return type(self)(super().__or__(other)) + @property + def name(self) -> str: + """ + Get the name for this parameter sweep item. + + Returns the '_benchmark_name' field if present, otherwise returns a text + representation of all parameters. + """ + if "_benchmark_name" in self: + return self["_benchmark_name"] + return self.as_text(sep="-") + # In JSON, we prefer "_" def _iter_param_key_candidates(self, param_key: str): # Inner config arguments are not converted by the CLI @@ -63,29 +102,57 @@ class ParameterSweepItem(dict[str, object]): def has_param(self, param_key: str) -> bool: return any(k in self for k in self._iter_param_key_candidates(param_key)) + def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]: + """ + Normalize a key-value pair into command-line arguments. + + Returns a list containing either: + - A single element for boolean flags (e.g., ['--flag'] or ['--flag=true']) + - Two elements for key-value pairs (e.g., ['--key', 'value']) + """ + if isinstance(v, bool): + # For nested params (containing "."), use =true/false syntax + if "." in k: + return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"] + else: + return [self._normalize_cmd_key(k if v else "no-" + k)] + else: + return [self._normalize_cmd_key(k), str(v)] + def apply_to_cmd(self, cmd: list[str]) -> list[str]: cmd = list(cmd) for k, v in self.items(): + # Skip the '_benchmark_name' field, not a parameter + if k == "_benchmark_name": + continue + + # Serialize dict values as JSON + if isinstance(v, dict): + v = json.dumps(v) + for k_candidate in self._iter_cmd_key_candidates(k): try: k_idx = cmd.index(k_candidate) - if isinstance(v, bool): - cmd[k_idx] = self._normalize_cmd_key(k if v else "no-" + k) + # Replace existing parameter + normalized = self._normalize_cmd_kv_pair(k, v) + if len(normalized) == 1: + # Boolean flag + cmd[k_idx] = normalized[0] else: - cmd[k_idx + 1] = str(v) + # Key-value pair + cmd[k_idx] = normalized[0] + cmd[k_idx + 1] = normalized[1] break except ValueError: continue else: - if isinstance(v, bool): - cmd.append(self._normalize_cmd_key(k if v else "no-" + k)) - else: - cmd.extend([self._normalize_cmd_key(k), str(v)]) + # Add new parameter + cmd.extend(self._normalize_cmd_kv_pair(k, v)) return cmd def as_text(self, sep: str = ", ") -> str: - return sep.join(f"{k}={v}" for k, v in self.items()) + return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name") diff --git a/vllm/benchmarks/sweep/plot.py b/vllm/benchmarks/sweep/plot.py index 9947d6170d891..163d517931342 100644 --- a/vllm/benchmarks/sweep/plot.py +++ b/vllm/benchmarks/sweep/plot.py @@ -65,6 +65,18 @@ class PlotEqualTo(PlotFilterBase): return df[df[self.var] == target] +@dataclass +class PlotNotEqualTo(PlotFilterBase): + @override + def apply(self, df: "pd.DataFrame") -> "pd.DataFrame": + try: + target = float(self.target) + except ValueError: + target = self.target + + return df[df[self.var] != target] + + @dataclass class PlotLessThan(PlotFilterBase): @override @@ -96,6 +108,7 @@ class PlotGreaterThanOrEqualTo(PlotFilterBase): # NOTE: The ordering is important! Match longer op_keys first PLOT_FILTERS: dict[str, type[PlotFilterBase]] = { "==": PlotEqualTo, + "!=": PlotNotEqualTo, "<=": PlotLessThanOrEqualTo, ">=": PlotGreaterThanOrEqualTo, "<": PlotLessThan, @@ -167,6 +180,27 @@ def _json_load_bytes(path: Path) -> list[dict[str, object]]: return json.load(f) +def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]: + """ + Convert string values "inf", "-inf", and "nan" to their float equivalents. + + This handles the case where JSON serialization represents inf/nan as strings. + """ + converted_data = [] + for record in data: + converted_record = {} + for key, value in record.items(): + if isinstance(value, str): + if value in ["inf", "-inf", "nan"]: + converted_record[key] = float(value) + else: + converted_record[key] = value + else: + converted_record[key] = value + converted_data.append(converted_record) + return converted_data + + def _get_metric(run_data: dict[str, object], metric_key: str): try: return run_data[metric_key] @@ -178,12 +212,15 @@ def _get_group(run_data: dict[str, object], group_keys: list[str]): return tuple((k, str(_get_metric(run_data, k))) for k in group_keys) -def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...]): +def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str): parts = list[str]() + + # Start with figure name (always provided, defaults to "FIGURE") + parts.append(fig_name) + + # Always append group data if present if group: - parts.extend(("FIGURE-", *(f"{k}={v}" for k, v in group))) - else: - parts.append("figure") + parts.extend(f"{k}={v}" for k, v in group) return fig_dir / sanitize_filename("-".join(parts) + ".png") @@ -217,6 +254,10 @@ def _plot_fig( scale_x: str | None, scale_y: str | None, dry_run: bool, + fig_name: str, + error_bars: bool, + fig_height: float, + fig_dpi: int, ): fig_group, fig_data = fig_group_data @@ -230,7 +271,7 @@ def _plot_fig( for _, row_data in row_groups ) - fig_path = _get_fig_path(fig_dir, fig_group) + fig_path = _get_fig_path(fig_dir, fig_group, fig_name) print("[BEGIN FIGURE]") print(f"Group: {dict(fig_group)}") @@ -241,6 +282,8 @@ def _plot_fig( print("[END FIGURE]") return + # Convert string "inf", "-inf", and "nan" to their float equivalents + fig_data = _convert_inf_nan_strings(fig_data) df = pd.DataFrame.from_records(fig_data) if var_x not in df.columns: @@ -275,6 +318,10 @@ def _plot_fig( df = filter_by.apply(df) df = bin_by.apply(df) + # Sort by curve_by columns alphabetically for consistent legend ordering + if curve_by: + df = df.sort_values(by=curve_by) + df["row_group"] = ( pd.concat( [k + "=" + df[k].astype(str) for k in row_by], @@ -293,7 +340,7 @@ def _plot_fig( else "(All)" ) - g = sns.FacetGrid(df, row="row_group", col="col_group") + g = sns.FacetGrid(df, row="row_group", col="col_group", height=fig_height) if row_by and col_by: g.set_titles("{row_name}\n{col_name}") @@ -320,6 +367,7 @@ def _plot_fig( style=style, size=size, markers=True, + errorbar="sd" if error_bars else None, ) g.add_legend(title=hue) @@ -339,11 +387,12 @@ def _plot_fig( y=var_y, hue="curve_group", markers=True, + errorbar="sd" if error_bars else None, ) g.add_legend() - g.savefig(fig_path) + g.savefig(fig_path, dpi=fig_dpi) plt.close(g.figure) print("[END FIGURE]") @@ -364,6 +413,10 @@ def plot( scale_x: str | None, scale_y: str | None, dry_run: bool, + fig_name: str = "FIGURE", + error_bars: bool = True, + fig_height: float = 6.4, + fig_dpi: int = 300, ): all_data = [ run_data @@ -398,6 +451,10 @@ def plot( scale_x=scale_x, scale_y=scale_y, dry_run=dry_run, + fig_name=fig_name, + error_bars=error_bars, + fig_height=fig_height, + fig_dpi=fig_dpi, ), fig_groups, ) @@ -419,6 +476,10 @@ class SweepPlotArgs: scale_x: str | None scale_y: str | None dry_run: bool + fig_name: str = "FIGURE" + error_bars: bool = True + fig_height: float = 6.4 + fig_dpi: int = 300 parser_name: ClassVar[str] = "plot" parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results." @@ -448,6 +509,10 @@ class SweepPlotArgs: scale_x=args.scale_x, scale_y=args.scale_y, dry_run=args.dry_run, + fig_name=args.fig_name, + error_bars=not args.no_error_bars, + fig_height=args.fig_height, + fig_dpi=args.fig_dpi, ) @classmethod @@ -541,6 +606,32 @@ class SweepPlotArgs: "Currently only accepts string values such as 'log' and 'sqrt'. " "See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html", ) + parser.add_argument( + "--fig-name", + type=str, + default="FIGURE", + help="Name prefix for the output figure file. " + "Group data is always appended when present. " + "Default: 'FIGURE'. Example: --fig-name my_performance_plot", + ) + parser.add_argument( + "--no-error-bars", + action="store_true", + help="If set, disables error bars on the plot. " + "By default, error bars are shown.", + ) + parser.add_argument( + "--fig-height", + type=float, + default=6.4, + help="Height of each subplot in inches. Default: 6.4", + ) + parser.add_argument( + "--fig-dpi", + type=int, + default=300, + help="Resolution of the output figure in dots per inch. Default: 300", + ) parser.add_argument( "--dry-run", action="store_true", @@ -566,6 +657,10 @@ def run_main(args: SweepPlotArgs): scale_x=args.scale_x, scale_y=args.scale_y, dry_run=args.dry_run, + fig_name=args.fig_name, + error_bars=args.error_bars, + fig_height=args.fig_height, + fig_dpi=args.fig_dpi, ) diff --git a/vllm/benchmarks/sweep/serve.py b/vllm/benchmarks/sweep/serve.py index 1298e4acbd87d..6626707cf2a52 100644 --- a/vllm/benchmarks/sweep/serve.py +++ b/vllm/benchmarks/sweep/serve.py @@ -138,9 +138,9 @@ def _get_comb_base_path( ): parts = list[str]() if serve_comb: - parts.extend(("SERVE-", serve_comb.as_text(sep="-"))) + parts.extend(("SERVE-", serve_comb.name)) if bench_comb: - parts.extend(("BENCH-", bench_comb.as_text(sep="-"))) + parts.extend(("BENCH-", bench_comb.name)) return output_dir / sanitize_filename("-".join(parts)) @@ -345,8 +345,9 @@ class SweepServeArgs: "--serve-params", type=str, default=None, - help="Path to JSON file containing a list of parameter combinations " - "for the `vllm serve` command. " + help="Path to JSON file containing parameter combinations " + "for the `vllm serve` command. Can be either a list of dicts or a dict " + "where keys are benchmark names. " "If both `serve_params` and `bench_params` are given, " "this script will iterate over their Cartesian product.", ) @@ -354,8 +355,9 @@ class SweepServeArgs: "--bench-params", type=str, default=None, - help="Path to JSON file containing a list of parameter combinations " - "for the `vllm bench serve` command. " + help="Path to JSON file containing parameter combinations " + "for the `vllm bench serve` command. Can be either a list of dicts or " + "a dict where keys are benchmark names. " "If both `serve_params` and `bench_params` are given, " "this script will iterate over their Cartesian product.", ) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 23b5faa1b2c32..37b8952a350b4 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -14,7 +14,7 @@ from typing import Any import torch import uvloop from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase from vllm.benchmarks.datasets import ( AIMODataset, @@ -35,6 +35,7 @@ from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams +from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.utils.async_utils import merge_async_iterators @@ -246,12 +247,15 @@ async def run_vllm_async( def run_hf( requests: list[SampleRequest], model: str, - tokenizer: PreTrainedTokenizerBase, + tokenizer: TokenizerLike, n: int, max_batch_size: int, trust_remote_code: bool, disable_detokenize: bool = False, ) -> float: + assert isinstance(tokenizer, PreTrainedTokenizerBase), ( + "the hf backend only supports HF tokenizers" + ) llm = AutoModelForCausalLM.from_pretrained( model, dtype=torch.float16, trust_remote_code=trust_remote_code ) @@ -342,7 +346,10 @@ def get_requests(args, tokenizer): "output_len": args.output_len, } - if args.dataset_path is None or args.dataset_name == "random": + if args.dataset_name == "random" or ( + args.dataset_path is None + and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"} + ): sample_kwargs["range_ratio"] = args.random_range_ratio sample_kwargs["prefix_len"] = args.prefix_len dataset_cls = RandomDataset @@ -651,8 +658,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--profile", action="store_true", default=False, - help="Use Torch Profiler. The env variable " - "VLLM_TORCH_PROFILER_DIR must be set to enable profiler.", + help="Use vLLM Profiling. --profiler-config must be provided on the server.", ) # prefix repetition dataset @@ -692,15 +698,21 @@ def add_cli_args(parser: argparse.ArgumentParser): def main(args: argparse.Namespace): - if args.tokenizer is None: - args.tokenizer = args.model validate_args(args) if args.seed is None: args.seed = 0 random.seed(args.seed) # Sample the requests. - tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code + if ( + args.backend == "hf" or args.backend == "mii" + ) and args.tokenizer_mode == "auto": + # mistral_common tokenizer is only supported on vllm and vllm-chat backends; + # for hf and mii backends, we use hf tokenizer + args.tokenizer_mode = "hf" + tokenizer = get_tokenizer( + args.tokenizer, + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, ) requests = get_requests(args, tokenizer) is_multi_modal = any(request.multi_modal_data is not None for request in requests) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1773913d0b6c6..a1eec7d74483f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -26,7 +26,8 @@ from vllm.compilation.partition_rules import ( should_split, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig -from vllm.config.utils import hash_factors +from vllm.config.compilation import DynamicShapesType +from vllm.config.utils import Range, hash_factors from vllm.logger import init_logger from vllm.logging_utils import lazy from vllm.platforms import current_platform @@ -90,7 +91,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[int | None, int, str], Any] = dict() + self.cache: dict[tuple[Range, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -99,11 +100,11 @@ class CompilerManager: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, runtime_shape: int | None = None): + def compile_context(self, compile_range: Range): """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" - with pass_context(runtime_shape): + with pass_context(compile_range): if self.compilation_config.use_inductor_graph_partition: with inductor_partition_rule_context( self.compilation_config.splitting_ops @@ -140,7 +141,25 @@ class CompilerManager: # we use ast.literal_eval to parse the data # because it is a safe way to parse Python literals. # do not use eval(), it is unsafe. - self.cache = ast.literal_eval(f.read()) + cache = ast.literal_eval(f.read()) + + def check_type(value, ty): + if not isinstance(value, ty): + raise TypeError(f"Expected {ty} but got {type(value)} for {value}") + + def parse_key(key: Any) -> tuple[Range, int, str]: + range_tuple, graph_index, compiler_name = key + check_type(graph_index, int) + check_type(compiler_name, str) + if isinstance(range_tuple, tuple): + start, end = range_tuple + check_type(start, int) + check_type(end, int) + range_tuple = Range(start=start, end=end) + check_type(range_tuple, Range) + return range_tuple, graph_index, compiler_name + + self.cache = {parse_key(key): value for key, value in cache.items()} self.compiler.initialize_cache( cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix @@ -159,29 +178,21 @@ class CompilerManager: graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range, ) -> Callable | None: - if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + handle = self.cache[(compile_range, graph_index, self.compiler.name)] compiled_graph = self.compiler.load( - handle, graph, example_inputs, graph_index, runtime_shape + handle, graph, example_inputs, graph_index, compile_range + ) + logger.debug( + "Directly load the %s-th graph for compile range %sfrom %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, ) - if runtime_shape is None: - logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via handle %s", - graph_index, - self.compiler.name, - handle, - ) - else: - logger.debug( - "Directly load the %s-th graph for shape %s from %s via handle %s", - graph_index, - str(runtime_shape), - self.compiler.name, - handle, - ) return compiled_graph def compile( @@ -190,9 +201,9 @@ class CompilerManager: example_inputs, additional_inductor_config, compilation_config: CompilationConfig, + compile_range: Range, graph_index: int = 0, num_graphs: int = 1, - runtime_shape: int | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -204,7 +215,7 @@ class CompilerManager: compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) + compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. @@ -212,19 +223,12 @@ class CompilerManager: now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: - logger.info( - "Directly load the compiled graph(s) for dynamic shape " - "from the cache, took %.3f s", - elapsed, - ) - else: - logger.info( - "Directly load the compiled graph(s) for shape %s " - "from the cache, took %.3f s", - str(runtime_shape), - elapsed, - ) + logger.info( + "Directly load the compiled graph(s) for compile range %s " + "from the cache, took %.3f s", + str(compile_range), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -233,14 +237,15 @@ class CompilerManager: # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - - with self.compile_context(runtime_shape): + maybe_key = "artifact_compile_range_" + maybe_key += f"{compile_range.start}_{compile_range.end}" + maybe_key += f"_subgraph_{graph_index}" + with self.compile_context(compile_range): compiled_graph, handle = self.compiler.compile( graph, example_inputs, additional_inductor_config, - runtime_shape, + compile_range, maybe_key, ) @@ -248,55 +253,34 @@ class CompilerManager: # store the artifact in the cache if is_compile_cache_enabled(additional_inductor_config) and handle is not None: - self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + self.cache[(compile_range, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if runtime_shape is None: - logger.info_once( - "Cache the graph for dynamic shape for later use", scope="local" - ) - else: - logger.info_once( - "Cache the graph of shape %s for later use", - str(runtime_shape), - scope="local", - ) - if runtime_shape is None: - logger.debug( - "Store the %s-th graph for dynamic shape from %s via handle %s", - graph_index, - self.compiler.name, - handle, - ) - else: - logger.debug( - "Store the %s-th graph for shape %s from %s via handle %s", - graph_index, - str(runtime_shape), - self.compiler.name, - handle, + logger.info_once( + "Cache the graph of compile range %s for later use", + str(compile_range), ) + logger.debug( + "Store the %s-th graph for compile range%s from %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: - logger.info_once( - "Compiling a graph for dynamic shape takes %.2f s", - elapsed, - scope="local", - ) - else: - logger.info_once( - "Compiling a graph for shape %s takes %.2f s", - runtime_shape, - elapsed, - scope="local", - ) + logger.info_once( + "Compiling a graph for compile range %s takes %.2f s", + str(compile_range), + elapsed, + scope="local", + ) return compiled_graph @@ -402,6 +386,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): self.extra_traceback = False def run(self, *args): + # maybe instead just assert inputs are fake? fake_args = [ self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args @@ -416,27 +401,17 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): kwargs: dict[str, Any], ) -> Any: assert isinstance(target, str) + output = super().call_module(target, args, kwargs) if target in self.compile_submod_names: index = self.compile_submod_names.index(target) submod = self.fetch_attr(target) + sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] - global compilation_start_time - compiled_graph_for_dynamic_shape = ( - self.vllm_backend.compiler_manager.compile( - submod, - args, - self.vllm_backend.inductor_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - runtime_shape=None, - ) - ) # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend @@ -446,7 +421,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, self.vllm_backend, ) @@ -489,21 +463,27 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # the tag for the part of model being compiled, # e.g. backbone/eagle_head model_tag: str = "backbone" +model_is_encoder: bool = False @contextmanager -def set_model_tag(tag: str): +def set_model_tag(tag: str, is_encoder: bool = False): """Context manager to set the model tag.""" global model_tag + global model_is_encoder assert tag != model_tag, ( f"Model tag {tag} is the same as the current tag {model_tag}." ) old_tag = model_tag + old_is_encoder = model_is_encoder + model_tag = tag + model_is_encoder = is_encoder try: yield finally: model_tag = old_tag + model_is_encoder = old_is_encoder class VllmBackend: @@ -549,6 +529,9 @@ class VllmBackend: # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag + # Mark compilation for encoder. + self.is_encoder = model_is_encoder + # Passes to run on the graph post-grad. self.pass_manager = resolve_obj_by_qualname( current_platform.get_pass_manager_cls() @@ -586,8 +569,13 @@ class VllmBackend: ) else: # Config should automatically wrap all inductor passes - assert isinstance(self.inductor_config[self.pass_key], InductorPass) - self.pass_manager.add(self.inductor_config[self.pass_key]) + assert isinstance( + self.compilation_config.inductor_compile_config[self.pass_key], + InductorPass, + ) + self.pass_manager.add( + self.compilation_config.inductor_compile_config[self.pass_key] + ) self.inductor_config[self.pass_key] = self.pass_manager def __call__( @@ -746,11 +734,44 @@ class VllmBackend: if not item.is_splitting_graph ] + # Extract fake values from the graph to use them when needed. + all_fake_values = [] + for i in graph.graph.find_nodes(op="placeholder"): + all_fake_values.append(i.meta["example_value"]) + + fake_args = [ + all_fake_values[i] if isinstance(t, torch.Tensor) else t + for i, t in enumerate(example_inputs) + ] + # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter( self.split_gm, submod_names_to_compile, self.vllm_config, self - ).run(*example_inputs) + ).run(*fake_args) + + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode() + + if ( + self.compilation_config.dynamic_shapes_config.evaluate_guards + and self.compilation_config.dynamic_shapes_config.type + == DynamicShapesType.BACKED + ): + from torch.utils._sympy.value_ranges import ValueRanges + + # Drop counter-0/1 specializations guards; for backed dynamic shapes, + # torch.compile will specialize for 0/1 inputs or otherwise guards that + # shape is >= 2. This is because it's really hard not to hit a check + # against 0/1. When we evaluate shape guards, we exclude checking those + # guards (We would fail always otherwise). + + # We avoid that by updating the ranges of backed sizes when the min is + # 2 for any, we assume it's 0. + for s, r in fake_mode.shape_env.var_to_range.items(): + if r.lower == 2: + fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper) graph_path = os.path.join(local_cache_dir, "computation_graph.py") if not os.path.exists(graph_path): @@ -779,15 +800,6 @@ class VllmBackend: graph, example_inputs, self.prefix, self.split_gm ) - # if we need to copy input buffers for cudagraph - from torch._guards import detect_fake_mode - - fake_mode = detect_fake_mode() - fake_args = [ - fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t - for t in example_inputs - ] - # index of tensors that have symbolic shapes (batch size) # for weights and static buffers, they will have concrete shapes. # symbolic shape only happens for input tensors. diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 69d4606d73ebd..57bd94c7e8ad6 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,6 +10,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, @@ -431,7 +432,7 @@ class AsyncTPPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: Range) -> bool: # This pass is applied on top of the sequence parallelism pass. # It inherits the same applicability condition as `SequenceParallelismPass`. # See `SequenceParallelismPass.is_applicable` for more details. @@ -441,7 +442,7 @@ class AsyncTPPass(VllmPatternMatcherPass): ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range.is_single_size() and compile_range.end % tp_size == 0 @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): @@ -505,91 +506,60 @@ if flashinfer_comm is not None: num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size + max_tensor_size = max_token_num * hidden_size * element_size + assert current_tensor_size <= max_tensor_size, ( + f"Current tensor size {current_tensor_size} is larger than " + f"max token num {max_token_num} * hidden size {hidden_size} * " + f"element size {element_size}" + ) + device_capability = current_platform.get_device_capability().to_int() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( + device_capability, {} + ).get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = ( + max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB + ) - if num_tokens <= max_token_num: - device_capability = current_platform.get_device_capability().to_int() - # Get one shot input size limit for the current world size - # for the current device capability - max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( - device_capability, {} - ).get(world_size, None) - # Use one shot if no max size for one shot is specified - use_oneshot = ( - max_one_shot_size_mb is None - or current_tensor_size <= max_one_shot_size_mb * MiB - ) - - assert _FI_WORKSPACE_TENSOR is not None, ( - "Flashinfer must be enabled when using flashinfer" - ) - if norm_out is None: - norm_out = allreduce_in - residual_out = residual - else: - # return residual_out as allreduce_out with zeroed residual_in - # as flashinfer does not support rms_norm - # and allreduce_out together - residual_out = allreduce_in - # For the sizes that are smaller than the max size, - # we only use flashinfer one shot allreduce - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - token_num=allreduce_in.shape[0], - residual_in=residual, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - world_rank=world_rank, - world_size=world_size, - hidden_dim=allreduce_in.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=None, - quant_out=quant_out, - scale_out=scale_out, - # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, - scale_factor=scale_factor, - ) + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) + if norm_out is None: + norm_out = allreduce_in + residual_out = residual else: - allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if scale_factor is not None and scale_out is None: - # Do fused rms norm static fp8 quant fused op - if norm_out is None: - torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, - allreduce_out, - residual, - rms_gamma, - scale_factor, - rms_eps, - ) - else: - torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps - ) - else: - if norm_out is None: - torch.ops._C.fused_add_rms_norm( - allreduce_out, residual, rms_gamma, rms_eps - ) - norm_out = allreduce_out - else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) - if scale_factor is not None and scale_out is not None: - torch.ops._C.scaled_fp4_quant( - quant_out, norm_out, scale_out, scale_factor - ) - if scale_factor is None or norm_out is not None: - # we need to return allreduce output - # in cases of non quant fused AR + RMS norm - # and fused AR + RMS norm + quant without fused add - allreduce_in.copy_(allreduce_out) + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) def call_trtllm_fused_allreduce_norm_fake( allreduce_in: torch.Tensor, @@ -1106,11 +1076,15 @@ class AllReduceFusionPass(VllmPatternMatcherPass): self.disabled = True self.tp_size = get_tensor_model_parallel_world_size() if self.tp_size <= 1: + logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") return self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="all_reduce_fusion_pass" ) if config.model_config is None: + logger.warning_once( + "AllReduce fusion pass is disabled for missing model_config." + ) return self.hidden_dim = config.model_config.get_hidden_size() self.group = get_tp_group().device_group @@ -1128,7 +1102,8 @@ class AllReduceFusionPass(VllmPatternMatcherPass): if max_size is None: # Flashinfer doesn't support current world size logger.warning( - "Flashinfer allreduce fusion is not supported for world size %s", + "Flashinfer allreduce fusion is not supported for world size %s" + " or max size is not provided", self.tp_size, ) return @@ -1216,6 +1191,12 @@ class AllReduceFusionPass(VllmPatternMatcherPass): self.disabled = False + def is_applicable_for_range(self, compile_range: Range) -> bool: + if self.disabled: + logger.warning_once("AllReduce fusion pass is disabled.") + return False + return compile_range.end <= self.max_token_num + @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): if self.disabled: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 7deaba1a99fad..ab56d3561c569 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -15,6 +15,7 @@ import torch.fx as fx import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.config.utils import Range from vllm.utils.hashing import safe_hash from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -63,16 +64,16 @@ class CompilerInterface: graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, - with a runtime shape. If the `runtime_shape` is None, it means - the `example_inputs` have a dynamic shape. Otherwise, the - `runtime_shape` specifies the shape of the inputs. Right now we only - support one variable shape for all inputs, which is the batchsize - (number of tokens) during inference. + with a range. The `compile_range` specifies the range of the inputs, + it could be concrete size (if compile_sizes is provided), e.g. [4, 4] + or a range [5, 8]. + Right now we only support one variable in ranges for all inputs, + which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -98,7 +99,7 @@ class CompilerInterface: graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range, ) -> Callable: """ Load the compiled function from the handle. @@ -212,20 +213,20 @@ class InductorStandaloneAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() - if isinstance(runtime_shape, int): + if compile_range.is_single_size(): dynamic_shapes = "from_example_inputs" else: - dynamic_shapes = "from_tracing_context" + dynamic_shapes = "from_graph" from torch._inductor import standalone_compile @@ -235,7 +236,6 @@ class InductorStandaloneAdaptor(CompilerInterface): dynamic_shapes=dynamic_shapes, options={"config_patches": current_config}, ) - # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) @@ -251,7 +251,7 @@ class InductorStandaloneAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -315,7 +315,7 @@ class InductorAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 @@ -329,7 +329,7 @@ class InductorAdaptor(CompilerInterface): current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) set_functorch_config() # inductor can inplace modify the graph, so we need to copy it @@ -512,7 +512,7 @@ class InductorAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: int | None = None, + compile_range: Range, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -608,9 +608,9 @@ class InductorAdaptor(CompilerInterface): return contextlib.nullcontext() -def set_inductor_config(config, runtime_shape): - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters +def set_inductor_config(config, compile_range: Range): + if compile_range.is_single_size(): + # for a specific batch size, tuning triton kernel parameters # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( @@ -630,7 +630,7 @@ class EagerAdaptor(CompilerInterface): graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: int | None = None, + compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index a2e0abfebc2c9..0748643a5299f 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +from collections import Counter from collections.abc import Callable from contextlib import ExitStack from typing import Any @@ -22,6 +23,99 @@ from vllm.utils.torch_utils import weak_ref_tensors logger = init_logger(__name__) +@dataclasses.dataclass(frozen=True) +class CUDAGraphStat: + num_unpadded_tokens: int + num_padded_tokens: int + num_paddings: int + runtime_mode: str + + +class CUDAGraphLogging: + """Aggregate and log cudagraph metrics""" + + COLUMN_HEADERS = [ + "Unpadded Tokens", + "Padded Tokens", + "Num Paddings", + "Runtime Mode", + "Count", + ] + + def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None): + self.reset() + self.cg_mode = str(cg_mode) + self.cg_capture_sizes = str(cg_capture_sizes or []) + + self.settings_header = ( + "**CUDAGraph Config Settings:**\n\n" + f"- Mode: {self.cg_mode}\n" + f"- Capture sizes: {self.cg_capture_sizes}\n\n" + "**CUDAGraph Stats:**\n\n" + ) + + def reset(self): + self.stats = [] + + def observe(self, cudagraph_stat: CUDAGraphStat): + self.stats.append(cudagraph_stat) + + def generate_metric_table(self) -> str: + stats_counts = Counter(self.stats) + + # Convert stats to rows of strings, in descending order of observed frequencies + rows = [] + for stat, count in sorted( + stats_counts.items(), key=lambda item: item[1], reverse=True + ): + rows.append( + [ + str(stat.num_unpadded_tokens), + str(stat.num_padded_tokens), + str(stat.num_paddings), + stat.runtime_mode, + str(count), + ] + ) + + # Calculate column widths (max of header and data) + col_widths = [] + for i, header_text in enumerate(self.COLUMN_HEADERS): + max_width = len(header_text) + for row in rows: + max_width = max(max_width, len(row[i])) + col_widths.append(max_width) + + table_header_list = [ + h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths) + ] + table_header = "| " + " | ".join(table_header_list) + " |\n" + + table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n" + + # Create data rows with proper alignment + data_rows = [] + for row in rows: + formatted_row = [ + str(val).ljust(width) for val, width in zip(row, col_widths) + ] + data_rows.append("| " + " | ".join(formatted_row) + " |") + + return ( + self.settings_header + + table_header + + table_separator + + "\n".join(data_rows) + + "\n" + ) + + def log(self, log_fn=logger.info): + if not self.stats: + return + log_fn(self.generate_metric_table()) + self.reset() + + @dataclasses.dataclass class CUDAGraphEntry: batch_descriptor: BatchDescriptor diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 6d9da1c488c6d..d1ee995ee8959 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import resolve_obj_by_qualname -from vllm.utils.torch_utils import supports_dynamo +from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo from .monitor import start_monitoring_torch_compile @@ -316,7 +316,13 @@ def _support_torch_compile( def _mark_dynamic_inputs(mod, type, *args, **kwargs): def mark_dynamic(arg, dims): if type == DynamicShapesType.UNBACKED: - torch._dynamo.decorators.mark_unbacked(arg, dims) + if is_torch_equal_or_newer("2.10.0.dev"): + for dim in dims: + torch._dynamo.decorators.mark_unbacked( + arg, dim, hint_override=arg.size()[dim] + ) + else: + torch._dynamo.decorators.mark_unbacked(arg, dims) else: torch._dynamo.mark_dynamic(arg, dims) @@ -350,7 +356,13 @@ def _support_torch_compile( if isinstance(arg, torch.Tensor): # In case dims is specified with negative indexing dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] - torch._dynamo.decorators.mark_unbacked(arg, dims) + if is_torch_equal_or_newer("2.10.0.dev"): + for dim in dims: + torch._dynamo.decorators.mark_unbacked( + arg, dim, hint_override=arg.size()[dim] + ) + else: + torch._dynamo.decorators.mark_unbacked(arg, dims) def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation @@ -378,21 +390,12 @@ def _support_torch_compile( serialized backend artifacts), then we need to generate a new AOT compile artifact from scratch. """ - # Validate that AOT compile is not used with unbacked dynamic - # shapes. aot_compile re-allocates backed symbols post dynamo! - if ds_type == DynamicShapesType.UNBACKED: - raise ValueError( - "AOT compilation is not compatible with UNBACKED dynamic shapes. " - "Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type " - "when VLLM_USE_AOT_COMPILE is enabled." - ) from .caching import compilation_config_hash_factors factors: list[str] = compilation_config_hash_factors(self.vllm_config) factors.append(_model_hash_key(self.forward)) hash_key = hashlib.sha256(str(factors).encode()).hexdigest() - cache_dir = os.path.join( envs.VLLM_CACHE_ROOT, "torch_aot_compile", @@ -409,9 +412,12 @@ def _support_torch_compile( open(aot_compilation_path, "rb") as f, ): start_monitoring_torch_compile(self.vllm_config) - loaded_fn = torch.compiler.load_compiled_function(f) + loaded_fn = torch.compiler.load_compiled_function( + f, f_globals=self.forward.__globals__ + ) _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) - loaded_fn.disable_guard_check() + if not self.compilation_config.dynamic_shapes_config.evaluate_guards: + loaded_fn.disable_guard_check() self.aot_compiled_fn = loaded_fn except Exception as e: if os.path.exists(aot_compilation_path): @@ -433,7 +439,6 @@ def _support_torch_compile( return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # This is the path for the first compilation. - # the first compilation needs to have dynamic shapes marked _mark_dynamic_inputs( self, @@ -487,6 +492,12 @@ def _support_torch_compile( if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS: fx_config_patches["backed_size_oblivious"] = True + # Prepare inductor config patches + # assume_32bit_indexing is only available in torch 2.10.0.dev+ + inductor_config_patches = {} + if is_torch_equal_or_newer("2.10.0.dev"): + inductor_config_patches["assume_32bit_indexing"] = True + with ( patch.object( InliningInstructionTranslator, "inline_call_", patched_inline_call @@ -495,6 +506,7 @@ def _support_torch_compile( maybe_use_cudagraph_partition_wrapper(self.vllm_config), torch.fx.experimental._config.patch(**fx_config_patches), _torch27_patch_tensor_subclasses(), + torch._inductor.config.patch(**inductor_config_patches), ): if envs.VLLM_USE_AOT_COMPILE: self.aot_compiled_fn = self.aot_compile(*args, **kwargs) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 126ad35e527ae..2625562aadd36 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -103,6 +103,19 @@ class FixFunctionalizationPass(VllmInductorPass): ]: mutated_args = {1: "result"} self.defunctionalize(graph, node, mutated_args) + elif ( + hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm") + and at_target + == torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default + ): + mutated_args = { + 1: "allreduce_in", + 2: "residual", + 3: "norm_out", + 4: "quant_out", + 5: "scale_out", + } + self.defunctionalize(graph, node, mutated_args) # For some reason we need to specify the args for both # silu_and_mul and silu_and_mul_quant. The kwargs # pathway gets the wrong answer. diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 1d6e297b495eb..d121106334cb9 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -15,6 +15,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, + kFp8Dynamic64Sym, + kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, @@ -24,7 +26,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode -from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm +from .matcher_utils import ( + MatcherFusedAddRMSNorm, + MatcherQuantFP8, + MatcherRMSNorm, +) from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -58,6 +64,9 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default +if current_platform.is_cuda(): + QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 class FusedRMSQuantKey(NamedTuple): @@ -90,11 +99,29 @@ FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { FusedRMSQuantKey( kFp8DynamicTokenSym, True ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic128Sym, False + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic128Sym, True + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic64Sym, False + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic64Sym, True + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 } class RMSNormQuantPattern: - def __init__(self, epsilon: float, key: FusedRMSQuantKey): + def __init__( + self, + epsilon: float, + key: FusedRMSQuantKey, + has_col_major_scales: bool = False, + is_e8m0: bool = False, + ): self.epsilon = epsilon self.quant_dtype = key.quant.dtype config = get_current_vllm_config() @@ -108,7 +135,9 @@ class RMSNormQuantPattern: if not key.fused_add else MatcherFusedAddRMSNorm(epsilon) ) - self.quant_matcher = MatcherQuantFP8(key.quant) + self.quant_matcher = MatcherQuantFP8( + key.quant, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 + ) class RMSNormStaticQuantPattern(RMSNormQuantPattern): @@ -218,6 +247,128 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): ) +class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + symmetric=True, + has_col_major_scales: bool = False, + is_e8m0: bool = False, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + self.group_shape = group_shape + self.has_col_major_scales = has_col_major_scales + self.is_e8m0 = is_e8m0 + super().__init__( + epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 + ) + + def register(self, pm_pass: PatternMatcherPass): + def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) + return result, residual, scale + + def replacement( + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor + ): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input, self.has_col_major_scales) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual, + group_size=self.group_shape[1], + is_scale_transposed=self.has_col_major_scales, + ) + + # result, residual, scale + return at[1], at[3], at[2] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + +class RMSNormGroupQuantPattern(RMSNormQuantPattern): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + symmetric=True, + has_col_major_scales: bool = False, + is_e8m0: bool = False, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + self.group_shape = group_shape + super().__init__( + epsilon, key, has_col_major_scales=has_col_major_scales, is_e8m0=is_e8m0 + ) + + def register(self, pm_pass: PatternMatcherPass): + def pattern(input: torch.Tensor, weight: torch.Tensor): + result_rms = self.rmsnorm_matcher(input, weight) + result, scale = self.quant_matcher(result_rms) + return result, scale + + def replacement(input: torch.Tensor, weight: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale( + input, transposed=self.quant_matcher.has_col_major_scales + ) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None, + group_size=self.group_shape[1], + is_scale_transposed=self.quant_matcher.has_col_major_scales, + ) + + # result, scale + return at[1], at[2] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + class RMSNormDynamicQuantPattern(RMSNormQuantPattern): def __init__( self, @@ -356,6 +507,29 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): # Fuse rms_norm + dynamic per-token fp8 quant RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Only register group quant patterns on CUDA where the C++ op exists + if current_platform.is_cuda(): + for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]: + for has_col_major_scales in [True, False]: + for is_e8m0 in [True, False]: + # Fuse fused_add_rms_norm + fp8 group quant + FusedAddRMSNormGroupQuantPattern( + epsilon, + FP8_DTYPE, + group_shape=group_shape, + has_col_major_scales=has_col_major_scales, + is_e8m0=is_e8m0, + ).register(self.patterns) + + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, + FP8_DTYPE, + group_shape=group_shape, + has_col_major_scales=has_col_major_scales, + is_e8m0=is_e8m0, + ).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log @@ -366,9 +540,11 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass): def uuid(self) -> Any: return self.hash_source( self, + RMSNormGroupQuantPattern, RMSNormQuantPattern, RMSNormStaticQuantPattern, RMSNormDynamicQuantPattern, FusedAddRMSNormStaticQuantPattern, FusedAddRMSNormDynamicQuantPattern, + FusedAddRMSNormGroupQuantPattern, ) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index f2497950fc22f..3650ee6b41745 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -75,8 +75,8 @@ def find_op_nodes( return assert isinstance(op, OpOverload) - if not op._schema.is_mutable: - yield from graph.find_nodes(op="call_function", target=op) + + yield from graph.find_nodes(op="call_function", target=op) for n in graph.find_nodes(op="call_function", target=auto_functionalized): if n.args[0] == op: diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 9af635a929b4b..dbf154eeb86a4 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import functools import hashlib import inspect @@ -8,7 +10,7 @@ import json import types from collections.abc import Callable from contextlib import contextmanager -from typing import Any +from typing import TYPE_CHECKING, Any import torch from torch import fx @@ -16,6 +18,9 @@ from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily from vllm.utils.torch_utils import is_torch_equal_or_newer +if TYPE_CHECKING: + from vllm.config.utils import Range + if is_torch_equal_or_newer("2.6"): from torch._inductor.custom_graph_pass import CustomGraphPass else: @@ -28,8 +33,8 @@ _pass_context = None class PassContext: - def __init__(self, runtime_shape: int | None): - self.runtime_shape = runtime_shape + def __init__(self, compile_range: Range): + self.compile_range: Range = compile_range def get_pass_context() -> PassContext: @@ -39,13 +44,13 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(runtime_shape: int | None): +def pass_context(compile_range: Range): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ global _pass_context prev_context = _pass_context - _pass_context = PassContext(runtime_shape) + _pass_context = PassContext(compile_range) try: yield finally: @@ -96,7 +101,7 @@ class InductorPass(CustomGraphPass): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable(self, shape: int | None): + def is_applicable_for_range(self, compile_range: Range): return True diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index e4cd063d2aee1..ec9ed34f561b4 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -13,6 +13,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, _normalize_quant_group_shape, + kFp8Dynamic64Sym, + kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, @@ -35,6 +37,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = { if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 +if current_platform.is_cuda(): + QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 + SILU_MUL_OP = torch.ops._C.silu_and_mul.default @@ -224,7 +230,13 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): class MatcherQuantFP8(MatcherCustomOp): - def __init__(self, quant_key: QuantKey, enabled: bool | None = None): + def __init__( + self, + quant_key: QuantKey, + enabled: bool | None = None, + has_col_major_scales: bool = False, + is_e8m0: bool = False, + ): if enabled is None: enabled = QuantFP8.enabled() @@ -233,11 +245,19 @@ class MatcherQuantFP8(MatcherCustomOp): assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] + self.has_col_major_scales = has_col_major_scales + self.is_e8m0 = is_e8m0 + assert quant_key.dtype == current_platform.fp8_dtype(), ( "Only QuantFP8 supported by" ) assert quant_key.scale2 is None - self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) + self.quant_fp8 = QuantFP8( + quant_key.scale.static, + quant_key.scale.group_shape, + column_major_scales=has_col_major_scales, + use_ue8m0=is_e8m0, + ) def forward_custom( self, @@ -248,6 +268,27 @@ class MatcherQuantFP8(MatcherCustomOp): input.shape, device=input.device, dtype=self.quant_key.dtype ) + if self.quant_key.scale.group_shape.is_per_group(): + assert scale is None + scale = self.make_scale(input, transposed=self.has_col_major_scales) + + finfo = torch.finfo(self.quant_key.dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + _, result, scale = auto_functionalized( + self.QUANT_OP, + input=input, + output_q=result, + output_s=scale, + group_size=self.quant_key.scale.group_shape[1], + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + scale_ue8m0=self.is_e8m0, + ) + return result, scale + if self.quant_key.scale.static: assert scale is not None _, result = auto_functionalized( @@ -269,7 +310,7 @@ class MatcherQuantFP8(MatcherCustomOp): ) -> tuple[torch.Tensor, torch.Tensor]: return self.quant_fp8(input, scale) - def make_scale(self, input: torch.Tensor): + def make_scale(self, input: torch.Tensor, transposed: bool = False): normalized_group_shape = _normalize_quant_group_shape( input, self.quant_key.scale.group_shape ) @@ -277,6 +318,11 @@ class MatcherQuantFP8(MatcherCustomOp): input.shape[0] // normalized_group_shape[0], input.shape[1] // normalized_group_shape[1], ) + if transposed: + scale_shape = tuple(reversed(scale_shape)) + return torch.empty( + scale_shape, device=input.device, dtype=torch.float32 + ).permute(-1, -2) return torch.empty(scale_shape, device=input.device, dtype=torch.float32) diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 42b8d3daac985..06e1771bac960 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -5,6 +5,7 @@ from collections.abc import Iterable import torch.fx from torch import SymInt +from torch.fx.experimental.symbolic_shapes import statically_known_true from vllm.logger import init_logger @@ -116,12 +117,7 @@ class NoOpEliminationPass(VllmInductorPass): 2. The dimensions both correspond to the same SymInt """ # Case 1 - if isinstance(i_dim, int) and isinstance(dim, int): - return dim == i_dim - # Case 2 - if isinstance(i_dim, SymInt) and isinstance(dim, SymInt): - return dim == i_dim - return False + return statically_known_true(dim == i_dim) def all_dims_equivalent( self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt] diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index fe2547d7fecaf..4ebb386f75ed8 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,6 +5,7 @@ import functools from torch import fx as fx from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform @@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var from .post_cleanup import PostCleanupPass from .vllm_inductor_pass import VllmInductorPass +if rocm_aiter_ops.is_enabled(): + from vllm.compilation.rocm_aiter_fusion import ( + RocmAiterRMSNormFp8GroupQuantFusionPass, + RocmAiterSiluMulFp8GroupQuantFusionPass, + ) + if current_platform.is_cuda_alike(): from .activation_quant_fusion import ActivationQuantFusionPass from .fusion import RMSNormQuantFusionPass @@ -24,7 +31,11 @@ if current_platform.is_cuda(): from .collective_fusion import AllReduceFusionPass, AsyncTPPass from .fix_functionalization import FixFunctionalizationPass -from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context +from .inductor_pass import ( + CustomGraphPass, + InductorPass, + get_pass_context, +) from .noop_elimination import NoOpEliminationPass logger = init_logger(__name__) @@ -70,13 +81,13 @@ class PostGradPassManager(CustomGraphPass): def __call__(self, graph: fx.Graph): VllmInductorPass.dump_prefix = 0 # reset dump index - shape = get_pass_context().runtime_shape + compile_range = get_pass_context().compile_range for pass_ in self.passes: - if pass_.is_applicable(shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) VllmInductorPass.dump_prefix += 1 else: - logger.debug("Skipping %s with shape %s", pass_, shape) + logger.debug("Skipping %s with compile range %s", pass_, compile_range) # post-cleanup goes before fix_functionalization # because it requires a functional graph @@ -92,22 +103,27 @@ class PostGradPassManager(CustomGraphPass): # Set the current vllm config to allow tracing CustomOp instances with set_current_vllm_config(config, check_compile=False): - if self.pass_config.enable_noop: + if self.pass_config.eliminate_noops: self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_sequence_parallelism: + if self.pass_config.enable_sp: self.passes += [SequenceParallelismPass(config)] - if self.pass_config.enable_async_tp: + if self.pass_config.fuse_gemm_comms: self.passes += [AsyncTPPass(config)] - if self.pass_config.enable_fi_allreduce_fusion: + if self.pass_config.fuse_allreduce_rms: self.passes += [AllReduceFusionPass(config)] - if self.pass_config.enable_fusion: + if self.pass_config.fuse_norm_quant: self.passes += [RMSNormQuantFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)] + if self.pass_config.fuse_act_quant: self.passes += [ActivationQuantFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] - if self.pass_config.enable_attn_fusion: + if self.pass_config.fuse_attn_quant: self.passes += [AttnFusionPass(config)] if self.pass_config.enable_qk_norm_rope_fusion: @@ -132,4 +148,8 @@ class PostGradPassManager(CustomGraphPass): state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) + # Include the compile range in the uuid to ensure that inductor + # recompiles the graph for the new dynamic compile range. + state["compile_range"] = str(get_pass_context().compile_range) + return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index e535d2c461c6e..58d3e2a14b22a 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -7,18 +7,18 @@ from typing import Any import torch.fx as fx -import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.logger import init_logger logger = init_logger(__name__) @dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int +class RangeEntry: + compile_range: Range compiled: bool = False runnable: Callable = None # type: ignore @@ -31,7 +31,6 @@ class PiecewiseBackend: piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend, ): """ @@ -54,68 +53,126 @@ class PiecewiseBackend: self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_full_graph = total_piecewise_compiles == 1 + self.is_encoder_compilation = vllm_backend.is_encoder - self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) + self.compile_ranges = self.compilation_config.get_compile_ranges() + if self.is_encoder_compilation: + # For encoder compilation we use the max int32 value + # to set the upper bound of the compile ranges + max_int32 = 2**31 - 1 + last_compile_range = self.compile_ranges[-1] + assert ( + last_compile_range.end + == vllm_config.scheduler_config.max_num_batched_tokens + ) + self.compile_ranges[-1] = Range( + start=last_compile_range.start, end=max_int32 + ) - self.first_run_finished = False + log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" + logger.debug_once(log_string) - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + self.compile_sizes = self.compilation_config.compile_sizes + log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}" + logger.debug_once(log_string) self.sym_shape_indices = sym_shape_indices - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + # the entries for ranges that we need to either + self.range_entries: dict[Range, RangeEntry] = {} - # the entries for different shapes that we need to compile - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} - - # to_be_compiled_sizes tracks the remaining sizes to compile, + # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) # We only keep compilation management inside this class directly. - for shape in self.compile_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - runnable=self.compiled_graph_for_general_shape, + for size in self.compile_sizes: + range = Range(start=size, end=size) + if range not in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, + ) + self.to_be_compiled_ranges.add(range) + + for range in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, ) def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: + if self.is_last_graph and not self.to_be_compiled_ranges: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) + def _fakify_args(self, args: list[Any]) -> list[Any]: + # We need to pass fake example_inputs, otherwise torch.compile + # will fakify the example_inputs potentially causing some non dynamic + # dimension to be be duck shaped to other existing shapes that have hints + # matching their values. + # This is problem because it can lead to unintended specializations! + # if the new wrongly dynamic dim is specialized + # it will force specializing the whole shape + # torch.compile probably should not accept + # non fake tensors as example inputs! + # See issue https://github.com/vllm-project/vllm/issues/27899 + fake_example_inputs = [] + for node in self.graph.graph.nodes: + # All place holders come first + if node.op == "placeholder": + fake_example_inputs.append(node.meta["example_value"]) + else: + break + assert len(fake_example_inputs) == len(args) + return fake_example_inputs - runtime_shape = args[self.sym_shape_indices[0]] + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: + if not range_entry.compiled: + range_entry.compiled = True + self.to_be_compiled_ranges.remove(range_entry.compile_range) - if runtime_shape not in self.concrete_size_entries: - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - - entry = self.concrete_size_entries[runtime_shape] - - if not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( + # fakify for range, real args for concrete size. + # For concrete size, we clear the shape env in + # compiler_manager.compile() so no need to fakify. + args = ( + self._fakify_args(args) + if not range_entry.compile_range.is_single_size() + else args + ) + range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.vllm_backend.inductor_config, self.compilation_config, + compile_range=range_entry.compile_range, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, ) - # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: - self.check_for_ending_compilation() + self.check_for_ending_compilation() - return entry.runnable(*args) + def _find_range_for_shape(self, runtime_shape: int) -> Range | None: + # First we try to find the range entry for the concrete compile size + # If not found, we search for the range entry + # that contains the runtime shape. + if runtime_shape in self.compile_sizes: + return self.range_entries[Range(start=runtime_shape, end=runtime_shape)] + else: + for range in self.compile_ranges: + if runtime_shape in range: + return self.range_entries[range] + return None + + def __call__(self, *args) -> Any: + runtime_shape = args[self.sym_shape_indices[0]] + range_entry = self._find_range_for_shape(runtime_shape) + + assert range_entry is not None, ( + f"Shape out of considered range: {runtime_shape} " + "[1, max_num_batched_tokens]" + ) + + self._maybe_compile_for_range_entry(range_entry, args) + return range_entry.runnable(*args) diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/rocm_aiter_fusion.py new file mode 100644 index 0000000000000..8b5db9de38181 --- /dev/null +++ b/vllm/compilation/rocm_aiter_fusion.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload + +import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 +from vllm.compilation.activation_quant_fusion import ActivationQuantPattern +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .fusion import empty_bf16 +from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherSiluAndMul +from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() + +AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default +AITER_RMS_ADD_GROUP_QUANT_OP = ( + torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default +) + +AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default +AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default + +AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default +TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default + +FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default + + +class AiterRMSFp8GroupQuantPattern: + """ + This pattern fuses aiter rms_norm & group fp8 quant custom + ops into an aiter rms_norm_group_fp8_quant op. + """ + + def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + ): + at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon) + + at2 = self.quant_op(at1, 128) + + return at2[0], at2[1] + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + ): + at = AITER_RMS_GROUP_QUANT_OP( + x=input, + weight=weight, + variance_epsilon=self.epsilon, + group_size=128, + ) + + return at[0], at[1] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class AiterFusedAddRMSFp8GroupQuantPattern: + """ + This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops + into a aiter rms_norm_with_add_group_fp8_quant op. + """ + + def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload): + self.epsilon = epsilon + self.quant_dtype = quant_dtype + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + ): + at1 = AITER_RMS_ADD_OP( + x=input, + residual=residual, + weight=weight, + variance_epsilon=self.epsilon, + ) + + at2 = self.quant_op(at1[0], 128) + + # result, scale, residual + return at2[0], at2[1], at1[1] + + def replacement( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + ): + at = AITER_RMS_ADD_GROUP_QUANT_OP( + x=input, + residual=residual, + weight=weight, + variance_epsilon=self.epsilon, + group_size=128, + ) + + # result, scale, residual + return at[0], at[1], at[2] + + inputs = [ + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass): + """ + This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. + It also supports fused_add_rms_norm. + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass" + ) + + # Make sure fused add patterns are before simple rms norm, + # as the latter is a subset of the former in torch ops + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + dynamic group fp8 quant + for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: + AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register( + self.patterns + ) + + AiterFusedAddRMSFp8GroupQuantPattern( + epsilon, FP8_DTYPE, quant_op + ).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self) -> Any: + fusion_patterns = [ + AiterRMSFp8GroupQuantPattern, + AiterFusedAddRMSFp8GroupQuantPattern, + ] + return self.hash_source(self, *fusion_patterns) + + +class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): + """ + This pattern fuses aiter silu_and_mul & group fp8 quant custom + ops into an aiter silu_and_mul_group_fp8_quant op. + """ + + def __init__(self, quant_op: OpOverload): + self.silu_and_mul_matcher = MatcherSiluAndMul() + self.quant_op = quant_op + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + input: torch.Tensor, + ): + at1 = self.silu_and_mul_matcher(input) + at2 = self.quant_op(at1, 128) + return at2[0], at2[1] + + def replacement( + input: torch.Tensor, + ): + at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128) + return at[0], at[1] + + inputs = [ + self.silu_and_mul_matcher.inputs()[0], + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) + + +class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass" + ) + + for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]: + AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: torch.fx.Graph): + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def uuid(self): + fusion_patterns = [ + ActivationQuantPattern, + AiterSiluMulFp8GroupQuantPattern, + ] + return VllmInductorPass.hash_source(self, *fusion_patterns) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index cf4b8118f6b5c..a4046356bcda0 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -9,6 +9,7 @@ import torch.fx as fx from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig +from vllm.config.compilation import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -333,7 +334,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) - def is_applicable(self, shape: int | None) -> bool: + def is_applicable_for_range(self, compile_range: Range) -> bool: # When sequence parallelism is enabled, the residual tensor from RMSNorm # needs to be split along the sequence dimension. However, this dimension # is symbolic during piecewise compilation, and splitting symbolic shapes @@ -353,7 +354,7 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ): return True tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph): diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index b120c85bf232e..02e974b0f9e8c 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -4,7 +4,7 @@ import os import sys from abc import abstractmethod -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from types import CodeType from typing import Any @@ -13,7 +13,9 @@ import torch._C._dynamo.guards import vllm.envs as envs from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config +from vllm.config.compilation import DynamicShapesType from vllm.logger import init_logger +from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context logger = init_logger(__name__) @@ -92,12 +94,29 @@ class TorchCompileWithNoGuardsWrapper: return self.forward(*args, **kwargs) + def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs): + if self.layerwise_nvtx_tracing_enabled: + args_list = list(args) + kwargs_dict = dict(kwargs) + with layerwise_nvtx_marker_context( + "Torch Compiled Module (input):{}".format(self.__class__.__name__), + self, + in_tensor=args_list, + kwargs=kwargs_dict, + ) as ctx: + ctx.result = callable_fn(*args, **kwargs) + return ctx.result + return callable_fn(*args, **kwargs) + def __init__(self): self.compiled = False vllm_config = get_current_vllm_config() self.vllm_config = vllm_config mode = vllm_config.compilation_config.mode + self.layerwise_nvtx_tracing_enabled = ( + vllm_config.observability_config.enable_layerwise_nvtx_tracing + ) if mode is None: raise RuntimeError("Compilation mode cannot be NO_COMPILATION") @@ -107,41 +126,69 @@ class TorchCompileWithNoGuardsWrapper: if isinstance(backend, str) and backend == "inductor": options = vllm_config.compilation_config.inductor_compile_config - if mode != CompilationMode.STOCK_TORCH_COMPILE: - # Drop all the guards. - options["guard_filter_fn"] = lambda x: [False for _ in x] - - # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False - from vllm.compilation.decorators import DynamicShapesType + self.first_compile = True + self.evaluate_guards = ( + vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards + ) ds_type = vllm_config.compilation_config.dynamic_shapes_config.type - compiled_ptr: Any = self.forward - if ds_type == DynamicShapesType.UNBACKED: - if envs.VLLM_USE_BYTECODE_HOOK: - # reason is that bytecode does this hack torch._dynamo.eval_frame. - # remove_from_cache(self.original_code_object()) to force a new - # re-compilation. - raise ValueError( - "UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. " + + if mode != CompilationMode.STOCK_TORCH_COMPILE: + # Drop all the guards. + if self.evaluate_guards: + assert not envs.VLLM_USE_BYTECODE_HOOK, ( + "compilation_config.dynamic_shapes_config.evaluate_guards " + "requires VLLM_USE_BYTECODE_HOOK=0. " ) + + if envs.VLLM_USE_AOT_COMPILE: + # disabled until https://github.com/pytorch/pytorch/pull/169239 + # is picked up. + assert ds_type != DynamicShapesType.BACKED, ( + "evaluate_guards for backed shapes requires " + "VLLM_USE_AOT_COMPILE=False. " + ) + + options["guard_filter_fn"] = lambda x: [ + entry.guard_type == "SHAPE_ENV" for entry in x + ] + else: + options["guard_filter_fn"] = lambda x: [False for _ in x] + + compiled_ptr: Any = self.forward + # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False + + if ds_type == DynamicShapesType.UNBACKED: + # reason is that bytecode does torch._dynamo.eval_frame. + # remove_from_cache(self.original_code_object()) to force a new + # re-compilation. And if we use + # compiled_ptr = self.check_invariants_and_forward + # it will reset all entries. + assert not envs.VLLM_USE_BYTECODE_HOOK, ( + "UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. " + ) + assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards" + compiled_ptr = self.check_invariants_and_forward + aot_context = nullcontext() if envs.VLLM_USE_AOT_COMPILE: if hasattr(torch._dynamo.config, "enable_aot_compile"): - torch._dynamo.config.enable_aot_compile = True + aot_context = torch._dynamo.config.patch(enable_aot_compile=True) else: msg = "torch._dynamo.config.enable_aot_compile is not " msg += "available. AOT compile is disabled and please " msg += "upgrade PyTorch version to use AOT compile." logger.warning(msg) - self._compiled_callable = torch.compile( - compiled_ptr, - fullgraph=True, - dynamic=False, - backend=backend, - options=options, - ) + with aot_context: + self._compiled_callable = torch.compile( + compiled_ptr, + fullgraph=True, + dynamic=False, + backend=backend, + options=options, + ) if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE: torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) @@ -168,13 +215,25 @@ class TorchCompileWithNoGuardsWrapper: # Make sure a compilation is triggered by clearing dynamo # cache. torch._dynamo.eval_frame.remove_from_cache(self.original_code_object()) - return self._compiled_callable(*args, **kwargs) + return self._call_with_optional_nvtx_range( + self._compiled_callable, *args, **kwargs + ) else: with self._dispatch_to_compiled_code(): - return self.forward(*args, **kwargs) + return self._call_with_optional_nvtx_range( + self.forward, *args, **kwargs + ) else: - with _compilation_context(): - return self._compiled_callable(*args, **kwargs) + ctx = ( + nullcontext() + if self.first_compile or not self.evaluate_guards + else torch.compiler.set_stance("fail_on_recompile") + ) + self.first_compile = False + with _compilation_context(), ctx: + return self._call_with_optional_nvtx_range( + self._compiled_callable, *args, **kwargs + ) @abstractmethod def forward(self, *args, **kwargs): ... diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index dd76a722106ef..0e91dd57420a8 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.config.attention import AttentionConfig from vllm.config.cache import CacheConfig from vllm.config.compilation import ( CompilationConfig, @@ -23,6 +24,7 @@ from vllm.config.multimodal import MultiModalConfig from vllm.config.observability import ObservabilityConfig from vllm.config.parallel import EPLBConfig, ParallelConfig from vllm.config.pooler import PoolerConfig +from vllm.config.profiler import ProfilerConfig from vllm.config.scheduler import SchedulerConfig from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig @@ -46,6 +48,8 @@ from vllm.config.vllm import ( # __all__ should only contain classes and functions. # Types and globals should be imported from their respective modules. __all__ = [ + # From vllm.config.attention + "AttentionConfig", # From vllm.config.cache "CacheConfig", # From vllm.config.compilation @@ -86,6 +90,8 @@ __all__ = [ "SpeechToTextConfig", # From vllm.config.structured_outputs "StructuredOutputsConfig", + # From vllm.config.profiler + "ProfilerConfig", # From vllm.config.utils "ConfigType", "SupportsMetricsInfo", diff --git a/vllm/config/attention.py b/vllm/config/attention.py new file mode 100644 index 0000000000000..dd62d88826bd6 --- /dev/null +++ b/vllm/config/attention.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Literal + +from pydantic import field_validator +from pydantic.dataclasses import dataclass + +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config.utils import config +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@config +@dataclass +class AttentionConfig: + """Configuration for attention mechanisms in vLLM.""" + + backend: AttentionBackendEnum | None = None + """Attention backend to use. If None, will be selected automatically.""" + + flash_attn_version: Literal[2, 3] | None = None + """Force vllm to use a specific flash-attention version (2 or 3). + Only valid when using the flash-attention backend.""" + + use_prefill_decode_attention: bool = False + """Use separate prefill and decode kernels for attention instead of + the unified triton kernel.""" + + flash_attn_max_num_splits_for_cuda_graph: int = 32 + """Flash Attention max number splits for cuda graph decode.""" + + use_cudnn_prefill: bool = False + """Whether to use cudnn prefill.""" + + use_trtllm_ragged_deepseek_prefill: bool = False + """Whether to use TRTLLM ragged deepseek prefill.""" + + use_trtllm_attention: bool | None = None + """If set to True/False, use or don't use the TRTLLM attention backend + in flashinfer. If None, auto-detect the attention backend in flashinfer.""" + + disable_flashinfer_prefill: bool = False + """Whether to disable flashinfer prefill.""" + + disable_flashinfer_q_quantization: bool = False + """If set, when using fp8 kv, do not quantize Q to fp8.""" + + def compute_hash(self) -> str: + """ + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + from vllm.config.utils import get_hash_factors, hash_factors + + ignored_factors: list[str] = [] + factors = get_hash_factors(self, ignored_factors) + return hash_factors(factors) + + @field_validator("backend", mode="before") + @classmethod + def validate_backend_before(cls, value: Any) -> Any: + """Enable parsing of the `backend` enum type from string.""" + if isinstance(value, str): + return AttentionBackendEnum[value.upper()] + return value + + def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None: + """Set field from env var if set, with deprecation warning.""" + from vllm import envs + + if envs.is_set(env_var_name): + value = getattr(envs, env_var_name) + if field_name == "backend": + value = self.validate_backend_before(value) + setattr(self, field_name, value) + logger.warning_once( + "Using %s environment variable is deprecated and will be removed in " + "v0.14.0 or v1.0.0, whichever is soonest. Please use " + "--attention-config.%s command line argument or " + "AttentionConfig(%s=...) config field instead.", + env_var_name, + field_name, + field_name, + ) + + def __post_init__(self) -> None: + self._set_from_env_if_set("backend", "VLLM_ATTENTION_BACKEND") + self._set_from_env_if_set("flash_attn_version", "VLLM_FLASH_ATTN_VERSION") + self._set_from_env_if_set( + "use_prefill_decode_attention", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION" + ) + self._set_from_env_if_set( + "flash_attn_max_num_splits_for_cuda_graph", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", + ) + self._set_from_env_if_set("use_cudnn_prefill", "VLLM_USE_CUDNN_PREFILL") + self._set_from_env_if_set( + "use_trtllm_ragged_deepseek_prefill", + "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", + ) + self._set_from_env_if_set("use_trtllm_attention", "VLLM_USE_TRTLLM_ATTENTION") + self._set_from_env_if_set( + "disable_flashinfer_prefill", "VLLM_DISABLE_FLASHINFER_PREFILL" + ) + self._set_from_env_if_set( + "disable_flashinfer_q_quantization", + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", + ) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 00530846fce00..067799a44db30 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -29,8 +29,8 @@ CacheDType = Literal[ "fp8_inc", "fp8_ds_mla", ] -MambaDType = Literal["auto", "float32"] -PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] +MambaDType = Literal["auto", "float32", "float16"] +PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] @@ -77,9 +77,21 @@ class CacheConfig: """Whether to enable prefix caching.""" prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" """Set the hash algorithm for prefix caching:\n - - "sha256" uses Pickle for object serialization before hashing.\n + - "sha256" uses Pickle for object serialization before hashing. This is the + current default, as SHA256 is the most secure choice to avoid potential + hash collisions.\n - "sha256_cbor" provides a reproducible, cross-language compatible hash. It - serializes objects using canonical CBOR and hashes them with SHA-256.""" + serializes objects using canonical CBOR and hashes them with SHA-256.\n + - "xxhash" uses Pickle serialization with xxHash (128-bit) for faster, + non-cryptographic hashing. Requires the optional ``xxhash`` package. + IMPORTANT: Use of a hashing algorithm that is not considered + cryptographically secure theoretically increases the risk of hash collisions, + which can cause undefined behavior or even leak private information in + multi-tenant environments. Even if collisions are still very unlikely, it is + important to consider your security risk tolerance against the performance + benefits before turning this on.\n + - "xxhash_cbor" combines canonical CBOR serialization with xxHash for + reproducible hashing. Requires the optional ``xxhash`` package.""" cpu_offload_gb: float = Field(default=0, ge=0) """The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index da2c100dae3dc..4a98494b3c7b3 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -4,16 +4,21 @@ import enum from collections import Counter from collections.abc import Callable -from dataclasses import asdict, field +from dataclasses import field from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Literal -from pydantic import Field, TypeAdapter, field_validator +from pydantic import ConfigDict, Field, TypeAdapter, field_validator from pydantic.dataclasses import dataclass import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config.utils import config +from vllm.config.utils import ( + Range, + config, + get_hash_factors, + hash_factors, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -91,7 +96,7 @@ class CUDAGraphMode(enum.Enum): @config -@dataclass +@dataclass(config=ConfigDict(extra="forbid")) class PassConfig: """Configuration for custom Inductor passes. @@ -105,18 +110,22 @@ class PassConfig: improper state. """ - enable_fusion: bool = Field(default=None) - """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" - enable_attn_fusion: bool = Field(default=None) - """Whether to enable the custom attention+quant fusion pass.""" - enable_noop: bool = Field(default=None) - """Whether to enable the custom no-op elimination pass.""" - enable_sequence_parallelism: bool = Field(default=None) - """Whether to enable sequence parallelism.""" - enable_async_tp: bool = Field(default=None) - """Whether to enable async TP.""" - enable_fi_allreduce_fusion: bool = Field(default=None) - """Whether to enable flashinfer allreduce fusion.""" + # New flags + fuse_norm_quant: bool = Field(default=None) + """Fuse the custom RMSNorm + quant ops.""" + fuse_act_quant: bool = Field(default=None) + """Fuse the custom SiluMul + quant ops.""" + fuse_attn_quant: bool = Field(default=None) + """Fuse the custom attention + quant ops.""" + eliminate_noops: bool = Field(default=None) + """Eliminate no-op ops.""" + enable_sp: bool = Field(default=None) + """Enable sequence parallelism.""" + fuse_gemm_comms: bool = Field(default=None) + """Enable async TP.""" + fuse_allreduce_rms: bool = Field(default=None) + """Enable flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_size_mb: float | None = None """The threshold of the communicated tensor sizes under which vllm should use flashinfer fused allreduce. Specified as a @@ -136,7 +145,7 @@ class PassConfig: }, }, where key is the device capability""" enable_qk_norm_rope_fusion: bool = False - """Whether to enable the fused Q/K RMSNorm + RoPE pass.""" + """Enable fused Q/K RMSNorm + RoPE pass.""" # TODO(luka) better pass enabling system. @@ -148,6 +157,9 @@ class PassConfig: """ MiB = 1024 * 1024 + FI_SUPPORTED_WORLD_SIZES = [2, 4, 8] + if world_size not in FI_SUPPORTED_WORLD_SIZES: + return None max_size_mb = self.fi_allreduce_fusion_max_size_mb if max_size_mb is None: max_size_mb = self.default_fi_allreduce_fusion_max_size_mb().get(world_size) @@ -171,15 +183,17 @@ class PassConfig: Any new fields that affect compilation should be added to the hash. Any future fields that don't affect compilation should be excluded. """ - return InductorPass.hash_dict(asdict(self)) + + return hash_factors(get_hash_factors(self, set())) @field_validator( - "enable_fusion", - "enable_attn_fusion", - "enable_noop", - "enable_sequence_parallelism", - "enable_async_tp", - "enable_fi_allreduce_fusion", + "fuse_norm_quant", + "fuse_act_quant", + "fuse_attn_quant", + "eliminate_noops", + "enable_sp", + "fuse_gemm_comms", + "fuse_allreduce_rms", mode="wrap", ) @classmethod @@ -190,18 +204,20 @@ class PassConfig: return handler(value) def __post_init__(self) -> None: - if not self.enable_noop: - if self.enable_fusion: + # Handle deprecation and defaults + + if not self.eliminate_noops: + if self.fuse_norm_quant or self.fuse_act_quant: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "RMSNorm/SiluMul + quant (fp8) fusion might not work" ) - if self.enable_attn_fusion: + if self.fuse_attn_quant: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work" ) - if self.enable_fi_allreduce_fusion: + if self.fuse_allreduce_rms: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "Allreduce + rms norm + quant (fp8) fusion might not work" @@ -235,7 +251,7 @@ class DynamicShapesType(str, enum.Enum): @config -@dataclass +@dataclass(config=ConfigDict(extra="forbid")) class DynamicShapesConfig: """Configuration to control/debug torch compile dynamic shapes.""" @@ -249,7 +265,18 @@ class DynamicShapesConfig: backed/unbacked. """ - # TODO add a debug mode to fail + evaluate_guards: bool = False + """ + A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by + guarding on it. When True, dynamic shape guards are not dropped from dynamo. + And a failure will be triggered if a recompilation ever happens due to that. + This mode requires VLLM_USE_BYTECODE_HOOK to be 0. + Enabling this allow observing the dynamic shapes guards in the tlparse + artifacts also. + When type is backed, aot_compile must be disabled for this mode to work. + until this change picked up https://github.com/pytorch/pytorch/pull/169239. + + """ def compute_hash(self) -> str: """ @@ -263,7 +290,7 @@ class DynamicShapesConfig: @config -@dataclass +@dataclass(config=ConfigDict(extra="forbid")) class CompilationConfig: """Configuration for compilation. @@ -293,6 +320,8 @@ class CompilationConfig: [vllm.config.CompilationConfig.cudagraph_copy_inputs] - Inductor compilation: - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`compile_ranges_split_points`] + [vllm.config.CompilationConfig.compile_ranges_split_points] - [`inductor_compile_config`] [vllm.config.CompilationConfig.inductor_compile_config] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] @@ -358,8 +387,8 @@ class CompilationConfig: We use string to avoid serialization issues when using compilation in a distributed setting. When the compilation mode is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the - compilation mode is 3, the backend supports both whole graph and piecewise - compilation, available backends include eager, inductor, and custom backends, + compilation mode is 3, the backend supports both whole graph and piecewise + compilation, available backends include eager, inductor, and custom backends, the latter of which can be defined via `get_compile_backend`. Furthermore, compilation is only piecewise if splitting ops is set accordingly and use_inductor_graph_partition is off. Note that the default options for @@ -406,6 +435,21 @@ class CompilationConfig: to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" + compile_ranges_split_points: list[int] | None = None + """Split points that represent compile ranges for inductor. + The compile ranges are + [1, split_points[0]], + [split_points[0] + 1, split_points[1]], ..., + [split_points[-1] + 1, max_num_batched_tokens]. + Compile sizes are also used single element ranges, + the range is represented as [compile_sizes[i], compile_sizes[i]]. + + If a range overlaps with the compile size, graph for compile size + will be prioritized, i.e. if we have a range [1, 8] and a compile size 4, + graph for compile size 4 will be compiled and used instead of the graph + for range [1, 8]. + """ + inductor_compile_config: dict = field(default_factory=dict) """Additional configurations for inductor. - None: use default configurations.""" @@ -854,7 +898,16 @@ class CompilationConfig: # May get recomputed in the model runner if adjustment is needed for spec-decode self.compute_bs_to_padded_graph_size() - def set_splitting_ops_for_v1(self): + def set_splitting_ops_for_v1( + self, all2all_backend: str | None = None, data_parallel_size: int | None = None + ): + # To compatible with OOT hardware plugin platform (for example vllm-ascend) + # which currently only supports sequence parallelism in eager mode. + if self.mode != CompilationMode.VLLM_COMPILE: + if self.splitting_ops is None: + self.splitting_ops = [] + return + # NOTE: this function needs to be called only when mode is # CompilationMode.VLLM_COMPILE assert self.mode == CompilationMode.VLLM_COMPILE, ( @@ -862,58 +915,95 @@ class CompilationConfig: "mode is CompilationMode.VLLM_COMPILE" ) - if self.use_inductor_graph_partition: - self.set_splitting_ops_for_inductor_graph_partition() - return + added_default_splitting_ops = False - if self.pass_config.enable_attn_fusion: - # here use_inductor_graph_partition is False + if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition: self.set_splitting_ops_for_attn_fusion() - return + else: + if self.splitting_ops is None: + # NOTE: When using full cudagraph, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture + # the full cudagraph outside the fx graph. This reduces some + # cpu overhead when the runtime batch_size is not cudagraph + # captured. see https://github.com/vllm-project/vllm/pull/20059 + # for details. Make a copy to avoid mutating the class-level + # list via reference. + self.splitting_ops = list(self._attention_ops) + added_default_splitting_ops = True + elif len(self.splitting_ops) == 0: + if ( + self.cudagraph_mode == CUDAGraphMode.PIECEWISE + or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE + ): + logger.warning_once( + "Using piecewise cudagraph with empty splitting_ops" + ) + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + logger.warning_once( + "Piecewise compilation with empty splitting_ops do not" + "contains piecewise cudagraph. Setting cudagraph_" + "mode to NONE. Hint: If you are using attention " + "backends that support cudagraph, consider manually " + "setting cudagraph_mode to FULL or FULL_DECODE_ONLY " + "to enable full cudagraphs." + ) + self.cudagraph_mode = CUDAGraphMode.NONE + elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: + logger.warning_once( + "Piecewise compilation with empty splitting_ops do " + "not contains piecewise cudagraph. Setting " + "cudagraph_mode to FULL." + ) + self.cudagraph_mode = CUDAGraphMode.FULL + self.splitting_ops = [] - if self.splitting_ops is None: - # NOTE: When using full cudagraph, instead of setting an empty - # list and capture the full cudagraph inside the flattened fx - # graph, we keep the piecewise fx graph structure but capture - # the full cudagraph outside the fx graph. This reduces some - # cpu overhead when the runtime batch_size is not cudagraph - # captured. see https://github.com/vllm-project/vllm/pull/20059 - # for details. Make a copy to avoid mutating the class-level - # list via reference. - self.splitting_ops = list(self._attention_ops) - elif len(self.splitting_ops) == 0: - logger.warning_once("Using piecewise compilation with empty splitting_ops") - if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: - logger.warning_once( - "Piecewise compilation with empty splitting_ops do not" - "contains piecewise cudagraph. Setting cudagraph_" - "mode to NONE. Hint: If you are using attention backends " - "that support cudagraph, consider manually setting " - "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable " - "full cudagraphs." - ) + # split MoE ops for cudagraph + moe_ops = [ + "vllm::moe_forward", + "vllm::moe_forward_shared", + ] + backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND + dp_size = data_parallel_size if data_parallel_size is not None else 1 + need_moe_splitting = ( + backend == "deepep_high_throughput" + and dp_size > 1 + # pure attn-fusion without inductor partition deliberately disables + # piecewise graphs and MoE splitting. + and not ( + self.pass_config.fuse_attn_quant + and not self.use_inductor_graph_partition + ) + ) + + if need_moe_splitting and self.cudagraph_mode != CUDAGraphMode.NONE: + # if we just initialized default splitting_ops for this config, + # automatically append the MoE ops + if added_default_splitting_ops: + for op in moe_ops: + if op not in self.splitting_ops: + self.splitting_ops.append(op) + + # make sure MoE ops are split out + if not any(op in self.splitting_ops for op in moe_ops): self.cudagraph_mode = CUDAGraphMode.NONE - elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: logger.warning_once( - "Piecewise compilation with empty splitting_ops do not " - "contains piecewise cudagraph. Setting cudagraph_mode " - "to FULL." + "DeepEP high throughput backend with data_parallel_size > 1 " + "requires splitting MoE ops from cudagraphs. Please ensure " + "'vllm::moe_forward' or 'vllm::moe_forward_shared' are " + "present in CompilationConfig.splitting_ops." ) - self.cudagraph_mode = CUDAGraphMode.FULL - self.splitting_ops = [] - - def set_splitting_ops_for_inductor_graph_partition(self): - assert self.use_inductor_graph_partition - if self.splitting_ops is None: - self.splitting_ops = list(self._attention_ops) + elif self.cudagraph_mode.has_full_cudagraphs(): + # fall back to piecewise when MoE splitting is required. + self.cudagraph_mode = CUDAGraphMode.PIECEWISE def set_splitting_ops_for_attn_fusion(self): - assert self.pass_config.enable_attn_fusion + assert self.pass_config.fuse_attn_quant if self.splitting_ops is None: self.splitting_ops = [] if self.cudagraph_mode.has_piecewise_cudagraphs(): logger.warning_once( - "enable_attn_fusion is incompatible with piecewise " + "fuse_attn_quant is incompatible with piecewise " "cudagraph when use_inductor_graph_partition is off. " "In this case, splitting_ops will be set to empty " "list, and cudagraph_mode will be set to FULL. " @@ -924,8 +1014,7 @@ class CompilationConfig: self.cudagraph_mode = CUDAGraphMode.FULL assert not self.splitting_ops_contain_attention(), ( - "attention ops should not be in splitting_ops " - "when enable_attn_fusion is True" + "attention ops should not be in splitting_ops when fuse_attn_quant is True" ) def splitting_ops_contain_attention(self) -> bool: @@ -1001,7 +1090,7 @@ class CompilationConfig: self, uniform_decode_query_len: int, tensor_parallel_size: int ): multiple_of = uniform_decode_query_len - if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism: + if tensor_parallel_size > 1 and self.pass_config.enable_sp: multiple_of = max(uniform_decode_query_len, tensor_parallel_size) if ( multiple_of % uniform_decode_query_len != 0 @@ -1061,3 +1150,13 @@ class CompilationConfig: self.bs_to_padded_graph_size[bs] = start else: self.bs_to_padded_graph_size[bs] = end + + def get_compile_ranges(self) -> list[Range]: + """Get the compile ranges for the compilation config.""" + if self.compile_ranges_split_points is None: + return [] + split_points = sorted(set(self.compile_ranges_split_points)) + return [ + Range(start=s + 1, end=e) + for s, e in zip([0] + split_points[:-1], split_points) + ] diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index 88f8b91c292bb..98cea821c678e 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -64,6 +64,11 @@ class KVTransferConfig: enable_permute_local_kv: bool = False """Experiment feature flag to enable HND to NHD KV Transfer""" + kv_load_failure_policy: Literal["recompute", "fail"] = "recompute" + """Policy for handling KV cache load failures. + 'recompute': reschedule the request to recompute failed blocks (default) + 'fail': immediately fail the request with an error finish reason""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/config/model.py b/vllm/config/model.py index ef592ac001535..1de9d15cf8c52 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -4,11 +4,11 @@ import warnings from collections.abc import Callable from dataclasses import InitVar, field -from importlib.util import find_spec +from functools import cached_property from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch -from pydantic import ConfigDict, SkipValidation, field_validator, model_validator +from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from transformers.configuration_utils import ALLOWED_LAYER_TYPES @@ -37,15 +37,13 @@ from vllm.transformers_utils.config import ( uses_xdrope_dim, ) from vllm.transformers_utils.gguf_utils import ( - maybe_patch_hf_config_from_gguf, -) -from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri -from vllm.transformers_utils.utils import ( is_gguf, is_remote_gguf, - maybe_model_redirect, + maybe_patch_hf_config_from_gguf, split_remote_gguf, ) +from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri +from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.import_utils import LazyLoader from vllm.utils.torch_utils import common_broadcastable_dtype @@ -75,18 +73,7 @@ logger = init_logger(__name__) RunnerOption = Literal["auto", RunnerType] ConvertType = Literal["none", "embed", "classify", "reward"] ConvertOption = Literal["auto", ConvertType] -TaskOption = Literal[ - "auto", - "generate", - "embedding", - "embed", - "classify", - "score", - "reward", - "transcription", - "draft", -] -TokenizerMode = Literal["auto", "hf", "slow", "mistral"] +TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] LogprobsMode = Literal[ "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" @@ -95,12 +82,6 @@ HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig] ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"] LayerBlockType = Literal["attention", "linear_attention", "mamba"] -_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = { - "generate": ["generate", "transcription"], - "pooling": ["embedding", "embed", "classify", "score", "reward"], - "draft": ["draft"], -} - _RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = { "generate": [], "pooling": ["embed", "classify", "reward"], @@ -128,21 +109,17 @@ class ModelConfig: """Convert the model using adapters defined in [vllm.model_executor.models.adapters][]. The most common use case is to adapt a text generation model to be used for pooling tasks.""" - task: TaskOption | None = None - """[DEPRECATED] The task to use the model for. If the model supports more - than one model runner, this is used to select which model runner to run. - - Note that the model may support other tasks using the same model runner. - """ - tokenizer: SkipValidation[str] = None # type: ignore + tokenizer: str = Field(default=None) """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" tokenizer_mode: TokenizerMode | str = "auto" """Tokenizer mode:\n - - "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n + - "auto" will use the tokenizer from `mistral_common` for Mistral models + if available, otherwise it will use the "hf" tokenizer.\n - "hf" will use the fast tokenizer if available.\n - "slow" will always use the slow tokenizer.\n - "mistral" will always use the tokenizer from `mistral_common`.\n + - "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n - Other custom values can be supported via plugins.""" trust_remote_code: bool = False """Trust remote code (e.g., from HuggingFace) when downloading the model @@ -187,7 +164,7 @@ class ModelConfig: """The specific revision to use for the tokenizer on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" - max_model_len: SkipValidation[int] = None # type: ignore + max_model_len: int = Field(default=None, gt=0) """Model context length (prompt and output). If unspecified, will be automatically derived from the model config. @@ -198,7 +175,7 @@ class ModelConfig: - 25.6k -> 25,600""" spec_target_max_model_len: int | None = None """Specify the maximum length for spec decoding draft models.""" - quantization: SkipValidation[QuantizationMethods | None] = None + quantization: QuantizationMethods | str | None = None """Method used to quantize the weights. If `None`, we first check the `quantization_config` attribute in the model config file. If that is `None`, we assume the model weights are not quantized and use `dtype` to @@ -335,7 +312,6 @@ class ModelConfig: ignored_factors = { "runner", "convert", - "task", "tokenizer", "tokenizer_mode", "seed", @@ -468,18 +444,6 @@ class ModelConfig: self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) - if ( - (backend := envs.VLLM_ATTENTION_BACKEND) - and backend == "FLASHINFER" - and find_spec("flashinfer") is None - ): - raise ValueError( - "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " - "module was not found. See " - "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 - "for instructions on how to install it." - ) - from vllm.platforms import current_platform if self.override_attention_dtype is not None and not current_platform.is_rocm(): @@ -522,93 +486,6 @@ class ModelConfig: is_generative_model = registry.is_text_generation_model(architectures, self) is_pooling_model = registry.is_pooling_model(architectures, self) - def _task_to_convert(task: TaskOption) -> ConvertType: - if task == "embedding" or task == "embed": - return "embed" - if task == "classify": - return "classify" - if task == "reward": - return "reward" - if task == "score": - new_task = self._get_default_pooling_task(architectures) - return "classify" if new_task == "classify" else "embed" - - return "none" - - if self.task is not None: - runner: RunnerOption = "auto" - convert: ConvertOption = "auto" - msg_prefix = ( - "The 'task' option has been deprecated and will be " - "removed in v0.13.0 or v1.0, whichever comes first." - ) - msg_hint = "Please remove this option." - - is_generative_task = self.task in _RUNNER_TASKS["generate"] - is_pooling_task = self.task in _RUNNER_TASKS["pooling"] - - if is_generative_model and is_pooling_model: - if is_generative_task: - runner = "generate" - convert = "auto" - msg_hint = ( - "Please replace this option with `--runner " - "generate` to continue using this model " - "as a generative model." - ) - elif is_pooling_task: - runner = "pooling" - convert = "auto" - msg_hint = ( - "Please replace this option with `--runner " - "pooling` to continue using this model " - "as a pooling model." - ) - else: # task == "auto" - pass - elif is_generative_model or is_pooling_model: - if is_generative_task: - runner = "generate" - convert = "auto" - msg_hint = "Please remove this option" - elif is_pooling_task: - runner = "pooling" - convert = _task_to_convert(self.task) - msg_hint = ( - "Please replace this option with `--convert " - f"{convert}` to continue using this model " - "as a pooling model." - ) - else: # task == "auto" - pass - else: - # Neither generative nor pooling model - try to convert if possible - if is_pooling_task: - runner = "pooling" - convert = _task_to_convert(self.task) - msg_hint = ( - "Please replace this option with `--runner pooling " - f"--convert {convert}` to continue using this model " - "as a pooling model." - ) - else: - debug_info = { - "architectures": architectures, - "is_generative_model": is_generative_model, - "is_pooling_model": is_pooling_model, - } - raise AssertionError( - "The model should be a generative or " - "pooling model when task is set to " - f"{self.task!r}. Found: {debug_info}" - ) - - self.runner = runner - self.convert = convert - - msg = f"{msg_prefix} {msg_hint}" - warnings.warn(msg, DeprecationWarning, stacklevel=2) - self.runner_type = self._get_runner_type(architectures, self.runner) self.convert_type = self._get_convert_type( architectures, self.runner_type, self.convert @@ -662,6 +539,11 @@ class ModelConfig: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) + + if self.is_encoder_decoder: + self.mm_processor_cache_gb = 0 + logger.info("Encoder-decoder model detected, disabling mm processor cache.") + # Init multimodal config if needed if self._model_info.supports_multimodal: if ( @@ -715,6 +597,14 @@ class ModelConfig: self._verify_cuda_graph() self._verify_bnb_config() + @field_validator("tokenizer", "max_model_len", mode="wrap") + @classmethod + def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: + """Skip validation if the value is `None` when initialisation is delayed.""" + if value is None: + return value + return handler(value) + @field_validator("tokenizer_mode", mode="after") def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str: return tokenizer_mode.lower() @@ -728,10 +618,19 @@ class ModelConfig: @model_validator(mode="after") def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": + """Called after __post_init__""" if not isinstance(self.tokenizer, str): - raise ValueError("tokenizer must be a string after __post_init__.") + raise ValueError( + f"tokenizer must be a string, got " + f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. " + "Please provide a valid tokenizer path or HuggingFace model ID." + ) if not isinstance(self.max_model_len, int): - raise ValueError("max_model_len must be an integer after __post_init__.") + raise ValueError( + f"max_model_len must be a positive integer, " + f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. " + "Example: max_model_len=2048" + ) return self def _get_transformers_backend_cls(self) -> str: @@ -911,6 +810,13 @@ class ModelConfig: runner_type: RunnerType, convert: ConvertOption, ) -> ConvertType: + if convert == "reward": + logger.warning( + "`--convert reward` is deprecated and will be removed in v0.15. " + "Please use `--convert embed` instead." + ) + return "embed" + if convert != "auto": return convert @@ -926,22 +832,6 @@ class ModelConfig: return convert_type - def _get_default_pooling_task( - self, - architectures: list[str], - ) -> Literal["embed", "classify", "reward"]: - if self.registry.is_cross_encoder_model(architectures, self): - return "classify" - - for arch in architectures: - match = try_match_architecture_defaults(arch, runner_type="pooling") - if match: - _, (_, convert_type) = match - assert convert_type != "none" - return convert_type - - return "embed" - def _parse_quant_hf_config(self, hf_config: PretrainedConfig): quant_cfg = getattr(hf_config, "quantization_config", None) if quant_cfg is None: @@ -1230,6 +1120,19 @@ class ModelConfig: ) return False + @cached_property + def is_mm_prefix_lm(self) -> bool: + """Whether to use bidirectional attention for mm positions.""" + MM_PREFIX_LM_MODELS = ( + "gemma3", + # TODO(Isotr0py): Disable paligemma for now before + # we supports soft cap attention for FlexAttention + # "paligemma", + ) + if not hasattr(self.hf_config, "model_type"): + return False + return self.hf_config.model_type in MM_PREFIX_LM_MODELS + def get_head_size(self) -> int: # TODO remove hard code if self.is_deepseek_mla: @@ -1300,7 +1203,15 @@ class ModelConfig: // block.attention.n_heads_in_group ) - raise RuntimeError("Couldn't determine number of kv heads") + raise RuntimeError( + "Could not determine the number of key-value attention heads " + "from model configuration. " + f"Model: {self.model}, Architecture: {self.architectures}. " + "This usually indicates an unsupported model architecture or " + "missing configuration. " + "Please check if your model is supported at: " + "https://docs.vllm.ai/en/latest/models/supported_models.html" + ) if self.is_attention_free: return 0 @@ -1781,20 +1692,22 @@ class ModelConfig: return False elif attn_type == "decoder": pooling_type = self.pooler_config.pooling_type.lower() - if pooling_type in ["all", "mean", "step", "cls"]: + if pooling_type in ["mean", "step", "cls"]: logger.debug( "Pooling models with %s pooling does not " "support chunked prefill.", pooling_type, ) return False - else: - # pooling_type == "last" + elif pooling_type in ["all", "last"]: logger.debug( - "Pooling models with causal attn and last pooling support " - "chunked prefill." + "Pooling models with causal attn and %s pooling support " + "chunked prefill.", + pooling_type, ) return True + else: + raise ValueError(f"{pooling_type=} not supported.") # vllm currently does not have pooling models using hybrid, # attention_free or encoder_decoder attn types. return attn_type != "encoder_decoder" @@ -1818,20 +1731,22 @@ class ModelConfig: return False elif attn_type == "decoder": pooling_type = self.pooler_config.pooling_type.lower() - if pooling_type in ["all", "mean", "step", "cls"]: + if pooling_type in ["mean", "step", "cls"]: logger.debug( "Pooling models with %s pooling does not " "support prefix caching.", pooling_type, ) return False - else: - # pooling_type == "last" + elif pooling_type in ["all", "last"]: logger.debug( - "Pooling models with causal attn and last pooling support " - "prefix caching." + "Pooling models with causal attn and %s pooling support " + "prefix caching.", + pooling_type, ) return True + else: + raise ValueError(f"{pooling_type=} not supported.") # vllm currently does not have pooling models using hybrid, # attention_free or encoder_decoder attn types. return False @@ -1890,12 +1805,13 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ ("ForTextEncoding", ("pooling", "embed")), ("EmbeddingModel", ("pooling", "embed")), ("ForSequenceClassification", ("pooling", "classify")), + ("ForTokenClassification", ("pooling", "classify")), ("ForAudioClassification", ("pooling", "classify")), ("ForImageClassification", ("pooling", "classify")), ("ForVideoClassification", ("pooling", "classify")), ("ClassificationModel", ("pooling", "classify")), - ("ForRewardModeling", ("pooling", "reward")), - ("RewardModel", ("pooling", "reward")), + ("ForRewardModeling", ("pooling", "embed")), + ("RewardModel", ("pooling", "embed")), # Let other `*Model`s take priority ("Model", ("pooling", "embed")), ] diff --git a/vllm/config/observability.py b/vllm/config/observability.py index ff35e12fe20ed..e40bf18a00ce2 100644 --- a/vllm/config/observability.py +++ b/vllm/config/observability.py @@ -5,7 +5,7 @@ from functools import cached_property from typing import Any, Literal, cast from packaging.version import parse -from pydantic import field_validator, model_validator +from pydantic import Field, field_validator, model_validator from pydantic.dataclasses import dataclass from vllm import version @@ -47,6 +47,23 @@ class ObservabilityConfig: Note that collecting detailed timing information for each request can be expensive.""" + kv_cache_metrics: bool = False + """Enable KV cache residency metrics (lifetime, idle time, reuse gaps). + Uses sampling to minimize overhead. + Requires log stats to be enabled (i.e., --disable-log-stats not set).""" + + kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1) + """Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks.""" + + cudagraph_metrics: bool = False + """Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph + dispatch modes, and their observed frequencies at every logging interval).""" + + enable_layerwise_nvtx_tracing: bool = False + """Enable layerwise NVTX tracing. This traces the execution of each layer or + module in the model and attach informations such as input/output shapes to + nvtx range markers. Noted that this doesn't work with CUDA graphs enabled.""" + @cached_property def collect_model_forward_time(self) -> bool: """Whether to collect model forward time for the request.""" diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 4a8c8bc17cfc3..3fe066ec32505 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -35,6 +35,7 @@ logger = init_logger(__name__) ExpertPlacementStrategy = Literal["linear", "round_robin"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DataParallelBackend = Literal["ray", "mp"] +EPLBPolicyOption = Literal["default"] @config @@ -65,6 +66,9 @@ class EPLBConfig: Whether to use non-blocking EPLB. """ + policy: EPLBPolicyOption = "default" + """The policy type for expert parallel load balancing (EPLB).""" + @config @dataclass @@ -152,6 +156,8 @@ class ParallelConfig: enable_dbo: bool = False """Enable dual batch overlap for the model executor.""" + ubatch_size: int = 0 + """Number of ubatch size.""" dbo_decode_token_threshold: int = 32 """The threshold for dual batch overlap for batches only containing decodes. @@ -180,13 +186,14 @@ class ParallelConfig: distributed_executor_backend: ( str | DistributedExecutorBackend | type[Executor] | None ) = None - """Backend to use for distributed model - workers, either "ray" or "mp" (multiprocessing). If the product - of pipeline_parallel_size and tensor_parallel_size is less than - or equal to the number of GPUs available, "mp" will be used to - keep processing on a single host. Otherwise, this will default - to "ray" if Ray is installed and fail otherwise. Note that tpu - only support Ray for distributed inference.""" + """Backend to use for distributed model workers, either "ray" or "mp" + (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size + is less than or equal to the number of GPUs available, "mp" will be used to + keep processing on a single host. Otherwise, an error will be raised. To use "mp" + you must also set nnodes, and to use "ray" you must manually set + distributed_executor_backend to "ray". + + Note that tpu only support Ray for distributed inference.""" worker_cls: str = "auto" """The full name of the worker class to use. If "auto", the worker class @@ -312,11 +319,6 @@ class ParallelConfig: "num_redundant_experts." ) - if self.prefill_context_parallel_size > 1: - raise ValueError( - "Prefill context parallelism is not fully supported. " - "Please set prefill_context_parallel_size to 1." - ) return self @property @@ -325,6 +327,14 @@ class ParallelConfig: including data parallelism.""" return self.world_size * self.data_parallel_size + @property + def use_ubatching(self) -> bool: + return self.enable_dbo or self.ubatch_size > 1 + + @property + def num_ubatches(self) -> int: + return 2 if self.enable_dbo else self.ubatch_size + def get_next_dp_init_port(self) -> int: """ We might need to initialize process groups in multiple @@ -562,8 +572,11 @@ class ParallelConfig: ): gpu_count = cuda_device_count_stateless() raise ValueError( - f"Tensor parallel size ({self.world_size}) cannot be " - f"larger than the number of available GPUs ({gpu_count})." + f"World size ({self.world_size}) is larger than the number of " + f"available GPUs ({gpu_count}) in this node. If this is " + "intentional and you are using:\n" + "- ray, set '--distributed-executor-backend ray'.\n" + "- multiprocessing, set '--nnodes' appropriately." ) elif self.data_parallel_backend == "ray": logger.info( @@ -593,10 +606,14 @@ class ParallelConfig: "max_parallel_loading_workers is currently " "not supported and will be ignored." ) - if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1: + allowed_backends = ("mp", "uni", "external_launcher") + if ( + self.distributed_executor_backend not in allowed_backends + and self.nnodes > 1 + ): raise ValueError( "nnodes > 1 can only be set when distributed executor " - "backend is mp or uni." + "backend is mp, uni or external_launcher." ) @property diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index aa4e7006d0247..976ae8c063eb7 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -111,13 +111,15 @@ class PoolerConfig: def get_use_activation(o: object): if softmax := getattr(o, "softmax", None) is not None: logger.warning_once( - "softmax will be deprecated, please use use_activation instead." + "softmax will be deprecated and will be removed in v0.15. " + "Please use use_activation instead." ) return softmax if activation := getattr(o, "activation", None) is not None: logger.warning_once( - "activation will be deprecated, please use use_activation instead." + "activation will be deprecated and will be removed in v0.15. " + "Please use use_activation instead." ) return activation diff --git a/vllm/config/profiler.py b/vllm/config/profiler.py new file mode 100644 index 0000000000000..76cc546f3c9e2 --- /dev/null +++ b/vllm/config/profiler.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Any, Literal + +from pydantic import Field, model_validator +from pydantic.dataclasses import dataclass +from typing_extensions import Self + +import vllm.envs as envs +from vllm.config.utils import config +from vllm.logger import init_logger +from vllm.utils.hashing import safe_hash + +logger = init_logger(__name__) + +ProfilerKind = Literal["torch", "cuda"] + + +@config +@dataclass +class ProfilerConfig: + """Dataclass which contains profiler config for the engine.""" + + profiler: ProfilerKind | None = None + """Which profiler to use. Defaults to None. Options are: + + - 'torch': Use PyTorch profiler.\n + - 'cuda': Use CUDA profiler.""" + + torch_profiler_dir: str = "" + """Directory to save torch profiler traces. Both AsyncLLM's CPU traces and + worker's traces (CPU & GPU) will be saved under this directory. Note that + it must be an absolute path.""" + + torch_profiler_with_stack: bool = True + """If `True`, enables stack tracing in the torch profiler. Enabled by default.""" + + torch_profiler_with_flops: bool = False + """If `True`, enables FLOPS counting in the torch profiler. Disabled by default.""" + + torch_profiler_use_gzip: bool = True + """If `True`, saves torch profiler traces in gzip format. Enabled by default""" + + torch_profiler_dump_cuda_time_total: bool = True + """If `True`, dumps total CUDA time in torch profiler traces. Enabled by default.""" + + torch_profiler_record_shapes: bool = False + """If `True`, records tensor shapes in the torch profiler. Disabled by default.""" + + torch_profiler_with_memory: bool = False + """If `True`, enables memory profiling in the torch profiler. + Disabled by default.""" + + ignore_frontend: bool = False + """If `True`, disables the front-end profiling of AsyncLLM when using the + 'torch' profiler. This is needed to reduce overhead when using delay/limit options, + since the front-end profiling does not track iterations and will capture the + entire range. + """ + + delay_iterations: int = Field(default=0, ge=0) + """Number of engine iterations to skip before starting profiling. + Defaults to 0, meaning profiling starts immediately after receiving /start_profile. + """ + + max_iterations: int = Field(default=0, ge=0) + """Maximum number of engine iterations to profile after starting profiling. + Defaults to 0, meaning no limit. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str + + def _get_from_env_if_set(self, field_name: str, env_var_name: str) -> None: + """Get field from env var if set, with deprecation warning.""" + + if envs.is_set(env_var_name): + value = getattr(envs, env_var_name) + logger.warning_once( + "Using %s environment variable is deprecated and will be removed in " + "v0.14.0 or v1.0.0, whichever is soonest. Please use " + "--profiler-config.%s command line argument or " + "ProfilerConfig(%s=...) config field instead.", + env_var_name, + field_name, + field_name, + ) + return value + return None + + def _set_from_env_if_set( + self, + field_name: str, + env_var_name: str, + to_bool: bool = True, + to_int: bool = False, + ) -> None: + """Set field from env var if set, with deprecation warning.""" + value = self._get_from_env_if_set(field_name, env_var_name) + if value is not None: + if to_bool: + value = value == "1" + if to_int: + value = int(value) + setattr(self, field_name, value) + + @model_validator(mode="after") + def _validate_profiler_config(self) -> Self: + maybe_use_cuda_profiler = self._get_from_env_if_set( + "profiler", "VLLM_TORCH_CUDA_PROFILE" + ) + if maybe_use_cuda_profiler is not None: + self.profiler = "cuda" if maybe_use_cuda_profiler == "1" else None + else: + self._set_from_env_if_set( + "torch_profiler_dir", "VLLM_TORCH_PROFILER_DIR", to_bool=False + ) + if self.torch_profiler_dir: + self.profiler = "torch" + self._set_from_env_if_set( + "torch_profiler_record_shapes", + "VLLM_TORCH_PROFILER_RECORD_SHAPES", + ) + self._set_from_env_if_set( + "torch_profiler_with_memory", + "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", + ) + self._set_from_env_if_set( + "torch_profiler_with_stack", + "VLLM_TORCH_PROFILER_WITH_STACK", + ) + self._set_from_env_if_set( + "torch_profiler_with_flops", + "VLLM_TORCH_PROFILER_WITH_FLOPS", + ) + self._set_from_env_if_set( + "ignore_frontend", + "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", + ) + self._set_from_env_if_set( + "torch_profiler_use_gzip", + "VLLM_TORCH_PROFILER_USE_GZIP", + ) + self._set_from_env_if_set( + "torch_profiler_dump_cuda_time_total", + "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", + ) + + self._set_from_env_if_set( + "delay_iterations", "VLLM_PROFILER_DELAY_ITERS", to_bool=False, to_int=True + ) + self._set_from_env_if_set( + "max_iterations", "VLLM_PROFILER_MAX_ITERS", to_bool=False, to_int=True + ) + + has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0 + if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend: + logger.warning_once( + "Using 'torch' profiler with delay_iterations or max_iterations " + "while ignore_frontend is False may result in high overhead." + ) + + profiler_dir = self.torch_profiler_dir + if profiler_dir and self.profiler != "torch": + raise ValueError( + "torch_profiler_dir is only applicable when profiler is set to 'torch'" + ) + if self.profiler == "torch" and not profiler_dir: + raise ValueError("torch_profiler_dir must be set when profiler is 'torch'") + + if profiler_dir: + is_gs_path = ( + profiler_dir.startswith("gs://") + and profiler_dir[5:] + and profiler_dir[5] != "/" + ) + if not is_gs_path: + self.torch_profiler_dir = os.path.abspath( + os.path.expanduser(profiler_dir) + ) + + return self diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index ff1ac0e18f324..8abbe8ba0103e 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -28,6 +28,19 @@ SchedulerPolicy = Literal["fcfs", "priority"] class SchedulerConfig: """Scheduler configuration.""" + max_model_len: InitVar[int] + """Maximum length of a sequence (including prompt and generated text). + + Note: This is stored in the ModelConfig, and is used only here to + provide fallbacks and validate other attributes.""" + + is_encoder_decoder: InitVar[bool] + """True if the model is an encoder-decoder model. + + Note: This is stored in the ModelConfig, and is used only here to + disable chunked prefill and prefix caching for encoder-decoder models. + """ + DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048 DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128 @@ -73,19 +86,6 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" - max_model_len: InitVar[int] = 8192 - """Maximum length of a sequence (including prompt and generated text). - - Note: This is stored in the ModelConfig, and is used only here to - provide fallbacks and validate other attributes.""" - - is_encoder_decoder: InitVar[bool] = False - """True if the model is an encoder-decoder model. - - Note: This is stored in the ModelConfig, and is used only here to - disable chunked prefill and prefix caching for encoder-decoder models. - """ - # TODO (ywang96): Make this configurable. max_num_encoder_input_tokens: int = Field(init=False) """Multimodal encoder compute budget, only used in V1. @@ -122,10 +122,12 @@ class SchedulerConfig: the default scheduler. Can be a class directly or the path to a class of form "mod.custom_class".""" - disable_hybrid_kv_cache_manager: bool = False + disable_hybrid_kv_cache_manager: bool | None = None """If set to True, KV cache manager will allocate the same size of KV cache for all attention layers even if there are multiple type of attention layers like full attention and sliding window attention. + If set to None, the default value will be determined based on the environment + and starting configuration. """ async_scheduling: bool = False @@ -141,6 +143,17 @@ class SchedulerConfig: while a larger value (e.g., 10) reduces host overhead and may increase throughput by batching multiple tokens before sending.""" + @staticmethod + def default_factory(**kwargs): + """ + Factory method to create `SchedulerConfig` with default values for `InitVar`s. + """ + if "max_model_len" not in kwargs: + kwargs["max_model_len"] = 8192 + if "is_encoder_decoder" not in kwargs: + kwargs["is_encoder_decoder"] = False + return SchedulerConfig(**kwargs) + def get_scheduler_cls(self) -> type["SchedulerInterface"]: if self.scheduler_cls is None: if self.async_scheduling: @@ -175,9 +188,19 @@ class SchedulerConfig: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # this config will not affect the computation graph. factors: list[Any] = [] + + # max_num_batched_tokens need to be included in the hash due + # to two reasons: + # 1. LoRA creates static buffers based on max_num_batched_tokens. + # The tensor sizes and strides get captured in the torch.compile + # graph explicitly. + # 2. Inductor decides whether using 32-bit or 64-bit indexing integer + # based on the data sizes. `max_num_batched_tokens` has an + # impact on that. For more details, please check + # https://github.com/vllm-project/vllm/issues/29585 + factors.append(self.max_num_batched_tokens) + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 80d53a543f149..bf533bf14e55c 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -167,6 +167,7 @@ class SpeculativeConfig: @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + initial_architecture = hf_config.architectures[0] if hf_config.model_type in ("deepseek_v3", "deepseek_v32"): hf_config.model_type = "deepseek_mtp" if hf_config.model_type == "deepseek_mtp": @@ -226,6 +227,9 @@ class SpeculativeConfig: {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} ) + if initial_architecture == "MistralLarge3ForCausalLM": + hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]}) + return hf_config def __post_init__(self): @@ -333,6 +337,7 @@ class SpeculativeConfig: enforce_eager=self.target_model_config.enforce_eager, max_logprobs=self.target_model_config.max_logprobs, hf_overrides=SpeculativeConfig.hf_config_override, + config_format=self.target_model_config.config_format, ) # Automatically detect the method diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index 1b32675c3dbd2..8c060c816fd15 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -28,8 +28,10 @@ class StructuredOutputsConfig: disable_fallback: bool = False """If `True`, vLLM will not fallback to a different backend on error.""" disable_any_whitespace: bool = False - """If `True`, the model will not generate any whitespace during structured - outputs. This is only supported for xgrammar and guidance backends.""" + """If `True`, json output will always be compact without any whitespace. + If `False`, the model may generate whitespace between JSON fields, + which is still valid JSON. This is only supported for xgrammar + and guidance backends.""" disable_additional_properties: bool = False """If `True`, the `guidance` backend will not use `additionalProperties` in the JSON schema. This is only supported for the `guidance` backend and @@ -63,22 +65,6 @@ class StructuredOutputsConfig: @model_validator(mode="after") def _validate_structured_output_config(self) -> Self: - # Import here to avoid circular import - from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager - - if self.reasoning_parser_plugin and len(self.reasoning_parser_plugin) > 3: - ReasoningParserManager.import_reasoning_parser(self.reasoning_parser_plugin) - - valid_reasoning_parsers = ReasoningParserManager.list_registered() - if ( - self.reasoning_parser != "" - and self.reasoning_parser not in valid_reasoning_parsers - ): - raise ValueError( - f"invalid reasoning parser: {self.reasoning_parser} " - f"(chose from {{ {','.join(valid_reasoning_parsers)} }})" - ) - if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"): raise ValueError( "disable_any_whitespace is only supported for " diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 02f2b75f608f1..470296517deb1 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -10,7 +10,7 @@ import json import pathlib import textwrap from collections.abc import Iterable, Mapping, Sequence, Set -from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -19,6 +19,10 @@ import torch from pydantic.fields import FieldInfo from typing_extensions import runtime_checkable +from vllm.logger import init_logger + +logger = init_logger(__name__) + if TYPE_CHECKING: from _typeshed import DataclassInstance else: @@ -69,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field: ) -def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any: +def getattr_iter( + object: object, names: Iterable[str], default: Any, warn: bool = False +) -> Any: """ A helper function that retrieves an attribute from an object which may have multiple possible names. This is useful when fetching attributes from arbitrary `transformers.PretrainedConfig` instances. + + In the case where the first name in `names` is the preferred name, and + any other names are deprecated aliases, setting `warn=True` will log a + warning when a deprecated name is used. """ - for name in names: + for i, name in enumerate(names): if hasattr(object, name): + if warn and i > 0: + logger.warning_once( + "%s contains a deprecated attribute name '%s'. " + "Please use the preferred attribute name '%s' instead.", + type(object).__name__, + name, + names[0], + ) return getattr(object, name) return default @@ -293,3 +311,60 @@ def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, ob def hash_factors(items: dict[str, object]) -> str: """Return a SHA-256 hex digest of the canonical items structure.""" return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest() + + +def handle_deprecated( + config: ConfigT, + old_name: str, + new_name_or_names: str | list[str], + removal_version: str, +) -> None: + old_val = getattr(config, old_name) + if old_val is None: + return + + if isinstance(new_name_or_names, str): + new_names = [new_name_or_names] + else: + new_names = new_name_or_names + + msg = ( + f"{old_name} is deprecated and will be removed in {removal_version}. " + f"Use {', '.join(new_names)} instead." + ) + logger.warning(msg) + + for new_name in new_names: + setattr(config, new_name, old_val) + + +@dataclass +class Range: + """ + A range of numbers. + Inclusive of start, inclusive of end. + """ + + start: int + end: int + + def is_single_size(self) -> bool: + return self.start == self.end + + def __contains__(self, size: int) -> bool: + # Inclusive of start, inclusive of end + return self.start <= size <= self.end + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Range): + return False + return self.start == other.start and self.end == other.end + + def __hash__(self) -> int: + return hash((self.start, self.end)) + + def __str__(self) -> str: + return f"({self.start}, {self.end})" + + def __repr__(self) -> str: + return self.__str__() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 34e70e3e134be..0439dc52e7e6f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -27,6 +27,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid from vllm.utils.hashing import safe_hash +from .attention import AttentionConfig from .cache import CacheConfig from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .device import DeviceConfig @@ -38,6 +39,7 @@ from .lora import LoRAConfig from .model import ModelConfig from .observability import ObservabilityConfig from .parallel import ParallelConfig +from .profiler import ProfilerConfig from .scheduler import SchedulerConfig from .speculative import SpeculativeConfig from .structured_outputs import StructuredOutputsConfig @@ -65,7 +67,7 @@ class OptimizationLevel(IntEnum): """O0 : No optimization. no compilation, no cudagraphs, no other optimization, just starting up immediately""" O1 = 1 - """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise + """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise cudagraphs""" O2 = 2 """O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs.""" @@ -83,22 +85,33 @@ IS_DENSE = False # See https://github.com/vllm-project/vllm/issues/25689. -def enable_fusion(cfg: "VllmConfig") -> bool: - """Returns True if RMS norm or quant FP8 is enabled.""" +def enable_norm_fusion(cfg: "VllmConfig") -> bool: + """Enable if either RMS norm or quant FP8 custom op is active; + otherwise Inductor handles fusion.""" + return cfg.compilation_config.is_custom_op_enabled( "rms_norm" ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8") +def enable_act_fusion(cfg: "VllmConfig") -> bool: + """Enable if either SiLU+Mul or quant FP8 custom op is active; + otherwise Inductor handles fusion.""" + return cfg.compilation_config.is_custom_op_enabled( + "silu_and_mul" + ) or cfg.compilation_config.is_custom_op_enabled("quant_fp8") + + OPTIMIZATION_LEVEL_00 = { "compilation_config": { "pass_config": { - "enable_noop": False, - "enable_fusion": False, - "enable_fi_allreduce_fusion": False, - "enable_attn_fusion": False, - "enable_sequence_parallelism": False, - "enable_async_tp": False, + "eliminate_noops": False, + "fuse_norm_quant": False, + "fuse_act_quant": False, + "fuse_allreduce_rms": False, + "fuse_attn_quant": False, + "enable_sp": False, + "fuse_gemm_comms": False, }, "cudagraph_mode": CUDAGraphMode.NONE, "use_inductor_graph_partition": False, @@ -107,12 +120,13 @@ OPTIMIZATION_LEVEL_00 = { OPTIMIZATION_LEVEL_01 = { "compilation_config": { "pass_config": { - "enable_noop": True, - "enable_fusion": enable_fusion, - "enable_fi_allreduce_fusion": False, - "enable_attn_fusion": False, - "enable_sequence_parallelism": False, - "enable_async_tp": False, + "eliminate_noops": True, + "fuse_norm_quant": enable_norm_fusion, + "fuse_act_quant": enable_act_fusion, + "fuse_allreduce_rms": False, + "fuse_attn_quant": False, + "enable_sp": False, + "fuse_gemm_comms": False, }, "cudagraph_mode": CUDAGraphMode.PIECEWISE, "use_inductor_graph_partition": False, @@ -121,12 +135,13 @@ OPTIMIZATION_LEVEL_01 = { OPTIMIZATION_LEVEL_02 = { "compilation_config": { "pass_config": { - "enable_noop": True, - "enable_fusion": enable_fusion, - "enable_fi_allreduce_fusion": False, - "enable_attn_fusion": IS_QUANTIZED, - "enable_sequence_parallelism": IS_DENSE, - "enable_async_tp": IS_DENSE, + "eliminate_noops": True, + "fuse_norm_quant": enable_norm_fusion, + "fuse_act_quant": enable_act_fusion, + "fuse_allreduce_rms": False, + "fuse_attn_quant": IS_QUANTIZED, + "enable_sp": IS_DENSE, + "fuse_gemm_comms": IS_DENSE, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, @@ -135,12 +150,13 @@ OPTIMIZATION_LEVEL_02 = { OPTIMIZATION_LEVEL_03 = { "compilation_config": { "pass_config": { - "enable_noop": True, - "enable_fusion": enable_fusion, - "enable_fi_allreduce_fusion": False, - "enable_attn_fusion": IS_QUANTIZED, - "enable_sequence_parallelism": IS_DENSE, - "enable_async_tp": IS_DENSE, + "eliminate_noops": True, + "fuse_norm_quant": enable_norm_fusion, + "fuse_act_quant": enable_act_fusion, + "fuse_allreduce_rms": False, + "fuse_attn_quant": IS_QUANTIZED, + "enable_sp": IS_DENSE, + "fuse_gemm_comms": IS_DENSE, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, @@ -170,12 +186,16 @@ class VllmConfig: """Cache configuration.""" parallel_config: ParallelConfig = Field(default_factory=ParallelConfig) """Parallel configuration.""" - scheduler_config: SchedulerConfig = Field(default_factory=SchedulerConfig) + scheduler_config: SchedulerConfig = Field( + default_factory=SchedulerConfig.default_factory, + ) """Scheduler configuration.""" device_config: DeviceConfig = Field(default_factory=DeviceConfig) """Device configuration.""" load_config: LoadConfig = Field(default_factory=LoadConfig) """Load configuration.""" + attention_config: AttentionConfig = Field(default_factory=AttentionConfig) + """Attention configuration.""" lora_config: LoRAConfig | None = None """LoRA configuration.""" speculative_config: SpeculativeConfig | None = None @@ -199,6 +219,8 @@ class VllmConfig: You can specify the full compilation config like so: `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` """ + profiler_config: ProfilerConfig = Field(default_factory=ProfilerConfig) + """Profiling configuration.""" kv_transfer_config: KVTransferConfig | None = None """The configurations for distributed KV cache transfer.""" kv_events_config: KVEventsConfig | None = None @@ -263,12 +285,12 @@ class VllmConfig: vllm_factors.append(self.load_config.compute_hash()) else: vllm_factors.append("None") + if self.attention_config: + vllm_factors.append(self.attention_config.compute_hash()) + else: + vllm_factors.append("None") if self.lora_config: vllm_factors.append(self.lora_config.compute_hash()) - # LoRA creates static buffers based on max_num_batched_tokens. - # The tensor sizes and strides get captured in the torch.compile - # graph explicitly. - vllm_factors.append(str(self.scheduler_config.max_num_batched_tokens)) else: vllm_factors.append("None") if self.speculative_config: @@ -277,6 +299,8 @@ class VllmConfig: vllm_factors.append("None") if self.structured_outputs_config: vllm_factors.append(self.structured_outputs_config.compute_hash()) + if self.profiler_config: + vllm_factors.append(self.profiler_config.compute_hash()) else: vllm_factors.append("None") vllm_factors.append(self.observability_config.compute_hash()) @@ -567,6 +591,15 @@ class VllmConfig: else: self.scheduler_config.async_scheduling = True + if ( + self.scheduler_config.async_scheduling + and not self.parallel_config.disable_nccl_for_dp_synchronization + ): + logger.info( + "Disabling NCCL for DP synchronization when using async scheduling." + ) + self.parallel_config.disable_nccl_for_dp_synchronization = True + from vllm.platforms import current_platform if ( @@ -633,8 +666,9 @@ class VllmConfig: default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level] self._apply_optimization_level_defaults(default_config) + if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + self.compilation_config.cudagraph_mode.requires_piecewise_compilation() and self.compilation_config.mode != CompilationMode.VLLM_COMPILE ): logger.info( @@ -647,9 +681,9 @@ class VllmConfig: # async tp is built on top of sequence parallelism # and requires it to be enabled. - if self.compilation_config.pass_config.enable_async_tp: - self.compilation_config.pass_config.enable_sequence_parallelism = True - if self.compilation_config.pass_config.enable_sequence_parallelism: + if self.compilation_config.pass_config.fuse_gemm_comms: + self.compilation_config.pass_config.enable_sp = True + if self.compilation_config.pass_config.enable_sp: if "-rms_norm" in self.compilation_config.custom_ops: logger.warning( "RMS norm force disabled, sequence parallelism might break" @@ -659,36 +693,29 @@ class VllmConfig: if current_platform.support_static_graph_mode(): # if cudagraph_mode has full cudagraphs, we need to check support - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - # decode context parallel does not support full cudagraphs - if self.parallel_config.decode_context_parallel_size > 1: + if model_config := self.model_config: + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and model_config.pooler_config is not None + ): logger.warning_once( - "Decode context parallel (DCP) is enabled, which is " - "incompatible with full CUDA graphs. " + "Pooling models do not support full cudagraphs. " "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - # prefill context parallel do not support full cudagraphs - elif self.parallel_config.prefill_context_parallel_size > 1: - logger.warning_once( - "Prefill context parallel (PCP) is enabled, which is " - "incompatible with full CUDA graphs. " - "Overriding cudagraph_mode to PIECEWISE." + elif ( + model_config.is_encoder_decoder + and self.compilation_config.cudagraph_mode + not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY) + ): + logger.info_once( + "Encoder-decoder models do not support %s. " + "Overriding cudagraph_mode to FULL_DECODE_ONLY.", + self.compilation_config.cudagraph_mode.name, + ) + self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.FULL_DECODE_ONLY ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - elif self.model_config is not None: - if self.model_config.pooler_config is not None: - logger.warning_once( - "Pooling models do not support full cudagraphs. " - "Overriding cudagraph_mode to PIECEWISE." - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - elif self.model_config.is_encoder_decoder: - logger.warning_once( - "Encoder-decoder models do not support full cudagraphs. " - "Overriding cudagraph_mode to PIECEWISE." - ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: @@ -720,28 +747,20 @@ class VllmConfig: "--kv-sharing-fast-prefill requires changes on model side for " "correctness and to realize prefill savings. " ) + # TODO: Move after https://github.com/vllm-project/vllm/pull/26847 lands + self._set_compile_ranges() - if self.model_config and self.model_config.is_encoder_decoder: - from vllm.multimodal import MULTIMODAL_REGISTRY - - self.scheduler_config.max_num_encoder_input_tokens = ( - MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) + if ( + self.model_config + and self.model_config.architecture == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" + ): + logger.warning( + "Whisper is known to have issues with " + "forked workers. If startup is hanging, " + "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " + "to 'spawn'." ) - logger.debug( - "Encoder-decoder model detected: setting " - "`max_num_encoder_input_tokens` to encoder length (%s)", - self.scheduler_config.max_num_encoder_input_tokens, - ) - if ( - self.model_config.architecture == "WhisperForConditionalGeneration" - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" - ): - logger.warning( - "Whisper is known to have issues with " - "forked workers. If startup is hanging, " - "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " - "to 'spawn'." - ) if ( self.kv_events_config is not None @@ -791,21 +810,25 @@ class VllmConfig: f"({self.parallel_config.cp_kv_cache_interleave_size})." ) - assert ( - self.parallel_config.cp_kv_cache_interleave_size == 1 - or self.speculative_config is None - ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now." - # Do this after all the updates to compilation_config.mode - if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: - self.compilation_config.set_splitting_ops_for_v1() + self.compilation_config.set_splitting_ops_for_v1( + all2all_backend=self.parallel_config.all2all_backend, + data_parallel_size=self.parallel_config.data_parallel_size, + ) - if self.compilation_config.pass_config.enable_sequence_parallelism: + if self.compilation_config.pass_config.enable_sp: # With pipeline parallelism or dynamo partitioning, # native rms norm tracing errors due to incorrect residual shape. # Use custom rms norm to unblock. In the future, # the pass will operate on higher-level IR to avoid the issue. # TODO: https://github.com/vllm-project/vllm/issues/27894 + if self.compilation_config.mode != CompilationMode.VLLM_COMPILE: + logger.warning( + "Sequence parallelism is enabled, but running in wrong " + "vllm compile mode: %s.", + self.compilation_config.mode, + ) + is_fullgraph = ( self.compilation_config.use_inductor_graph_partition or len(self.compilation_config.splitting_ops) == 0 @@ -847,9 +870,12 @@ class VllmConfig: f"cudagraph_mode={self.compilation_config.cudagraph_mode}" ) - if self.parallel_config.enable_dbo: + if self.parallel_config.use_ubatching: a2a_backend = self.parallel_config.all2all_backend - assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( + assert a2a_backend in [ + "deepep_low_latency", + "deepep_high_throughput", + ], ( "Microbatching currently only supports the deepep_low_latency and " f"deepep_high_throughput all2all backend. {a2a_backend} is not " "supported. To fix use --all2all-backend=deepep_low_latency or " @@ -864,17 +890,48 @@ class VllmConfig: if not self.instance_id: self.instance_id = random_uuid()[:5] - if not self.scheduler_config.disable_hybrid_kv_cache_manager: - # logger should only print warning message for hybrid models. As we - # can't know whether the model is hybrid or not now, so we don't log - # warning message here and will log it later. - if not current_platform.support_hybrid_kv_cache(): - # Hybrid KV cache manager is not supported on non-GPU platforms. - self.scheduler_config.disable_hybrid_kv_cache_manager = True + # Hybrid KV cache manager (HMA) runtime rules: + # - Explicit enable (--no-disable-kv-cache-manager): error if runtime + # disables it + # - No preference: auto-disable for unsupported features (e.g. kv connector) + # - Explicit disable (--disable-kv-cache-manager): always respect it + need_disable_hybrid_kv_cache_manager = False + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not current_platform.support_hybrid_kv_cache(): + # Hybrid KV cache manager is not supported on non-GPU platforms. + need_disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + need_disable_hybrid_kv_cache_manager = True + if ( + self.model_config is not None + and self.model_config.attention_chunk_size is not None + ): + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): + # Hybrid KV cache manager is not yet supported with chunked + # local attention + eagle. + need_disable_hybrid_kv_cache_manager = True + elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: + logger.warning( + "There is a latency regression when using chunked local" + " attention with the hybrid KV cache manager. Disabling" + " it, by default. To enable it, set the environment " + "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." + ) + # Hybrid KV cache manager is not yet supported with chunked + # local attention. + need_disable_hybrid_kv_cache_manager = True + + if self.scheduler_config.disable_hybrid_kv_cache_manager is None: + # Default to disable HMA, but only if the user didn't express a preference. if self.kv_transfer_config is not None: - # NOTE(Kuntai): turn HMA off for connector for now. - # TODO(Kuntai): have a more elegent solution to check and - # turn off HMA for connector that does not support HMA. + # NOTE(Kuntai): turn HMA off for connector unless specifically enabled. + need_disable_hybrid_kv_cache_manager = True logger.warning( "Turning off hybrid kv cache manager because " "`--kv-transfer-config` is set. This will reduce the " @@ -882,33 +939,26 @@ class VllmConfig: "or Mamba attention. If you are a developer of kv connector" ", please consider supporting hybrid kv cache manager for " "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py." + " of `SupportsHMA` defined in kv_connector/v1/base.py and" + " use --no-disable-hybrid-kv-cache-manager to start vLLM." ) - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_events_config is not None: - # Hybrid KV cache manager is not compatible with KV events. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if ( - self.model_config is not None - and self.model_config.attention_chunk_size is not None - ): - if ( - self.speculative_config is not None - and self.speculative_config.use_eagle() - ): - # Hybrid KV cache manager is not yet supported with chunked - # local attention + eagle. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: - logger.warning( - "There is a latency regression when using chunked local" - " attention with the hybrid KV cache manager. Disabling" - " it, by default. To enable it, set the environment " - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." - ) - # Hybrid KV cache manager is not yet supported with chunked - # local attention. - self.scheduler_config.disable_hybrid_kv_cache_manager = True + self.scheduler_config.disable_hybrid_kv_cache_manager = ( + need_disable_hybrid_kv_cache_manager + ) + elif ( + self.scheduler_config.disable_hybrid_kv_cache_manager is False + and need_disable_hybrid_kv_cache_manager + ): + raise ValueError( + "Hybrid KV cache manager was explicitly enabled but is not " + "supported in this configuration. Consider omitting the " + "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide" + " automatically." + ) + + if self.scheduler_config.disable_hybrid_kv_cache_manager is None: + # Default to enable HMA if not explicitly disabled by user or logic above. + self.scheduler_config.disable_hybrid_kv_cache_manager = False if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( @@ -976,7 +1026,7 @@ class VllmConfig: max_graph_size = min(max_num_seqs * 2, 512) # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16 # up to max_graph_size - cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( + cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( range(256, max_graph_size + 1, 16)) In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` @@ -1017,8 +1067,14 @@ class VllmConfig: self.compilation_config.max_cudagraph_capture_size ) if max_cudagraph_capture_size is None: + decode_query_len = 1 + if ( + self.speculative_config + and self.speculative_config.num_speculative_tokens + ): + decode_query_len += self.speculative_config.num_speculative_tokens max_cudagraph_capture_size = min( - self.scheduler_config.max_num_seqs * 2, 512 + self.scheduler_config.max_num_seqs * decode_query_len * 2, 512 ) max_num_tokens = self.scheduler_config.max_num_batched_tokens max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size) @@ -1058,7 +1114,7 @@ class VllmConfig: if ( self.parallel_config.tensor_parallel_size > 1 - and self.compilation_config.pass_config.enable_sequence_parallelism + and self.compilation_config.pass_config.enable_sp ): cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism( cudagraph_capture_sizes @@ -1115,6 +1171,52 @@ class VllmConfig: # complete the remaining process. self.compilation_config.post_init_cudagraph_sizes() + def _set_compile_ranges(self): + """ + Set the compile ranges for the compilation config. + """ + compilation_config = self.compilation_config + computed_compile_ranges_split_points = [] + + # The upper bound of the compile ranges is the max_num_batched_tokens + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if max_num_batched_tokens is not None: + computed_compile_ranges_split_points.append(max_num_batched_tokens) + + # Add the compile ranges for flashinfer + if compilation_config.pass_config.fuse_allreduce_rms: + tp_size = self.parallel_config.tensor_parallel_size + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None: + max_token_num = max_size // ( + self.model_config.get_hidden_size() + * self.model_config.dtype.itemsize + ) + if ( + max_num_batched_tokens is not None + and max_token_num < max_num_batched_tokens + ): + computed_compile_ranges_split_points.append(max_token_num) + else: + logger.debug( + "Max num batched tokens below allreduce-rms fusion threshold, " + "allreduce-rms fusion will be enabled for all num_tokens." + ) + + if compilation_config.compile_ranges_split_points is not None: + for x in compilation_config.compile_ranges_split_points: + assert isinstance(x, int) + assert x > 0, f"Invalid compile range split point: {x}" + if ( + max_num_batched_tokens is not None + and x < max_num_batched_tokens + and x > 1 + ): + computed_compile_ranges_split_points.append(x) + compilation_config.compile_ranges_split_points = sorted( + computed_compile_ranges_split_points + ) + def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config model_config = self.model_config diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 2e878eef908ac..cd9c267beb5b5 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase): output_shape, dtype=input_tensor.dtype, device=input_tensor.device ) - if sizes is not None: + if sizes is not None and sizes.count(sizes[0]) != len(sizes): pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) else: pynccl_comm.reduce_scatter(output, input_tensor) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 052df19e34d72..31c6084c9b507 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import pickle +import threading import time from contextlib import contextmanager from dataclasses import dataclass, field @@ -27,6 +28,7 @@ from zmq import ( # type: ignore import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup, sched_yield from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils.network_utils import ( get_ip, get_open_port, @@ -42,6 +44,33 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL from_bytes_big = functools.partial(int.from_bytes, byteorder="big") +# Memory fence for cross-process shared memory visibility. +# Required for correct producer-consumer synchronization when using +# shared memory without locks. +_memory_fence_lock = threading.Lock() + + +def memory_fence(): + """ + Full memory barrier for shared memory synchronization. + + Ensures all prior memory writes are visible to other processes before + any subsequent reads. This is critical for lock-free producer-consumer + patterns using shared memory. + + Implementation acquires and immediately releases a lock. Python's + threading.Lock provides sequentially consistent memory barrier semantics + across all major platforms (POSIX, Windows). This is a lightweight + operation (~20ns) that guarantees: + - All stores before the barrier are visible to other threads/processes + - All loads after the barrier see the latest values + """ + # Lock acquire/release provides full memory barrier semantics. + # Using context manager ensures lock release even on exceptions. + with _memory_fence_lock: + pass + + def to_bytes_big(value: int, size: int) -> bytes: return value.to_bytes(size, byteorder="big") @@ -413,6 +442,10 @@ class MessageQueue: n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + # Memory fence ensures we see the latest read flags from readers. + # Without this, we may read stale flags from our CPU cache and + # spin indefinitely even though readers have completed. + memory_fence() read_count = sum(metadata_buffer[1:]) written_flag = metadata_buffer[0] if written_flag and read_count != self.buffer.n_reader: @@ -457,6 +490,10 @@ class MessageQueue: metadata_buffer[i] = 0 # mark the block as written metadata_buffer[0] = 1 + # Memory fence ensures the write is visible to readers on other cores + # before we proceed. Without this, readers may spin indefinitely + # waiting for a write that's stuck in our CPU's store buffer. + memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break @@ -472,6 +509,10 @@ class MessageQueue: n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + # Memory fence ensures we see the latest writes from the writer. + # Without this, we may read stale flags from our CPU cache + # and spin indefinitely even though writer has updated them. + memory_fence() read_flag = metadata_buffer[self.local_reader_rank + 1] written_flag = metadata_buffer[0] if not written_flag or read_flag: @@ -512,6 +553,10 @@ class MessageQueue: # caller has read from the buffer # set the read flag metadata_buffer[self.local_reader_rank + 1] = 1 + # Memory fence ensures the read flag is visible to the writer. + # Without this, writer may not see our read completion and + # could wait indefinitely for all readers to finish. + memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self._read_spin_timer.record_activity() @@ -632,7 +677,7 @@ class MessageQueue: The MessageQueue instance for the calling process, and a list of handles (only non-empty for the reader process). """ - local_size = torch.cuda.device_count() + local_size = current_platform.device_count() rank = dist.get_rank() same_node = rank // local_size == reader_rank // local_size buffer_io = MessageQueue( diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py index 4af2caa16b0d6..5da261fbc6cfc 100644 --- a/vllm/distributed/device_communicators/shm_object_storage.py +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -574,7 +574,6 @@ class SingleWriterShmObjectStorage: value ) buffer_size = self.flag_bytes + data_bytes + md_bytes - # Sanity checks if buffer_size > self.max_object_size: raise ValueError( @@ -626,6 +625,44 @@ class SingleWriterShmObjectStorage: return obj + def touch( + self, + key: str, + address: int = 0, + monotonic_id: int = 0, + ) -> None: + """ + Touch an existing cached item to update its eviction status. + + For writers (ShmObjectStoreSenderCache): Increment writer_flag + For readers (ShmObjectStoreReceiverCache): Increment reader_count + + Args: + key: String key of the object to touch + address: Address of the object (only for readers) + monotonic_id: Monotonic ID of the object (only for readers) + + """ + if self._reader_lock is None: + if key not in self.key_index: + return None + address, monotonic_id = self.key_index[key] + # Writer side: increment writer_flag to raise eviction threshold + self.increment_writer_flag(monotonic_id) + else: + with ( + self._reader_lock, + self.ring_buffer.access_buf(address) as (data_view, _), + ): + reader_count = self.ring_buffer.byte2int(data_view[: self.flag_bytes]) + + # NOTE(Long): + # Avoid increasing flag on newly added item (sync with sender) + # Since when a new item is added + # pre-touch has no effect on writer side + if reader_count >= self.n_readers: + self.increment_reader_flag(data_view[: self.flag_bytes]) + def handle(self): """Get handle for sharing across processes.""" return ShmObjectStorageHandle( diff --git a/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py b/vllm/distributed/ec_transfer/ec_connector/example_connector.py similarity index 96% rename from vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py rename to vllm/distributed/ec_transfer/ec_connector/example_connector.py index c8388141dcc97..5f2eff5a8e6a8 100644 --- a/vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py +++ b/vllm/distributed/ec_transfer/ec_connector/example_connector.py @@ -32,7 +32,7 @@ class MMMeta: @dataclass -class ECSharedStorageConnectorMetadata(ECConnectorMetadata): +class ECExampleConnectorMetadata(ECConnectorMetadata): mm_datas: list[MMMeta] def __init__(self): @@ -42,7 +42,7 @@ class ECSharedStorageConnectorMetadata(ECConnectorMetadata): self.mm_datas.append(mm_data) -class ECSharedStorageConnector(ECConnectorBase): +class ECExampleConnector(ECConnectorBase): # NOTE: This is Simple debug implementation of the EC connector. # It save / load the EC cache to / from the disk. @@ -76,7 +76,7 @@ class ECSharedStorageConnector(ECConnectorBase): # Get the metadata metadata: ECConnectorMetadata = self._get_connector_metadata() - assert isinstance(metadata, ECSharedStorageConnectorMetadata) + assert isinstance(metadata, ECExampleConnectorMetadata) assert encoder_cache is not None if metadata is None: logger.warning( @@ -160,7 +160,7 @@ class ECSharedStorageConnector(ECConnectorBase): Args: scheduler_output (SchedulerOutput): the scheduler output object. """ - meta = ECSharedStorageConnectorMetadata() + meta = ECExampleConnectorMetadata() for mm_hash, num_encoder_token in self._mm_datas_need_loads.items(): meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token)) self._mm_datas_need_loads.clear() diff --git a/vllm/distributed/ec_transfer/ec_connector/factory.py b/vllm/distributed/ec_transfer/ec_connector/factory.py index e51b32e6f6dff..32f36ffbb14d2 100644 --- a/vllm/distributed/ec_transfer/ec_connector/factory.py +++ b/vllm/distributed/ec_transfer/ec_connector/factory.py @@ -79,7 +79,7 @@ class ECConnectorFactory: # only load the files corresponding to the current connector. ECConnectorFactory.register_connector( - "ECSharedStorageConnector", - "vllm.distributed.ec_transfer.ec_connector.shared_storage_connector", - "ECSharedStorageConnector", + "ECExampleConnector", + "vllm.distributed.ec_transfer.ec_connector.example_connector", + "ECExampleConnector", ) diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py index 4cd51dd384ad2..12e6cd417c50d 100644 --- a/vllm/distributed/eplb/__init__.py +++ b/vllm/distributed/eplb/__init__.py @@ -1,8 +1,3 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Expert parallelism load balancer (EPLB). -""" - -from .eplb_state import * -from .rebalance_algo import * +"""Expert parallelism load balancer (EPLB).""" diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 9f8798a96a2fc..c5654659b79d6 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -45,7 +45,7 @@ from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts from .async_worker import start_async_worker -from .rebalance_algo import rebalance_experts +from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace logger = init_logger(__name__) @@ -213,18 +213,23 @@ class EplbState: self.parallel_config = parallel_config self.device = device self.model_states: dict[str, EplbModelState] = {} + self.policy: type[AbstractEplbPolicy] = DefaultEplbPolicy + """ + Selected EPLB algorithm class + """ + self.expert_load_window_step: int = 0 """ Current step in the sliding window. Different from `expert_rearrangement_step`, each EP rank may have its own `expert_load_window_step`. """ - self.expert_load_window_step: int = 0 + self.expert_load_window_size: int = 0 """ Size of the expert load sliding window. This is a constant and is taken from the config. """ - self.expert_load_window_size: int = 0 + self.expert_rearrangement_step: int = 0 """ Steps after last rearrangement. Will trigger a rearrangement if it exceeds the threshold. @@ -415,6 +420,10 @@ class EplbState: ) self.expert_rearrangement_step_interval = eplb_step_interval + # Set the policy based on the selected eplb algorithm type. + policy_type = self.parallel_config.eplb_config.policy + self.policy = EPLB_POLICIES[policy_type] + logger.debug("Selected EPLB policy: %d", policy_type) if global_expert_load is not None: ep_group = get_ep_group().device_group assert global_expert_load.shape == ( @@ -441,7 +450,7 @@ class EplbState: new_physical_to_logical_map, new_logical_to_physical_map, new_logical_replica_count, - ) = rebalance_experts( + ) = self.policy.rebalance_experts( global_expert_load, num_replicas, num_groups, @@ -776,6 +785,7 @@ class EplbState: f"{num_gpus=}, {num_nodes=}" ) + # Get new expert mappings for eplb_model_state, global_expert_load_window in zip( self.model_states.values(), global_expert_load_windows ): @@ -784,7 +794,7 @@ class EplbState: new_physical_to_logical_map, new_logical_to_physical_map, new_logical_replica_count, - ) = rebalance_experts( + ) = self.policy.rebalance_experts( global_expert_load_window, num_replicas, num_groups, diff --git a/vllm/distributed/eplb/policy/__init__.py b/vllm/distributed/eplb/policy/__init__.py new file mode 100644 index 0000000000000..8e78d7bac0e35 --- /dev/null +++ b/vllm/distributed/eplb/policy/__init__.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import get_args + +from vllm.config.parallel import EPLBPolicyOption + +from .abstract import AbstractEplbPolicy +from .default import DefaultEplbPolicy + +EPLB_POLICIES = {"default": DefaultEplbPolicy} + +# Ensure that the EPLB_POLICIES keys match the EPLBPolicyOption values +assert set(EPLB_POLICIES.keys()) == set(get_args(EPLBPolicyOption)) + +__all__ = [ + "AbstractEplbPolicy", + "DefaultEplbPolicy", + "EPLB_POLICIES", +] diff --git a/vllm/distributed/eplb/policy/abstract.py b/vllm/distributed/eplb/policy/abstract.py new file mode 100644 index 0000000000000..40ed621c84892 --- /dev/null +++ b/vllm/distributed/eplb/policy/abstract.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod + +import torch + + +class AbstractEplbPolicy(ABC): + @classmethod + @abstractmethod + def rebalance_experts( + cls, + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_ranks: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics + for all logical experts + num_replicas: number of physical experts, must be a multiple of + `num_ranks` + num_groups: number of expert groups + num_nodes: number of server nodes + num_ranks: number of ranks, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert + index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], + the replica indices for each expert + expert_count: [layers, num_logical_experts], number of + physical replicas for each logical expert + """ + raise NotImplementedError diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py new file mode 100644 index 0000000000000..6127ec703184a --- /dev/null +++ b/vllm/distributed/eplb/policy/default.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Expert parallelism load balancer (EPLB) for vLLM. + +This module implements the core rearrangement algorithm. + +The rearrangement algorithm is adapted from +[DeepSeek EPLB](https://github.com/deepseek-ai/eplb). + +Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example +on how the EPLB algorithm works. +""" + +import numpy as np +import torch + +from .abstract import AbstractEplbPolicy + + +class DefaultEplbPolicy(AbstractEplbPolicy): + @classmethod + def balanced_packing( + cls, weight: torch.Tensor, num_packs: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly + n/m objects and the weights of all packs are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + device = weight.device + + if groups_per_pack == 1: + pack_index = torch.arange( + weight.size(-1), dtype=torch.int64, device=device + ).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) + return pack_index, rank_in_pack + + weight_np = weight.cpu().numpy() + + # Sort and get indices in decending order + indices_np = np.argsort(-weight_np, axis=-1) + + pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + + # Run the packing algorithm + for i in range(num_layers): + pack_weights = [0.0] * num_packs + pack_items = [0] * num_packs + + for group in indices_np[i]: + # Find a pack with capacity that has the lowest weight + pack = min( + (j for j in range(num_packs) if pack_items[j] < groups_per_pack), + key=pack_weights.__getitem__, + ) + + assert pack_items[pack] < groups_per_pack + pack_index_np[i, group] = pack + rank_in_pack_np[i, group] = pack_items[pack] + pack_weights[pack] += weight_np[i, group] + pack_items[pack] += 1 + + pack_index = torch.from_numpy(pack_index_np).to(device) + rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) + + return pack_index, rank_in_pack + + @classmethod + def replicate_experts( + cls, weight: torch.Tensor, num_phy: int + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum + load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + @classmethod + def rebalance_experts_hierarchical( + cls, + weight: torch.Tensor, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + phy2log: [layers, num_replicas], the expert + index of each replica + log2phy: [layers, num_logical_experts, X], + the replica indices for each expert + logcnt: [layers, num_logical_experts], number of + physical replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_( + 1, + perm, + torch.arange( + perm.size(1), dtype=torch.int64, device=perm.device + ).expand(perm.shape), + ) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = cls.balanced_packing( + tokens_per_group, num_nodes + ) + log2mlog = ( + ( + (group_pack_index * groups_per_node + group_rank_in_pack) * group_size + ).unsqueeze(-1) + + torch.arange( + group_size, dtype=torch.int64, device=group_pack_index.device + ) + ).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view( + -1, num_logical_experts // num_nodes + ) + phy2mlog, phyrank, mlogcnt = cls.replicate_experts( + tokens_per_mlog, num_physical_experts // num_nodes + ) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = cls.balanced_packing( + tokens_per_phy, num_gpus // num_nodes + ) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather( + -1, pphy2phy + ) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = ( + pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1) + ).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + @classmethod + def rebalance_experts( + cls, + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_ranks: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all + logical experts + num_replicas: number of physical experts, must be a multiple of + `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_ranks: number of ranks, must be a multiple of `num_nodes` + + Returns: + phy2log: [layers, num_replicas], the expert + index of each replica + log2phy: [layers, num_logical_experts, X], + the replica indices for each expert + logcnt: [layers, num_logical_experts], number of + physical replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.float() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_ranks + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_ranks + ) + num_redundant_experts = num_replicas - num_logical_experts + maxlogcnt = num_redundant_experts + 1 + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( + num_layers, -1 + ), + ) + return phy2log, log2phy, logcnt diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py deleted file mode 100644 index e6645e524cc3e..0000000000000 --- a/vllm/distributed/eplb/rebalance_algo.py +++ /dev/null @@ -1,260 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Expert parallelism load balancer (EPLB) for vLLM. - -This module implements the core rearrangement algorithm. - -The rearrangement algorithm is adapted from -[DeepSeek EPLB](https://github.com/deepseek-ai/eplb). - -Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example -on how the EPLB algorithm works. -""" - -import numpy as np -import torch - - -def balanced_packing( - weight: torch.Tensor, num_packs: int -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Pack n weighted objects to m packs, such that each bin contains exactly - n/m objects and the weights of all packs are as balanced as possible. - - Parameters: - weight: [X, n], the weight of each item - num_packs: number of packs - - Returns: - pack_index: [X, n], the pack index of each item - rank_in_pack: [X, n], the rank of the item in the pack - """ - num_layers, num_groups = weight.shape - assert num_groups % num_packs == 0 - groups_per_pack = num_groups // num_packs - - device = weight.device - - if groups_per_pack == 1: - pack_index = torch.arange( - weight.size(-1), dtype=torch.int64, device=device - ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) - return pack_index, rank_in_pack - - weight_np = weight.cpu().numpy() - - # Sort and get indices in decending order - indices_np = np.argsort(-weight_np, axis=-1) - - pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) - rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) - - # Run the packing algorithm - for i in range(num_layers): - pack_weights = [0.0] * num_packs - pack_items = [0] * num_packs - - for group in indices_np[i]: - # Find a pack with capacity that has the lowest weight - pack = min( - (j for j in range(num_packs) if pack_items[j] < groups_per_pack), - key=pack_weights.__getitem__, - ) - - assert pack_items[pack] < groups_per_pack - pack_index_np[i, group] = pack - rank_in_pack_np[i, group] = pack_items[pack] - pack_weights[pack] += weight_np[i, group] - pack_items[pack] += 1 - - pack_index = torch.from_numpy(pack_index_np).to(device) - rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) - - return pack_index, rank_in_pack - - -def replicate_experts( - weight: torch.Tensor, num_phy: int -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Replicate `num_log` experts to `num_phy` replicas, such that the maximum - load of all replicas is minimized. - - Parameters: - weight: [X, num_log] - num_phy: total number of experts after replication - - Returns: - phy2log: [X, num_phy], logical expert id of each physical expert - rank: [X, num_phy], the replica rank - logcnt: [X, num_log], number of replicas for each logical expert - """ - n, num_log = weight.shape - num_redundant = num_phy - num_log - assert num_redundant >= 0 - device = weight.device - phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) - rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) - logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) - arangen = torch.arange(n, dtype=torch.int64, device=device) - for i in range(num_log, num_phy): - redundant_indices = (weight / logcnt).max(dim=-1).indices - phy2log[:, i] = redundant_indices - rank[:, i] = logcnt[arangen, redundant_indices] - logcnt[arangen, redundant_indices] += 1 - return phy2log, rank, logcnt - - -def rebalance_experts_hierarchical( - weight: torch.Tensor, - num_physical_experts: int, - num_groups: int, - num_nodes: int, - num_gpus: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Parameters: - weight: [num_moe_layers, num_logical_experts] - num_physical_experts: number of physical experts after replication - num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network - (e.g., NVLink) is faster - num_gpus: number of GPUs, must be a multiple of `num_nodes` - - Returns: - physical_to_logical_map (torch.Tensor): - [num_moe_layers, num_physical_experts] - logical_to_physical_map (torch.Tensor): - [num_moe_layers, num_logical_experts, X] - logical_count (torch.Tensor): - [num_moe_layers, num_logical_experts] - """ - num_layers, num_logical_experts = weight.shape - assert num_logical_experts % num_groups == 0 - group_size = num_logical_experts // num_groups - assert num_groups % num_nodes == 0 - groups_per_node = num_groups // num_nodes - assert num_gpus % num_nodes == 0 - assert num_physical_experts % num_gpus == 0 - phy_experts_per_gpu = num_physical_experts // num_gpus - - def inverse(perm: torch.Tensor) -> torch.Tensor: - inv = torch.empty_like(perm) - inv.scatter_( - 1, - perm, - torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand( - perm.shape - ), - ) - return inv - - # Step 1: pack groups to nodes - tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) - group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) - log2mlog = ( - ( - (group_pack_index * groups_per_node + group_rank_in_pack) * group_size - ).unsqueeze(-1) - + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device) - ).flatten(-2) - mlog2log = inverse(log2mlog) - - # Step 2: construct redundant experts within nodes - # [num_layers * num_nodes, num_logical_experts // num_nodes] - tokens_per_mlog = weight.gather(-1, mlog2log).view( - -1, num_logical_experts // num_nodes - ) - phy2mlog, phyrank, mlogcnt = replicate_experts( - tokens_per_mlog, num_physical_experts // num_nodes - ) - - # Step 3: pack physical_experts to GPUs - # [num_layers * num_nodes, num_physical_experts // num_nodes] - tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) - pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) - phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack - pphy2phy = inverse(phy2pphy) - - pphy2mlog = phy2mlog.gather( - -1, pphy2phy - ) # [num_layers * num_nodes, num_log_per_nodes] - pphy2mlog = ( - pphy2mlog.view(num_layers, num_nodes, -1) - + torch.arange( - 0, - num_logical_experts, - num_logical_experts // num_nodes, - device=group_pack_index.device, - ).view(1, -1, 1) - ).flatten(-2) - pphy2log = mlog2log.gather(-1, pphy2mlog) - pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) - logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) - return pphy2log, pphyrank, logcnt - - -def rebalance_experts( - weight: torch.Tensor, - num_replicas: int, - num_groups: int, - num_nodes: int, - num_gpus: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Entry point for expert-parallelism load balancer. - - Parameters: - weight: [layers, num_logical_experts], the load statistics for all - logical experts - num_replicas: number of physical experts, must be a multiple of - `num_gpus` - num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network - (e.g, NVLink) is faster - num_gpus: number of GPUs, must be a multiple of `num_nodes` - - Returns: - physical_to_logical_map: - [layers, num_replicas], the expert index of each replica - logical_to_physical_map: - [layers, num_logical_experts, X], the replica indices for each - expert - expert_count: - [layers, num_logical_experts], number of physical - replicas for each logical expert - """ - num_layers, num_logical_experts = weight.shape - weight = weight.float() - if num_groups % num_nodes == 0: - # use hierarchical load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, num_groups, num_nodes, num_gpus - ) - else: - # use global load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, 1, 1, num_gpus - ) - num_redundant_experts = num_replicas - num_logical_experts - maxlogcnt = num_redundant_experts + 1 - log2phy: torch.Tensor = torch.full( - (num_layers, num_logical_experts, maxlogcnt), - -1, - dtype=torch.int64, - device=logcnt.device, - ) - log2phy.view(num_layers, -1).scatter_( - -1, - phy2log * maxlogcnt + phyrank, - torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( - num_layers, -1 - ), - ) - return phy2log, log2phy, logcnt - - -__all__ = ["rebalance_experts"] diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 376dad8a72ef1..55856d940f001 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -322,9 +322,6 @@ async def transfer_layer( num_local_physical_experts = next(iter(expert_weights[0])).shape[0] assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) assert num_physical_experts == ep_size * num_local_physical_experts - # A buffer to hold the expert weights in one layer during the exchange. - # NOTE: Currently we assume the same weights across different layers - # have the same shape. is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( num_local_experts=num_local_physical_experts, diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 0795989c11d0e..144f3ea2b028e 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -5,7 +5,7 @@ import queue import threading import time from abc import ABC, abstractmethod -from collections import deque +from collections import Counter, deque from collections.abc import Callable from dataclasses import asdict from itertools import count @@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent): lora_name: str | None medium: str | None + def __hash__(self) -> int: + return hash( + ( + tuple(self.block_hashes), + self.parent_block_hash, + tuple(self.token_ids), + self.block_size, + self.lora_id, + self.medium, + ) + ) + class BlockRemoved(KVCacheEvent): block_hashes: list[ExternalBlockHash] medium: str | None + def __hash__(self) -> int: + return hash((tuple(self.block_hashes), self.medium)) + class AllBlocksCleared(KVCacheEvent): pass @@ -68,6 +83,119 @@ class KVEventBatch(EventBatch): events: list[BlockStored | BlockRemoved | AllBlocksCleared] +class KVEventAggregator: + """ + Aggregates KV events across multiple workers. + Tracks how many times each event appears and returns only those + that were emitted by all workers. + """ + + __slots__ = ("_event_counter", "_num_workers") + + def __init__(self, num_workers: int) -> None: + if num_workers <= 0: + raise ValueError("num_workers must be greater than zero.") + self._event_counter: Counter[KVCacheEvent] = Counter() + self._num_workers: int = num_workers + + def add_events(self, events: list[KVCacheEvent]) -> None: + """ + Add events from a worker batch. + + :param events: List of KVCacheEvent objects. + """ + if not isinstance(events, list): + raise TypeError("events must be a list of KVCacheEvent.") + self._event_counter.update(events) + + def get_common_events(self) -> list[KVCacheEvent]: + """ + Return events that appeared in all workers. + + :return: List of events present in all workers. + """ + return [ + event + for event, count in self._event_counter.items() + if count == self._num_workers + ] + + def get_all_events(self) -> list[KVCacheEvent]: + """ + Return all events for all workers. + + :return: List of events for all workers. + """ + return list(self._event_counter.elements()) + + def clear_events(self) -> None: + """ + Clear all tracked events. + """ + self._event_counter.clear() + + def increment_workers(self, count: int = 1) -> None: + """ + Increment the number of workers contributing events. + + :param count: Number to increment the workers by. + """ + if count <= 0: + raise ValueError("count must be positive.") + self._num_workers += count + + def reset_workers(self) -> None: + """ + Reset the number of workers to 1. + """ + self._num_workers = 1 + + def get_number_of_workers(self) -> int: + """ + Return the number of workers. + + :return: int number of workers. + """ + return self._num_workers + + def __repr__(self) -> str: + return ( + f"<KVEventAggregator workers={self._num_workers}, " + f"events={len(self._event_counter)}>" + ) + + +class KVConnectorKVEvents(ABC): + """ + Abstract base class for KV events. + Acts as a container for KV events from the connector. + """ + + @abstractmethod + def add_events(self, events: list[KVCacheEvent]) -> None: + raise NotImplementedError + + @abstractmethod + def aggregate(self) -> "KVConnectorKVEvents": + raise NotImplementedError + + @abstractmethod + def increment_workers(self, count: int = 1) -> None: + raise NotImplementedError + + @abstractmethod + def get_all_events(self) -> list[KVCacheEvent]: + raise NotImplementedError + + @abstractmethod + def get_number_of_workers(self) -> int: + raise NotImplementedError + + @abstractmethod + def clear_events(self) -> None: + raise NotImplementedError + + class EventPublisher(ABC): """Lightweight publisher for EventBatch batches with data parallelism support. diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index df871dd7cbe4f..02d9a1ec9599e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -144,9 +144,9 @@ class KVConnectorFactory: # only load the files corresponding to the current connector. KVConnectorFactory.register_connector( - "SharedStorageConnector", - "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", - "SharedStorageConnector", + "ExampleConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.example_connector", + "ExampleConnector", ) KVConnectorFactory.register_connector( @@ -190,3 +190,8 @@ KVConnectorFactory.register_connector( "vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector", "DecodeBenchConnector", ) +KVConnectorFactory.register_connector( + "MooncakeConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector", + "MooncakeConnector", +) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index b8eb5ea3b4939..117d159e25e71 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -4,13 +4,14 @@ KV cache helper for store. """ +from dataclasses import dataclass from typing import TYPE_CHECKING, Literal import torch -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config import get_current_vllm_config from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -21,89 +22,6 @@ if TYPE_CHECKING: logger = init_logger(__name__) -class model_aware_kv_ops_helper: - def __init__(self, config: VllmConfig): - self.is_deepseek_mla = config.model_config.is_deepseek_mla - self.use_mla_opt = not envs.VLLM_MLA_DISABLE - self.tp_size = config.parallel_config.tensor_parallel_size - - def get_model_args(self, model_executable: torch.nn.Module): - model_config = model_executable.model.config - self.model_executable = model_executable - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - - # Deepseek's MLA (Multi-head Latent Attention) uses two different - # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. - # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, - # resulting in a kv_cache shape of [num_blks, blk_size, 1, - # kv_lora_rank + qk_rope_head_dim]. - # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading - # to a kv_cache shape of [2, num_blks, blk_size, - # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. - # For more details, see vllm/v1/attention/backends/mla/common.py. - if self.is_deepseek_mla and self.use_mla_opt: - head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim - num_heads = 1 - elif self.is_deepseek_mla and not self.use_mla_opt: - head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim - else: - head_size = getattr(model_config, "head_dim", None) - if head_size is None: - head_size = int(hidden_size // num_attention_heads) - - return num_heads, head_size - - def get_kv_from_cache(self, kv_cache, num_heads, head_size): - if self.is_deepseek_mla and self.use_mla_opt: - key_cache = kv_cache.reshape(-1, num_heads, head_size) - value_cache = kv_cache.reshape(-1, num_heads, head_size) - else: - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - return key_cache, value_cache - - def put_kv_to_cache( - self, - model_executable: torch.nn.Module, - keys, - values, - layer, - kv_cache, - slot_mapping, - start_pos, - end_pos, - ): - model_config = model_executable.model.config - - if self.is_deepseek_mla and self.use_mla_opt: - layer.self_attn.attn = layer.self_attn.mla_attn - k_c_normed_k_pe = keys.squeeze(1) - k_c_normed = k_c_normed_k_pe[:, : model_config.kv_lora_rank] - k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank :] - ops.concat_and_cache_mla( - k_c_normed.to(kv_cache.device), - k_pe.to(kv_cache.device), - kv_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - ) - else: - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - keys.to(key_cache.device), - values.to(value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) - - def get_kv_connector_cache_layout(): # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is # used for faster transfer. @@ -160,6 +78,7 @@ class KVOutputAggregator: finished_sending = set[str]() finished_recving = set[str]() aggregated_kv_connector_stats = None + combined_kv_cache_events = None invalid_block_ids = set[int]() for model_runner_output in outputs: assert model_runner_output is not None @@ -201,6 +120,19 @@ class KVOutputAggregator: aggregated_kv_connector_stats.aggregate(kv_connector_stats) ) + # Combine kv_cache_events from all workers. + if combined_kv_cache_events is None: + # Use the first worker's kv_cache events as start event list. + combined_kv_cache_events = kv_output.kv_cache_events + elif kv_cache_events := kv_output.kv_cache_events: + assert isinstance( + combined_kv_cache_events, + type(kv_cache_events), + ) + worker_kv_cache_events = kv_cache_events.get_all_events() + combined_kv_cache_events.add_events(worker_kv_cache_events) + combined_kv_cache_events.increment_workers(1) + invalid_block_ids |= kv_output.invalid_block_ids # select output of the worker specified by output_rank @@ -211,6 +143,7 @@ class KVOutputAggregator: finished_sending=finished_sending or None, finished_recving=finished_recving or None, kv_connector_stats=aggregated_kv_connector_stats or None, + kv_cache_events=combined_kv_cache_events or None, invalid_block_ids=invalid_block_ids, expected_finished_count=self._expected_finished_count, ) @@ -266,3 +199,124 @@ def copy_kv_blocks( src_tensor = src_kv_caches[layer_name] dst_tensor = dst_kv_caches[layer_name] copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) + + +@dataclass +class TpKVTopology: + """ + Helper class for tensor parallel and KV topology information for + mapping between local and remote TP workers. + """ + + tp_rank: int + remote_tp_size: dict[str, int] + is_mla: bool + total_num_kv_heads: int + attn_backend: type[AttentionBackend] + engine_id: str + remote_block_size: dict[str, int] + + def __post_init__(self): + # Figure out whether the first dimension of the cache is K/V + # or num_blocks. This is used to register the memory regions correctly. + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], + # we just mock num_blocks to 1 for the dimension check below. + self._is_kv_layout_blocks_first = ( + len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 + ) + + attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS + + @property + def is_kv_layout_blocks_first(self) -> bool: + return self._is_kv_layout_blocks_first + + @property + def split_k_and_v(self) -> bool: + # Whether to register regions for K and V separately (when present). + return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first) + + @property + def tp_size(self) -> int: + return self.remote_tp_size[self.engine_id] + + @property + def block_size(self) -> int: + return self.remote_block_size[self.engine_id] + + def tp_ratio( + self, + remote_tp_size: int, + ) -> int: + """ + Calculate the tensor parallel ratio between local and remote TP. + We can think of it as the number of local TP workers-per-remote TP + workers. Local workers will read from the same remote TP worker in + groups of size `tp_ratio`. + """ + assert self.tp_size % remote_tp_size == 0, ( + f"Local tensor parallel size {self.tp_size} is not divisible " + f"by remote tensor parallel size {remote_tp_size}." + ) + return self.tp_size // remote_tp_size + + def block_size_ratio( + self, + remote_block_size: int, + ) -> float: + """ + Calculate the block size ratio between local and remote TP. + """ + assert self.block_size % remote_block_size == 0, ( + f"Local block size {self.block_size} is not divisible " + f"by remote block size {remote_block_size} or vice versa." + ) + return self.block_size // remote_block_size + + def tp_ratio_from_engine_id( + self, + remote_engine_id: str, + ) -> int: + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.tp_ratio(remote_tp_size) + + def block_size_ratio_from_engine_id( + self, + remote_engine_id: str, + ) -> float: + remote_block_size = self.remote_block_size[remote_engine_id] + return self.block_size_ratio(remote_block_size) + + def is_kv_replicated(self, engine_id: str) -> bool: + """ + Whether the KV cache is replicated across TP workers due to the + number of TP workers being greater than the number of KV heads. + """ + tp_size = self.remote_tp_size[engine_id] + return tp_size // self.total_num_kv_heads >= 1 + + def replicates_kv_cache(self, remote_engine_id: str) -> bool: + # MLA is always replicated as the hidden dim can't be split. + return self.is_mla or self.is_kv_replicated(remote_engine_id) + + def get_target_remote_rank( + self, + remote_tp_size: int, + ) -> int: + """ + Get the remote TP rank (on P) that the current local TP rank + (on D) will read from. + """ + tp_ratio = self.tp_ratio(remote_tp_size) + return self.tp_rank // tp_ratio + + def get_target_remote_rank_from_engine_id( + self, + remote_engine_id: str, + ) -> int: + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.get_target_remote_rank(remote_tp_size) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index d37ec25675b72..c05e5485a835e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.config import VllmConfig - from vllm.distributed.kv_events import KVCacheEvent + from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, KVConnectorStats, @@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC): return def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] ): """ Initialize with a single KV cache tensor used by all layers. @@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC): """ return None + def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]: + """ + Get the KV connector kv cache events collected during the last interval. + This function should be called by the model runner every time after the + model execution and before cleanup. + """ + return None + def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: """ Get the KVConnector handshake metadata for this connector. @@ -573,3 +581,17 @@ class KVConnectorBase_V1(ABC): expose connector transfer stats via Prometheus. """ return None + + def reset_cache(self) -> bool | None: + """ + Reset the connector's internal cache. + + Returns: + bool: True if the cache was successfully reset, False otherwise. + """ + logger.debug( + "Connector cache reset requested, but %s does not implement reset_cache().", + type(self).__name__, + ) + + return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py similarity index 98% rename from vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py rename to vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py index ed641cfc43ddd..41243fc866b59 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py @@ -65,7 +65,7 @@ class ReqMeta: @dataclass -class SharedStorageConnectorMetadata(KVConnectorMetadata): +class ExampleConnectorMetadata(KVConnectorMetadata): requests: list[ReqMeta] = field(default_factory=list) def add_request( @@ -81,7 +81,7 @@ class SharedStorageConnectorMetadata(KVConnectorMetadata): ) -class SharedStorageConnector(KVConnectorBase_V1): +class ExampleConnector(KVConnectorBase_V1): # NOTE: This is Simple debug implementation of the KV connector. # It save / load the KV cache to / from the disk. # It does extra work which will overwrite the existing prefix-cache in GPU @@ -157,7 +157,7 @@ class SharedStorageConnector(KVConnectorBase_V1): # Get the metadata metadata: KVConnectorMetadata = self._get_connector_metadata() - assert isinstance(metadata, SharedStorageConnectorMetadata) + assert isinstance(metadata, ExampleConnectorMetadata) if metadata is None: logger.warning( @@ -241,7 +241,7 @@ class SharedStorageConnector(KVConnectorBase_V1): return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...] connector_metadata = self._get_connector_metadata() - assert isinstance(connector_metadata, SharedStorageConnectorMetadata) + assert isinstance(connector_metadata, ExampleConnectorMetadata) for request in connector_metadata.requests: if request.is_store: filename = self._generate_filename_debug( @@ -315,7 +315,7 @@ class SharedStorageConnector(KVConnectorBase_V1): Args: scheduler_output (SchedulerOutput): the scheduler output object. """ - meta = SharedStorageConnectorMetadata() + meta = ExampleConnectorMetadata() total_need_load = 0 for new_req in scheduler_output.scheduled_new_reqs: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 30da424ddcca0..17d468fe6c305 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -1,14 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable from typing import TYPE_CHECKING, Any import torch -from lmcache.integration.vllm.vllm_v1_adapter import ( - LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl, -) from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig +from vllm.distributed.kv_events import ( + BlockStored, + KVCacheEvent, + KVConnectorKVEvents, + KVEventAggregator, +) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, @@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.forward_context import ForwardContext @@ -26,6 +31,44 @@ if TYPE_CHECKING: logger = init_logger(__name__) +class LMCacheKVEvents(KVConnectorKVEvents): + """ + Concrete implementation of KVConnectorKVEvents using KVEventAggregator. + """ + + def __init__(self, num_workers: int) -> None: + self._aggregator = KVEventAggregator(num_workers) + + def add_events(self, events: list[KVCacheEvent]) -> None: + self._aggregator.add_events(events) + + def aggregate(self) -> "LMCacheKVEvents": + """ + Aggregate KV events and retain only common events. + """ + common_events = self._aggregator.get_common_events() + self._aggregator.clear_events() + self._aggregator.add_events(common_events) + self._aggregator.reset_workers() + return self + + def increment_workers(self, count: int = 1) -> None: + self._aggregator.increment_workers(count) + + def get_all_events(self) -> list[KVCacheEvent]: + return self._aggregator.get_all_events() + + def get_number_of_workers(self) -> int: + return self._aggregator.get_number_of_workers() + + def clear_events(self) -> None: + self._aggregator.clear_events() + self._aggregator.reset_workers() + + def __repr__(self) -> str: + return f"<LMCacheKVEvents events={self.get_all_events()}>" + + class LMCacheConnectorV1(KVConnectorBase_V1): def __init__( self, @@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1): cls = _adapter.LMCacheConnectorV1Impl else: logger.info("Initializing latest dev LMCache connector") + # lazy import + from lmcache.integration.vllm.vllm_v1_adapter import ( + LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl, + ) + cls = LMCacheConnectorLatestImpl self._lmcache_engine = cls(vllm_config, role, self) + self._kv_cache_events: LMCacheKVEvents | None = None + # ============================== # Worker-side methods # ============================== @@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1): # Fallback for older versions that don't support this method return set() + def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None: + """ + Get the KV connector kv cache events collected during the last interval. + """ + + events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined] + if not events: + return None + + blocks: list[BlockStored] = [ + BlockStored( + block_hashes=e.block_hashes, + parent_block_hash=e.parent_block_hash, + token_ids=e.token_ids, + lora_id=e.lora_id, + block_size=e.block_size, + medium=e.medium, + ) + for e in events + ] + + lmcache_kv_events = LMCacheKVEvents(num_workers=1) + lmcache_kv_events.add_events(blocks) + return lmcache_kv_events + # ============================== # Scheduler-side methods # ============================== @@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1): """ return self._lmcache_engine.build_connector_meta(scheduler_output) + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + # Get the KV events + kv_cache_events = connector_output.kv_cache_events + if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents): + return + + if self._kv_cache_events is None: + self._kv_cache_events = kv_cache_events + else: + self._kv_cache_events.add_events(kv_cache_events.get_all_events()) + self._kv_cache_events.increment_workers( + kv_cache_events.get_number_of_workers() + ) + return + def request_finished( self, request: "Request", @@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1): returned by the engine. """ return self._lmcache_engine.request_finished(request, block_ids) + + def take_events(self) -> Iterable["KVCacheEvent"]: + """ + Take the KV cache events from the connector. + + Yields: + New KV cache events since the last call. + """ + if self._kv_cache_events is not None: + self._kv_cache_events.aggregate() + kv_cache_events = self._kv_cache_events.get_all_events() + yield from kv_cache_events + self._kv_cache_events.clear_events() + self._kv_cache_events = None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 15ac5b049fce9..09af128f3ed74 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -27,7 +27,14 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import ( LMCacheAsyncLookupServer, ) from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer -from lmcache.v1.plugin.plugin_launcher import PluginLauncher + +try: + from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher +except ImportError: + # Backwards compatibility for lmcache <= 0.3.10-post1 + from lmcache.v1.plugin.plugin_launcher import ( + PluginLauncher as RuntimePluginLauncher, + ) from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig @@ -683,7 +690,7 @@ class LMCacheConnectorV1Impl: self.api_server = InternalAPIServer(self) self.api_server.start() # Launch plugins - self.plugin_launcher = PluginLauncher( + self.plugin_launcher = RuntimePluginLauncher( self.config, role, self.worker_count, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py index eb8342eb7129f..28aad71ab48f2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py @@ -7,7 +7,6 @@ from prometheus_client import Counter, Gauge, Histogram from vllm.config import KVTransferConfig, VllmConfig from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory -from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group from vllm.logger import init_logger PromMetric: TypeAlias = Gauge | Counter | Histogram @@ -53,8 +52,6 @@ class KVConnectorStats: class KVConnectorLogging: def __init__(self, kv_transfer_config: KVTransferConfig | None): - # This should be called on frontend process. - assert not has_kv_transfer_group() # Instantiate the connector's stats class. if kv_transfer_config and kv_transfer_config.kv_connector: self.connector_cls = KVConnectorFactory.get_connector_class( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py new file mode 100644 index 0000000000000..705960aebe2da --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py @@ -0,0 +1,914 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import threading +import time +import uuid +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import numpy as np +import torch +import zmq +import zmq.asyncio + +from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.selector import get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger +from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket +from vllm.v1.attention.backends.utils import get_kv_cache_layout +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +try: + from mooncake.engine import TransferEngine +except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run VLLM with MooncakeTransferEngine." + ) from e + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +EngineId = str +ReqId = str + +TRANS_DONE = b"trans_done" +TRANS_ERROR = b"trans_error" + +logger = init_logger(__name__) + + +class MooncakeAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): + remote_hostname: str + remote_port: int + request_ids: list[ReqId] + kv_caches_base_addr: list[int] + block_ids: list[list[int]] + + +@dataclass +class RecvReqMeta: + local_block_ids: list[int] + remote_host: str + remote_port: int + + +@dataclass +class SendBlockMeta: + local_block_ids: list[int] + ready: threading.Event + expire_time: float = float("inf") + + +@dataclass +class SendReqMeta: + reqs: dict[ReqId, SendBlockMeta] + lock: threading.Lock + + +@dataclass +class FinishedSendReqSet: + set: set[ReqId] + lock: threading.Lock + + +@dataclass +class FinishedReceiveReqSet: + set: set[ReqId] + lock: asyncio.Lock + + +class MooncakeConnectorMetadata(KVConnectorMetadata): + def __init__(self): + self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {} + self.reqs_to_send: dict[ReqId, list[int]] = {} + + def add_new_req( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + load_remote_cache: bool = True, + ): + if load_remote_cache: + self.reqs_to_recv[request_id] = RecvReqMeta( + local_block_ids=local_block_ids, + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + ) + else: + self.reqs_to_send[request_id] = local_block_ids + + +class MooncakeConnector(KVConnectorBase_V1): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + + assert vllm_config.kv_transfer_config is not None + assert vllm_config.kv_transfer_config.engine_id is not None + self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: MooncakeConnectorScheduler | None = ( + MooncakeConnectorScheduler(vllm_config, self.engine_id) + ) + self.connector_worker: MooncakeConnectorWorker | None = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens + ) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens + ) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, MooncakeConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeConnector does not do layerwise saving.""" + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> None: + """MooncakeConnector does not save explicitly.""" + pass + + def wait_for_save(self): + pass + + +class MooncakeConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.engine_id: EngineId = engine_id + self.side_channel_host = get_ip() + self.side_channel_port = get_mooncake_side_channel_port(vllm_config) + + assert vllm_config.kv_transfer_config + self.kv_role = vllm_config.kv_transfer_config.kv_role + logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id) + + # Requests that need to start recv/send. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[ReqId, list[int]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, + params, + ) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + token_ids = request.prompt_token_ids or [] + count = len(token_ids) - num_computed_tokens + if count > 0: + return count, True + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, + params, + ) + + if not params: + return + + if params.get("do_remote_prefill"): + assert self.kv_role != "kv_producer" + if all(p in params for p in ("remote_host", "remote_port")): + # If remote_blocks and num_external_tokens = 0, we have + # a full prefix cache hit on the D worker. We need to call + # send_notif in _read_blocks to free the memory on the P. + local_block_ids = ( + blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] + ) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = (request, local_block_ids) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", + params, + ) + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + + elif params.get("do_remote_decode"): + # Add an empty list to worker to create event. + self._reqs_need_send[request.request_id] = [] + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = MooncakeConnectorMetadata() + + # Loop through scheduled reqs and convert to RecvReqMeta. + if self.kv_role != "kv_producer": + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + self._reqs_need_recv.clear() + + if self.kv_role != "kv_consumer": + for req_id, block_ids in self._reqs_need_send.items(): + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params={}, + load_remote_cache=False, + ) + self._reqs_need_send.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", + request.status, + params, + ) + if not params: + return False, None + + if params.get("do_remote_prefill"): + # If do_remote_prefill is still True when the request is finished, + # update_state_after_alloc must not have been called (the request + # must have been aborted before it was scheduled). + # To avoid stranding the prefill blocks in the prefill instance, + # we must add empty block_ids to _reqs_need_recv so that our + # worker side will notify and free blocks in the prefill instance. + assert self.kv_role != "kv_producer" + self._reqs_need_recv[request.request_id] = (request, []) + params["do_remote_prefill"] = False + return False, None + + if ( + not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED + ): + return False, None + + assert self.kv_role != "kv_consumer" + + # TODO: check whether block_ids actually ever be 0. If not we could + # remove the conditional below + delay_free_blocks = len(block_ids) > 0 + + if delay_free_blocks: + self._reqs_need_send[request.request_id] = block_ids + + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + ) + + +class MooncakeConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id) + + self.vllm_config = vllm_config + + self.engine = TransferEngine() + self.hostname = get_ip() + ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "") + if ret_value != 0: + raise RuntimeError("Mooncake Transfer Engine initialization failed.") + + self.rpc_port = self.engine.get_rpc_port() + + logger.debug( + "Mooncake Transfer Engine initialized at %s:%d", + self.hostname, + self.rpc_port, + ) + + # Mooncake handshake port. + self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config) + + self.engine_id: EngineId = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + self.num_blocks = 0 + + assert vllm_config.kv_transfer_config + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "num_workers", 10 + ) + + self.kv_caches_base_addr: list[int] = [] + self.device_kv_caches: dict[str, torch.Tensor] = {} + self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock()) + + # For kv_both, we will act both prefiller and decoder. + if self.kv_role != "kv_consumer": + # Background thread for sending kvcaches to D. + self._mooncake_sender_t: threading.Thread | None = None + # Background thread for processing new sending requests. + self._sender_executor = ThreadPoolExecutor( + max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender" + ) + logger.debug( + "Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers + ) + if self.kv_role != "kv_producer": + self.receiver_loop = asyncio.new_event_loop() + self._mooncake_receiver_t = threading.Thread( + target=self._receiver_loop, args=(self.receiver_loop,), daemon=True + ) + self._mooncake_receiver_t.start() + logger.debug("Mooncake Decoder: start receiver thread") + + self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet( + set(), threading.Lock() + ) + self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet( + set(), asyncio.Lock() + ) + + self.block_size = vllm_config.cache_config.block_size + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.use_mla = self.model_config.use_mla + + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla, + ) + self.backend_name = backend.get_name() + self.kv_cache_layout = get_kv_cache_layout() + logger.debug("Detected attention backend %s", self.backend_name) + logger.debug("Detected kv cache layout %s", self.kv_cache_layout) + + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} + self.kv_topo = TpKVTopology( + tp_rank=self.tp_rank, + engine_id=self.engine_id, + remote_tp_size=self._tp_size, # shared state + remote_block_size=self._block_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=backend, + ) + self._use_pallas = self.kv_topo._use_pallas + + self.zmq_ctx = zmq.Context() + self.async_zmq_ctx = zmq.asyncio.Context() + self._encoder = msgspec.msgpack.Encoder() + self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) + + def __del__(self): + self.shutdown() + + def shutdown(self): + """Cleanup background threads on destruction.""" + self.zmq_ctx.term() + self.async_zmq_ctx.term() + if self.kv_role != "kv_consumer": + self._sender_executor.shutdown(wait=False) + if self._mooncake_sender_t: + self._mooncake_sender_t.join() + if self.kv_role != "kv_producer" and self.receiver_loop.is_running(): + self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop) + self._mooncake_receiver_t.join() + + def _receiver_loop(self, loop: asyncio.AbstractEventLoop): + asyncio.set_event_loop(loop) + loop.run_forever() + + def _mooncake_sender( + self, ready_event: threading.Event, base_port: int, tp_rank: int + ): + """ + Background thread that listens for Mooncake requests, dispatches them + to a thread pool, and sends acknowledgments upon completion. + """ + + frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank) + frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER) + logger.debug("Mooncake sender starting listening on path: %s", frontend_path) + + backend_path = make_zmq_path("inproc", str(uuid.uuid4())) + backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL) + + poller = zmq.Poller() + poller.register(frontend, zmq.POLLIN) + poller.register(backend, zmq.POLLIN) + + ready_event.set() + + try: + while True: + sockets = dict(poller.poll()) + + if frontend in sockets: + identity, _, metadata_bytes = frontend.recv_multipart() + self._sender_executor.submit( + self._sender_worker, + identity, + metadata_bytes, + backend_path, + ) + + if backend in sockets: + identity, status = backend.recv_multipart() + frontend.send_multipart((identity, b"", status)) + + except zmq.ContextTerminated: + logger.debug("ZMQ context terminated, exiting Mooncake sender thread.") + except Exception as e: + logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e)) + finally: + frontend.close() + backend.close() + + def _sender_worker( + self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str + ): + status = TRANS_ERROR + + try: + metadata = self._decoder.decode(metadata_bytes) + self.send_kv_to_decode(metadata) + status = TRANS_DONE + except Exception as e: + logger.error("Error processing Mooncake handshake: %s", e) + finally: + pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH) + try: + pusher.send_multipart((identity, status)) + except zmq.ZMQError as e: + logger.warning( + "Internal error, maybe the server is shutting down. Error: %s", + e, + ) + finally: + pusher.close() + + def send_kv_to_decode(self, meta: MooncakeAgentMetadata): + send_reqs: list[tuple[ReqId, SendBlockMeta]] = [] + with self.reqs_need_send.lock: + for req_id in meta.request_ids: + send_meta = self.reqs_need_send.reqs.get(req_id) + if send_meta is None: + logger.warning("Request %s not found in reqs_need_send", req_id) + return + # Mark it as not expired. We will send it now. + send_meta.expire_time = float("inf") + send_reqs.append((req_id, send_meta)) + + self._send_blocks(send_reqs, meta) + + with self.reqs_need_send.lock: + for req_id in meta.request_ids: + del self.reqs_need_send.reqs[req_id] + + with self.finished_sending_reqs.lock: + self.finished_sending_reqs.set.update(meta.request_ids) + + def _send_blocks( + self, + send_reqs: list[tuple[ReqId, SendBlockMeta]], + agent_meta: MooncakeAgentMetadata, + ): + src_ptrs = [] + dst_ptrs = [] + lengths = [] + local_base_addr = self.kv_caches_base_addr + remote_base_addr = agent_meta.kv_caches_base_addr + block_len = self.block_len + remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}" + + assert len(send_reqs) == len(agent_meta.block_ids) + for (req_id, send_meta), remote_block_ids in zip( + send_reqs, agent_meta.block_ids + ): + send_meta.ready.wait() + + num_remote_blocks = len(remote_block_ids) + if num_remote_blocks == 0: + continue + + local_block_ids = send_meta.local_block_ids + # Partial prefix cache hit: just read uncomputed blocks. + num_local_blocks = len(local_block_ids) + assert num_local_blocks >= num_remote_blocks + if num_local_blocks > num_remote_blocks: + local_block_ids = local_block_ids[-num_remote_blocks:] + + # Group by indices + group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous( + local_block_ids, remote_block_ids + ) + + for local_layer_addr, remote_layer_addr in zip( + local_base_addr, remote_base_addr + ): + for group_local_block_id, group_remote_block_id in zip( + group_local_block_ids, group_remote_block_ids + ): + src_ptrs.append( + local_layer_addr + group_local_block_id[0] * block_len + ) + dst_ptrs.append( + remote_layer_addr + group_remote_block_id[0] * block_len + ) + lengths.append(block_len * len(group_local_block_id)) + + logger.debug( + "Sending kv_caches for request %s (%d blocks) to %s", + req_id, + num_remote_blocks, + remote_session, + ) + + start_time = time.perf_counter() + ret_value = self.engine.batch_transfer_sync_write( + remote_session, src_ptrs, dst_ptrs, lengths + ) + if ret_value != 0: + raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}") + + logger.debug( + "Sending to %s done, took %s", + remote_session, + time.perf_counter() - start_time, + ) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in mooncake.""" + + logger.info("Registering KV_Caches. use_mla: %s", self.use_mla) + + kv_data_ptrs = [] + kv_data_lens = [] + seen_base_addresses = [] + + split_k_and_v = self.kv_topo.split_k_and_v + tensor_size_bytes = None + for layer_name, cache_or_caches in kv_caches.items(): + logger.debug( + "registering layer %s with shape %s", layer_name, cache_or_caches.shape + ) + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + + for cache in cache_list: + base_addr = cache.data_ptr() + if base_addr in seen_base_addresses: + continue + + seen_base_addresses.append(base_addr) + curr_tensor_size_bytes = cache.nbytes + + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] + + assert tensor_size_bytes == curr_tensor_size_bytes, ( + "All kv cache tensors must have the same size" + ) + kernel_block_size = cache.shape[-2 if self.use_mla else -3] + assert self.block_size == kernel_block_size + kv_data_ptrs.append(base_addr) + kv_data_lens.append(tensor_size_bytes) + + self.kv_caches_base_addr = seen_base_addresses + + ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens) + if ret_value != 0: + raise RuntimeError("Mooncake batch memory registration failed.") + + assert tensor_size_bytes is not None + assert self.num_blocks != 0 + assert tensor_size_bytes % self.num_blocks == 0 + self.block_len = tensor_size_bytes // self.num_blocks + self.device_kv_caches = kv_caches + logger.debug( + "registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len + ) + + # No need to launch server for D node. + if self.kv_role == "kv_consumer": + return + + ready_event = threading.Event() + self._mooncake_sender_t = threading.Thread( + target=self._mooncake_sender, + args=(ready_event, self.side_channel_port, self.tp_rank), + daemon=True, + name="mooncake_sender", + ) + self._mooncake_sender_t.start() + ready_event.wait() # Wait for listener ZMQ socket to be ready. + + async def fetch_finished_recving_reqs(self) -> set[ReqId]: + async with self.finished_recving_reqs.lock: + finished_recving_reqs = self.finished_recving_reqs.set + self.finished_recving_reqs.set = set() + return finished_recving_reqs + + def get_finished(self) -> tuple[set[str] | None, set[str] | None]: + """ + Get requests that are done sending or recving on this specific worker. + The scheduler process (via the MultiprocExecutor) will use this output + to track which workers are done. + """ + fut = None + if self.kv_role != "kv_producer": + fut = asyncio.run_coroutine_threadsafe( + self.fetch_finished_recving_reqs(), self.receiver_loop + ) + + if self.kv_role != "kv_consumer": + with self.finished_sending_reqs.lock: + finished_sending_reqs = self.finished_sending_reqs.set + self.finished_sending_reqs.set = set() + else: + finished_sending_reqs = set() + + finished_recving_reqs = fut.result() if fut else set() + + if finished_sending_reqs or finished_recving_reqs: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", + self.tp_rank, + len(finished_sending_reqs), + len(finished_recving_reqs), + ) + + # Handle timeout to avoid stranding blocks on remote. + now = time.perf_counter() + with self.reqs_need_send.lock: + expired_reqs = [ + req_id + for req_id, send_meta in self.reqs_need_send.reqs.items() + if send_meta.expire_time < now + ] + for req_id in expired_reqs: + logger.warning( + "Request %s timed out after %d seconds without " + "being sent. Freeing its blocks on the producer side.", + req_id, + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, + ) + del self.reqs_need_send.reqs[req_id] + if expired_reqs: + finished_sending_reqs.update(expired_reqs) + + return finished_sending_reqs or None, finished_recving_reqs or None + + async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]): + req_ids, block_ids = map(list, zip(*req_blocks)) + metadata = MooncakeAgentMetadata( + remote_hostname=self.hostname, + remote_port=self.rpc_port, + request_ids=req_ids, + kv_caches_base_addr=self.kv_caches_base_addr, + block_ids=block_ids, + ) + + encoded_data = self._encoder.encode(metadata) + logger.debug( + "Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data) + ) + logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path) + + # Send query for the request. + sock: zmq.asyncio.Socket = make_zmq_socket( + self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0 + ) + sock.setsockopt(zmq.RCVTIMEO, 60000) + try: + await sock.send(encoded_data) + ret_msg = await sock.recv() + if ret_msg != TRANS_DONE: + logger.error( + "Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501 + req_ids, + ) + return + except zmq.ContextTerminated: + logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.") + except Exception as e: + logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e) + return + finally: + sock.close() + + async with self.finished_recving_reqs.lock: + self.finished_recving_reqs.set.update(req_ids) + + logger.debug("pulling kv_caches for %s finished", req_ids) + + def group_kv_pull(self, metadata: MooncakeConnectorMetadata): + kv_pulls = defaultdict(list) + for req_id, meta in metadata.reqs_to_recv.items(): + logger.debug( + "start_load_kv for request %s from remote engine. " + "Num local_block_ids: %s.", + req_id, + len(meta.local_block_ids), + ) + path = make_zmq_path( + "tcp", meta.remote_host, meta.remote_port + self.tp_rank + ) + kv_pulls[path].append((req_id, meta.local_block_ids)) + + return kv_pulls + + def start_load_kv(self, metadata: MooncakeConnectorMetadata): + if self.kv_role != "kv_producer": + kv_pulls = self.group_kv_pull(metadata) + for path, req_blocks in kv_pulls.items(): + asyncio.run_coroutine_threadsafe( + self.receive_kv(path, req_blocks), self.receiver_loop + ) + + if self.kv_role != "kv_consumer": + with self.reqs_need_send.lock: + for req_id, block_ids in metadata.reqs_to_send.items(): + if block_ids: + # Already gone through request_finished() + send_meta = self.reqs_need_send.reqs[req_id] + send_meta.local_block_ids = block_ids + send_meta.ready.set() + send_meta.expire_time = ( + time.perf_counter() + + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT + ) + else: + # From update_state_after_alloc(), + # but not reach request_finished() yet + self.reqs_need_send.reqs[req_id] = SendBlockMeta( + local_block_ids=[], ready=threading.Event() + ) + + +def group_concurrent_contiguous( + src_indices: list[int], dst_indices: list[int] +) -> tuple[list[list[int]], list[list[int]]]: + """Vectorised NumPy implementation.""" + if len(src_indices) == 0: + return [], [] + + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) + + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] + + return src_groups, dst_groups + + +def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int: + # This logic is now centralized + return ( + envs.VLLM_MOONCAKE_BOOTSTRAP_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 51d5df6c6ba15..6825745374959 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1): agg_block_ids |= c.get_block_ids_with_load_errors() return agg_block_ids + # TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method + # for the MultiConnector. It should be able to get events from multiple + # connectors, handling the case where only a subset of the requested connectors + # implements the 'get_kv_connector_kv_cache_events' + # Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082 + # ============================== # Scheduler-side methods # ============================== @@ -452,3 +458,7 @@ class MultiConnector(KVConnectorBase_V1): per_engine_labelvalues, prom_metrics, ) + + def reset_cache(self) -> bool: + results = [c.reset_cache() is not False for c in self._connectors] + return all(results) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 41e32bb73d40b..fb4b8ac391afb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -20,10 +20,10 @@ import torch import zmq from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata -from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, KVConnectorBase_V1, @@ -55,10 +55,26 @@ if TYPE_CHECKING: from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request -Transfer = tuple[int, float] # (xfer_handle, start_time) +TransferHandle = int EngineId = str ReqId = str +# +# NIXL Connector Version +# +# Increment this version whenever there is an incompatible change to: +# - NixlAgentMetadata schema +# - kv_transfer_params schema or semantics +# - NIXL transfer protocol or wire format +# - KV cache memory layout or block organization +# - Any other change that breaks P/D interoperability +# +# Version History: +# 1: Initial version with compatibility checking +# 2: Add remote_request_id to kv_transfer_params +# +NIXL_CONNECTOR_VERSION: int = 2 + GET_META_MSG = b"get_meta_msg" logger = init_logger(__name__) @@ -97,28 +113,111 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) @dataclass -class NixlAgentMetadata(KVConnectorHandshakeMetadata): +class NixlAgentMetadata: engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] device_id: int num_blocks: int block_lens: list[int] - attn_backend_name: str kv_cache_layout: str block_size: int +@dataclass +class NixlHandshakePayload(KVConnectorHandshakeMetadata): + """ + Wrapper for NIXL handshake sent over the wire. + + Enables two-phase decoding for graceful compatibility checking: + 1. Decode NixlHandshakePayload to get compatibility_hash + 2. Compute local hash and compare + 3. Only if hashes match, decode agent_metadata_bytes + + This prevents decoder errors when NixlAgentMetadata schema is + incompatible, allowing graceful failure with clear error message. + """ + + compatibility_hash: str + agent_metadata_bytes: bytes # NixlAgentMetadata encoded + + +def compute_nixl_compatibility_hash( + vllm_config: VllmConfig, attn_backend_name: str +) -> str: + """ + Compute compatibility hash for NIXL KV transfer. + + Hash only the factors that affect whether two NIXL instances can + successfully transfer KV cache data. + + Factors included: + - vLLM version and NIXL connector version + - Model architecture (name, dtype, KV heads, layers) + - KV cache format (dtype, sliding window) + - Attention backend + + Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout + are validated at runtime in _validate_remote_agent_handshake and are not + included in this hash to support heterogeneous deployments. + + Note - the set of factors are likely to evolve significantly over + time to be more or less permissive. + + Returns: + SHA-256 hex digest + """ + from vllm import __version__ as vllm_version + from vllm.config.utils import hash_factors + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + + factors = { + # Version compatibility + "vllm_version": vllm_version, + "nixl_connector_version": NIXL_CONNECTOR_VERSION, + # Model architecture - affects KV cache shape + "model": model_config.model, + "dtype": str(model_config.dtype), + "num_kv_heads": model_config.get_total_num_kv_heads(), + "head_size": model_config.get_head_size(), + "num_hidden_layers": model_config.get_total_num_hidden_layers(), + # Attention backend and KV cache dtype affect memory layout + "attn_backend_name": attn_backend_name, + "cache_dtype": str(cache_config.cache_dtype), + } + + compat_hash = hash_factors(factors) + logger.debug( + "NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, " + "cache_dtype=%s, attn_backend=%s)", + compat_hash, + factors["model"], + factors["dtype"], + factors["num_kv_heads"], + factors["cache_dtype"], + attn_backend_name, + ) + return compat_hash + + +@dataclass +class RemoteMeta: + block_ids: list[int] + host: str + port: int + engine_id: str + request_id: str + + @dataclass class ReqMeta: local_block_ids: list[int] # To be used when logical block size does not match the kernel block size local_physical_block_ids: list[int] - remote_block_ids: list[int] - remote_host: str - remote_port: int - remote_engine_id: str tp_size: int + remote: RemoteMeta | None = None class NixlConnectorMetadata(KVConnectorMetadata): @@ -129,30 +228,43 @@ class NixlConnectorMetadata(KVConnectorMetadata): self.reqs_in_batch: set[ReqId] = set() self.reqs_not_processed: set[ReqId] = set() - def add_new_req( + def _add_new_req( + self, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ) -> ReqMeta: + return ReqMeta( + local_block_ids=local_block_ids, + local_physical_block_ids=local_block_ids, + # P workers don't need to receive tp_size from proxy here. + tp_size=kv_transfer_params.get("tp_size", 1), + ) + + def add_new_req_to_save( self, request_id: ReqId, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - load_remote_cache: bool = True, - save_to_host: bool = False, ): - # save and load are mutually exclusive - assert load_remote_cache ^ save_to_host - _req = ReqMeta( - local_block_ids=local_block_ids, - local_physical_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params["remote_block_ids"], - remote_engine_id=kv_transfer_params["remote_engine_id"], - remote_host=kv_transfer_params["remote_host"], - remote_port=kv_transfer_params["remote_port"], - # P workers don't need to receive tp_size from proxy here. - tp_size=kv_transfer_params.get("tp_size", 1), + self.reqs_to_save[request_id] = self._add_new_req( + local_block_ids, kv_transfer_params ) - if save_to_host: - self.reqs_to_save[request_id] = _req - if load_remote_cache: - self.reqs_to_recv[request_id] = _req + + def add_new_req_to_recv( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ): + req = self._add_new_req(local_block_ids, kv_transfer_params) + req.remote = RemoteMeta( + block_ids=kv_transfer_params["remote_block_ids"], + engine_id=kv_transfer_params["remote_engine_id"], + request_id=kv_transfer_params["remote_request_id"], + host=kv_transfer_params["remote_host"], + port=kv_transfer_params["remote_port"], + ) + self.reqs_to_recv[request_id] = req class NixlConnector(KVConnectorBase_V1): @@ -396,14 +508,14 @@ class NixlConnectorScheduler: encoded_data: dict[int, bytes] = {} encoder = msgspec.msgpack.Encoder() for tp_rank, rank_metadata in metadata.items(): - if not isinstance(rank_metadata, NixlAgentMetadata): + if not isinstance(rank_metadata, NixlHandshakePayload): raise ValueError( - "NixlConnectorScheduler expects NixlAgentMetadata for " + "NixlConnectorScheduler expects NixlHandshakePayload for " "handshake metadata." ) encoded_data[tp_rank] = encoder.encode(rank_metadata) logger.debug( - "Tp rank %d: encoded NixlAgentMetadata size: %s bytes", + "Tp rank %d: encoded NixlHandshakePayload size: %s bytes", tp_rank, str(len(encoded_data[tp_rank])), ) @@ -530,7 +642,12 @@ class NixlConnectorScheduler: if params.get("remote_block_ids"): if all( p in params - for p in ("remote_engine_id", "remote_host", "remote_port") + for p in ( + "remote_engine_id", + "remote_request_id", + "remote_host", + "remote_port", + ) ): # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call @@ -566,22 +683,18 @@ class NixlConnectorScheduler: # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None - meta.add_new_req( + meta.add_new_req_to_recv( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - load_remote_cache=True, - save_to_host=False, ) for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - meta.add_new_req( + meta.add_new_req_to_save( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - load_remote_cache=False, - save_to_host=True, ) meta.reqs_to_send = self._reqs_need_send @@ -659,6 +772,7 @@ class NixlConnectorScheduler: do_remote_decode=False, remote_block_ids=block_ids, remote_engine_id=self.engine_id, + remote_request_id=request.request_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, tp_size=self.vllm_config.parallel_config.tensor_parallel_size, @@ -668,128 +782,6 @@ class NixlConnectorScheduler: class NixlConnectorWorker: """Implementation of Worker side methods""" - @dataclass - class TpKVTopology: - """ - Helper class for tensor parallel and KV topology information for - mapping between local and remote TP workers. - """ - - tp_rank: int - remote_tp_size: dict[EngineId, int] - is_mla: bool - total_num_kv_heads: int - attn_backend: type[AttentionBackend] - engine_id: EngineId - remote_block_size: dict[EngineId, int] - - def __post_init__(self): - # Figure out whether the first dimension of the cache is K/V - # or num_blocks. This is used to register the memory regions correctly. - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], - # we just mock num_blocks to 1 for the dimension check below. - self._is_kv_layout_blocks_first = ( - len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 - ) - - attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] - self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS - - @property - def is_kv_layout_blocks_first(self) -> bool: - return self._is_kv_layout_blocks_first - - @property - def split_k_and_v(self) -> bool: - # Whether to register regions for K and V separately (when present). - return not ( - self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first - ) - - @property - def tp_size(self) -> int: - return self.remote_tp_size[self.engine_id] - - @property - def block_size(self) -> int: - return self.remote_block_size[self.engine_id] - - def tp_ratio( - self, - remote_tp_size: int, - ) -> int: - """ - Calculate the tensor parallel ratio between local and remote TP. - We can think of it as the number of local TP workers-per-remote TP - workers. Local workers will read from the same remote TP worker in - groups of size `tp_ratio`. - """ - assert self.tp_size % remote_tp_size == 0, ( - f"Local tensor parallel size {self.tp_size} is not divisible " - f"by remote tensor parallel size {remote_tp_size}." - ) - return self.tp_size // remote_tp_size - - def block_size_ratio( - self, - remote_block_size: int, - ) -> float: - """ - Calculate the block size ratio between local and remote TP. - """ - assert self.block_size % remote_block_size == 0, ( - f"Local block size {self.block_size} is not divisible " - f"by remote block size {remote_block_size} or vice versa." - ) - return self.block_size // remote_block_size - - def tp_ratio_from_engine_id( - self, - remote_engine_id: EngineId, - ) -> int: - remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.tp_ratio(remote_tp_size) - - def block_size_ratio_from_engine_id( - self, - remote_engine_id: EngineId, - ) -> float: - remote_block_size = self.remote_block_size[remote_engine_id] - return self.block_size_ratio(remote_block_size) - - def is_kv_replicated(self, engine_id: EngineId) -> bool: - """ - Whether the KV cache is replicated across TP workers due to the - number of TP workers being greater than the number of KV heads. - """ - tp_size = self.remote_tp_size[engine_id] - return tp_size // self.total_num_kv_heads >= 1 - - def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: - # MLA is always replicated as the hidden dim can't be split. - return self.is_mla or self.is_kv_replicated(remote_engine_id) - - def get_target_remote_rank( - self, - remote_tp_size: int, - ) -> int: - """ - Get the remote TP rank (on P) that the current local TP rank - (on D) will read from. - """ - tp_ratio = self.tp_ratio(remote_tp_size) - return self.tp_rank // tp_ratio - - def get_target_remote_rank_from_engine_id( - self, - remote_engine_id: EngineId, - ) -> int: - remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.get_target_remote_rank(remote_tp_size) - def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") @@ -904,7 +896,7 @@ class NixlConnectorWorker: # In progress transfers. # [req_id -> list[handle]] self._recving_metadata: dict[ReqId, ReqMeta] = {} - self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) + self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list) # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} # Set of requests that have been part of a batch, regardless of status. @@ -916,7 +908,7 @@ class NixlConnectorWorker: self._failed_recv_reqs: set[ReqId] = set() # Handshake metadata of this worker for NIXL transfers. - self.xfer_handshake_metadata: NixlAgentMetadata | None = None + self.xfer_handshake_metadata: NixlHandshakePayload | None = None # Background thread for initializing new NIXL handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. @@ -951,6 +943,13 @@ class NixlConnectorWorker: logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) + self.compat_hash = compute_nixl_compatibility_hash( + self.vllm_config, self.backend_name + ) + self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config( + "enforce_handshake_compat", True + ) + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} # With heterogeneous TP, P must wait for all assigned D TP workers to @@ -958,7 +957,7 @@ class NixlConnectorWorker: self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - self.kv_topo = self.TpKVTopology( + self.kv_topo = TpKVTopology( tp_rank=self.tp_rank, engine_id=self.engine_id, remote_tp_size=self._tp_size, # shared state @@ -999,14 +998,58 @@ class NixlConnectorWorker: # Set receive timeout to 5 seconds to avoid hanging on dead server sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds sock.send(msg) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) + handshake_bytes = sock.recv() + + # Decode handshake payload to get compatibility hash + handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload) + try: + handshake_payload = handshake_decoder.decode(handshake_bytes) + except (msgspec.DecodeError, msgspec.ValidationError) as e: + raise RuntimeError( + f"Failed to decode NixlHandshakePayload. This likely indicates " + f"an incompatibility between connector version. Error: {e}" + ) from e + got_metadata_time = time.perf_counter() logger.debug( "NIXL handshake: get metadata took: %s", got_metadata_time - start_time ) + # Check compatibility hash BEFORE decoding agent metadata + if ( + self.enforce_compat_hash + and handshake_payload.compatibility_hash != self.compat_hash + ): + raise RuntimeError( + f"NIXL compatibility hash mismatch. " + f"Local: {self.compat_hash}, " + f"Remote: {handshake_payload.compatibility_hash}. " + f"Prefill and decode instances have incompatible configurations. " + f"This may be due to: different vLLM versions, models, dtypes, " + f"KV cache layouts, attention backends, etc. " + f"Both instances must use identical configurations." + f"Disable this check using " + f'--kv-transfer-config \'{{"kv_connector_extra_config": ' + f'{{"enforce_handshake_compat": false}}}}\'' + ) + + logger.info( + "NIXL compatibility check passed (hash: %s)", + handshake_payload.compatibility_hash, + ) + + # Decode agent metadata + metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + try: + metadata = metadata_decoder.decode( + handshake_payload.agent_metadata_bytes + ) + except (msgspec.DecodeError, msgspec.ValidationError) as e: + # This should not happen if hash matched + raise RuntimeError( + f"Failed to decode NixlAgentMetadata. Error: {e}" + ) from e + # Ensure engine id matches. if metadata.engine_id != expected_engine_id: raise RuntimeError( @@ -1094,10 +1137,11 @@ class NixlConnectorWorker: # Do NIXL handshake in background and add to _ready_requests when done. fut = self._handshake_futures.get(remote_engine_id) if fut is None: + assert meta.remote is not None fut = self._handshake_initiation_executor.submit( self._nixl_handshake, - meta.remote_host, - meta.remote_port, + meta.remote.host, + meta.remote.port, meta.tp_size, remote_engine_id, ) @@ -1180,14 +1224,11 @@ class NixlConnectorWorker: # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms - self.device_id = self.tp_rank for layer_name, cache_or_caches in xfer_buffers.items(): cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] for cache in cache_list: base_addr = cache.data_ptr() - if not self.use_host_buffer and current_platform.is_cuda_alike(): - self.device_id = cache.device.index if base_addr in seen_base_addresses: continue @@ -1230,8 +1271,7 @@ class NixlConnectorWorker: "All kv cache tensors must have the same size" ) # Need to make sure the device ID is non-negative for NIXL, - # Torch uses -1 to indicate CPU tensors while NIXL uses explicit - # memory type. + # Torch uses -1 to indicate CPU tensors. self.device_id = max(cache.get_device(), 0) caches_data.append( (base_addr, curr_tensor_size_bytes, self.device_id, "") @@ -1297,19 +1337,24 @@ class NixlConnectorWorker: assert len(self.block_window_per_layer) == self.num_layers # After KV Caches registered, listen for new connections. - self.xfer_handshake_metadata = NixlAgentMetadata( + agent_metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], device_id=self.device_id, num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, - attn_backend_name=self.backend_name, kv_cache_layout=self.kv_cache_layout if not self.use_host_buffer else self.host_buffer_kv_cache_layout, block_size=self.block_size, ) + # Wrap metadata in payload with hash for defensive decoding + encoder = msgspec.msgpack.Encoder() + self.xfer_handshake_metadata = NixlHandshakePayload( + compatibility_hash=self.compat_hash, + agent_metadata_bytes=encoder.encode(agent_metadata), + ) def register_local_xfer_handler( self, @@ -1524,8 +1569,6 @@ class NixlConnectorWorker: remote_engine_id = nixl_agent_meta.engine_id assert self._tp_size[remote_engine_id] == remote_tp_size - # TODO We may eventually want to skip enforcing the same attn backend. - assert nixl_agent_meta.attn_backend_name == self.backend_name tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( @@ -1745,6 +1788,7 @@ class NixlConnectorWorker: # clean up metadata for completed requests meta = self._recving_metadata.pop(req_id, None) assert meta is not None, f"{req_id} not found in recving_metadata list" + assert meta.remote is not None if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) if self.enable_permute_local_kv: @@ -1752,7 +1796,7 @@ class NixlConnectorWorker: # post processing for heteroblocksize block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( - meta.remote_engine_id + meta.remote.engine_id ) if ( not self.use_mla @@ -1818,9 +1862,7 @@ class NixlConnectorWorker: self._reqs_to_send.pop(req_id, None) return notified_req_ids - def _pop_done_transfers( - self, transfers: dict[str, list[tuple[int, float]]] - ) -> set[str]: + def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: """ Pop completed xfers by checking for DONE state. Args: @@ -1831,7 +1873,7 @@ class NixlConnectorWorker: done_req_ids: set[str] = set() for req_id, handles in list(transfers.items()): in_progress = False - for handle, xfer_start_time in handles: + for handle in handles: try: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": @@ -1889,17 +1931,18 @@ class NixlConnectorWorker: meta.local_physical_block_ids = self._logical_to_kernel_block_ids( meta.local_block_ids ) - meta.remote_block_ids = self._logical_to_kernel_block_ids( - meta.remote_block_ids + assert meta.remote is not None + meta.remote.block_ids = self._logical_to_kernel_block_ids( + meta.remote.block_ids ) - remote_engine_id = meta.remote_engine_id + remote_engine_id = meta.remote.engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, len(meta.local_physical_block_ids), - len(meta.remote_block_ids), + len(meta.remote.block_ids), ) # always store metadata for failure recovery self._recving_metadata[req_id] = meta @@ -1938,16 +1981,18 @@ class NixlConnectorWorker: self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + assert meta.remote is not None logger.debug( "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, + meta.remote.engine_id, req_id, ) self._read_blocks( request_id=req_id, - dst_engine_id=meta.remote_engine_id, + dst_engine_id=meta.remote.engine_id, + remote_request_id=meta.remote.request_id, local_block_ids=meta.local_physical_block_ids, - remote_block_ids=meta.remote_block_ids, + remote_block_ids=meta.remote.block_ids, ) def _read_blocks( @@ -1956,6 +2001,7 @@ class NixlConnectorWorker: remote_block_ids: list[int], dst_engine_id: str, request_id: str, + remote_request_id: str, ): block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: @@ -1988,7 +2034,7 @@ class NixlConnectorWorker: # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id) - notif_id = f"{request_id}:{tp_ratio}".encode() + notif_id = f"{remote_request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. @@ -2096,7 +2142,7 @@ class NixlConnectorWorker: self.nixl_wrapper.transfer(handle) # Use handle to check completion in future step(). - self._recving_transfers[request_id].append((handle, time.perf_counter())) + self._recving_transfers[request_id].append(handle) except Exception: logger.exception( "NIXL transfer setup/initiation failed for request %s. " @@ -2227,7 +2273,7 @@ class NixlConnectorWorker: """Shutdown the connector worker.""" self._handshake_initiation_executor.shutdown(wait=False) for handles in self._recving_transfers.values(): - for handle, _ in handles: + for handle in handles: self.nixl_wrapper.release_xfer_handle(handle) self._recving_transfers.clear() if self.src_xfer_side_handle: diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py deleted file mode 100644 index f48d03d0b0cd5..0000000000000 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This file contains a new class `KVLookupBufferBase` that allows developers to -think of KV cache operations as inserting new KV cache entries (`insert`) -into the lookup buffer and querying existing KV caches (`drop_select`) -from the lookup buffer. - -This file also contains a new class `KVStoreBufferBase` that allows developers -to manage the KVCache buffer as a simple key-value storage buffer with basic -put/get operations. - -These classes above are abstracted behind class `KVCacheBufferBase`. -""" - -from abc import ABC, abstractmethod - -import torch - - -class KVCacheBufferBase(ABC): - """ - Abstract base class for a KVCache buffer. - """ - - @abstractmethod - def close(self) -> None: - """Close the buffer and release resources. - - This method is responsible for cleaning up resources related to the - KVCache buffer when it is no longer needed. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - -class KVLookupBufferBase(KVCacheBufferBase): - """ - Abstract base class for a KVCache lookup buffer. - - This class provides an abstraction for a key-value (KV) cache lookup buffer. - - The key of the lookup buffer: - - input_tokens: token IDs of the request - - roi: a binary mask on top of input_tokens. - - Purpose of roi: Since KV cache may only be available for a subset of - tokens in the input (for example, when vLLM is connected to an external - KV cache service), roi specifies the subset of tokens that the KV cache - is associated with. - - NOTE: roi can be further extended to describe which part of KV the - current process is holding (each process may only hold a part of KV - due to TP and PP). This is not implemented for now. - - The value of the lookup buffer: - - key: the key tensor in the KV cache - - value: the value tensor in the KV cache - - hidden: the final hidden state generated by model forwarding. This allows - vLLM to bypass further model forwarding by transmitting the hidden state. - """ - - @abstractmethod - def insert( - self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - hidden: torch.Tensor, - ) -> None: - """Insert into the lookup buffer. - - The functionality is similar to the following python statement - ``` - buffer[input_tokens, roi] = [key, value, hidden] - ``` - - FIXME: in the future, we should only have two arguments, key and value, - where key is a tensor dict and value is a tensor dict. - - FIXME: we should transmit both sampler outputs and the hidden states. - - Args: - input_tokens (torch.Tensor): token IDs. - roi (torch.Tensor): A binary mask on top of the input tokens - key (torch.Tensor): The key tensor in the KV cache. - value (torch.Tensor): The value tensor in the KV cache. - hidden (torch.Tensor): The final hidden state tensor generated - during model forwarding to bypass model - forwarding. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def drop_select( - self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None - ) -> list[torch.Tensor | None]: - """Select and *drop* KV cache entries from the lookup buffer. - - The functionality is similar to the following python statements - ``` - ret = buffer.pop(input_tokens, roi) - return ret - ``` - - If `input_tokens` and `roi` is `None`, it means selecting any of the - KV caches in the buffer, return, and remove it from the buffer, useful - when offloading KV cache to KV cache storage service. - - Args: - input_tokens (torch.Tensor): token IDs. - roi (torch.Tensor): A binary mask on top of the input tokens - - Returns: - list[Optional[torch.Tensor]]: A list of tensors. Can be None. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - -class KVStoreBufferBase(KVCacheBufferBase): - """ - Abstract base class for a KVCache storage buffer with key-value semantics. - This class provides a simple key-value storage buffer abstract with basic - put/get operations, which enables flexible KVCache transfer granular - control. - - The functionality is similar to a distributed key-value store, where: - - Key: A unique string identifier for the cached entry - - Value: - - Tensor to be stored and retrieved - - None (indicating deletion or empty value) - """ - - @abstractmethod - def put( - self, - key: str, - value: torch.Tensor | None, - ) -> None: - """Store a key-value pair in the buffer. - - Args: - key (str): Unique identifier for a tensor, this tensor could be the - key cache tensor, value cache tensor, or hidden state tensor - generated during model forwarding. - - value (Optional[torch.Tensor]): Tensor to be stored. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def get( - self, - key: str, - ) -> torch.Tensor | None: - """Retrieve a value from the buffer by key. - - Args: - key (str): Unique identifier for a tensor, this tensor could be the - key cache tensor, value cache tensor, or hidden state tensor - generated during model forwarding. - - Returns: - Optional[torch.Tensor]: Stored tensor if exists, None otherwise. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py deleted file mode 100644 index 7861bea1f9c54..0000000000000 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +++ /dev/null @@ -1,164 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This file contains a new class `MooncakeStore` that allows developers to -think of KV cache transfer operations as putting new KV cache entries -into a remote KVStore-based lookup buffer and getting existing KV caches -from this remote lookup buffer. -""" - -import json -import os -from dataclasses import dataclass - -import torch -from safetensors.torch import load as safetensors_load -from safetensors.torch import save as safetensors_save - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVStoreBufferBase -from vllm.logger import init_logger - -DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB - -logger = init_logger(__name__) - - -@dataclass -class MooncakeStoreConfig: - local_hostname: str - metadata_server: str - global_segment_size: int - local_buffer_size: int - protocol: str - device_name: str - master_server_address: str - - @staticmethod - def from_file(file_path: str) -> "MooncakeStoreConfig": - """Load the config from a JSON file.""" - with open(file_path) as fin: - config = json.load(fin) - return MooncakeStoreConfig( - local_hostname=config.get("local_hostname"), - metadata_server=config.get("metadata_server"), - global_segment_size=config.get( - "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE - ), - local_buffer_size=config.get( - "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE - ), - protocol=config.get("protocol", "tcp"), - device_name=config.get("device_name", ""), - master_server_address=config.get("master_server_address"), - ) - - @staticmethod - def load_from_env() -> "MooncakeStoreConfig": - """Load config from a file specified in the environment variable.""" - config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") - if config_file_path is None: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." - ) - return MooncakeStoreConfig.from_file(config_file_path) - - -class MooncakeStore(KVStoreBufferBase): - def __init__( - self, - config: VllmConfig, - ): - try: - from mooncake.store import MooncakeDistributedStore - except ImportError as e: - raise ImportError( - "Please install mooncake by following the instructions at " - "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector." - ) from e - - try: - self.store = MooncakeDistributedStore() - self.config = MooncakeStoreConfig.load_from_env() - logger.info("Mooncake Configuration loaded successfully.") - - self.store.setup( - self.config.local_hostname, - self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address, - ) - - except ValueError as e: - logger.error("Configuration loading failed: %s", e) - raise - except Exception as exc: - logger.error("An error occurred while loading the configuration: %s", exc) - raise - - def close(self): - # MooncakeDistributedStore will automatically call the destructor, so - # it is unnecessary to close it manually. - pass - - def put( - self, - key: str, - value: torch.Tensor | None, - ) -> None: - # A message queue needs to be introduced before making it asynchronous. - if value is not None: - self._put_impl(key, value) - - def get( - self, - key: str, - ) -> torch.Tensor | None: - # A message queue needs to be introduced before making it asynchronous. - value = self._get_impl(key) - return value - - def _put_impl( - self, - key: str, - value: torch.Tensor, - ) -> None: - """Put KVCache to Mooncake Store""" - device_id = value.device.index if value.device.type == "cuda" else -1 - device_tensor = torch.tensor(device_id, dtype=torch.int32) - value_bytes = safetensors_save({"tensor": value, "device_id": device_tensor}) - try: - self.store.put(key, value_bytes) - except TypeError as err: - logger.error("Failed to put value into Mooncake Store: %s", err) - raise TypeError("Mooncake Store Put Type Error.") from err - - def _get_impl( - self, - key: str, - ) -> torch.Tensor | None: - """Get KVCache from Mooncake Store""" - try: - data = self.store.get(key) - except TypeError as err: - logger.error("Failed to get value from Mooncake Store: %s", err) - raise TypeError("Mooncake Store Get Type Error.") from err - - if data: - loaded_tensors = safetensors_load(data) - tensor = loaded_tensors["tensor"] - device_id_tensor = loaded_tensors["device_id"] - device_id = int(device_id_tensor.item()) - device = ( - torch.device("cuda", device_id) - if device_id >= 0 - else torch.device("cpu") - ) - return tensor.to(device) - - return None diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py deleted file mode 100644 index f046a349874e6..0000000000000 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ /dev/null @@ -1,242 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Implements a distributed key-value (KV) cache transfer mechanism. - -Key Features: -- Distributed KV cache transmission using PyNccl pipes. -- Non-blocking `insert`, blocking `drop_select`. -- Use CPU signal pipe to avoid racing condition -- Handles buffer size constraints and provide backpressure mechanism to - stop the prefill instance when the decode instance is slow. -""" - -import threading -from collections import deque - -import torch - -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class SimpleBuffer(KVLookupBufferBase): - def __init__( - self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float - ): - """ - signal_pipe: on CPU - - NOTE: on-device recv will block all threads in the process, making the - KV cache producer unable to listen to new request while transmitting - KV cache. Luckily CPU recv only blocks the current thread so we use - CPU recv to listen to new request. - - data_pipe: on device (e.g. GPU) - """ - - self.buffer: deque[list[torch.Tensor]] = deque() - - self.buffer_size = 0 - self.buffer_size_threshold = buffer_size_thresh - self.buffer_cv = threading.Condition() - self.signal_pipe = signal_pipe - self.data_pipe = data_pipe - self.request_handling_thread: threading.Thread | None = None - - self.normal_signal = torch.tensor([0], device="cpu") - self.end_signal = None - - def _matches( - self, - tokens_roi_sender: list[torch.Tensor], - tokens_roi_recver: list[torch.Tensor], - ): - # tokens_roi_sender: tokens and roi of the producer (in the buffer) - # tokens_roi_recver: tokens and roi of the consumer (query) - - tokens_sender = tokens_roi_sender[0] - tokens_recver = tokens_roi_recver[0] - roi_sender = tokens_roi_sender[1] - roi_recver = tokens_roi_recver[1] - - if tokens_recver is None: - # consumer sends an empty request - # semantics: DROP SELECT * LIMIT 1 - # so any of the data in the buffer can be drop-selected - return True - - # Assuming that roi is a binary mask on tokens - tokens_sender = tokens_sender[roi_sender] - tokens_recver = tokens_recver[roi_recver] - - # simple common prefix matching - min_length = min(len(tokens_sender), len(tokens_recver)) - if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]): - return min_length - - return 0 - - def _send_tensor_and_dec_size(self, tensor: torch.Tensor | None) -> None: - assert tensor is not None, "Use self.data_pipe.send(None) instead" - self.buffer_size -= tensor.element_size() * tensor.numel() - if tensor.dtype == torch.bool: - tensor = tensor.float() - self.data_pipe.send_tensor(tensor) - - def _get_element_size(self, data: list | torch.Tensor | None): - if isinstance(data, torch.Tensor): - return data.element_size() * data.numel() - if not data: - # cannot perform `not data` on a tensor - # so this check needs to go after the check above - return 0 - - raise AssertionError(f"Unknown data type {type(data)}") - - def _add_to_buffer( - self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - hidden: torch.Tensor, - ): - if isinstance(input_tokens, torch.Tensor): - input_tokens = input_tokens.clone() - if isinstance(roi, torch.Tensor): - roi = roi.clone() - if isinstance(key, torch.Tensor): - key = key.clone() - if isinstance(value, torch.Tensor): - value = value.clone() - if isinstance(hidden, torch.Tensor): - hidden = hidden.clone() - - buffer_item = [input_tokens, roi, key, value, hidden] - data_size = sum([self._get_element_size(data) for data in buffer_item]) - - with self.buffer_cv: - if self.buffer_size + data_size > self.buffer_size_threshold: - # log outside the while loop to avoid this message being logged - # repeatedly. - logger.debug("KV transfer buffer is full. Handling...") - while self.buffer_size + data_size > self.buffer_size_threshold: - self.buffer_cv.wait() - - self.buffer_size += data_size - self.buffer.append(buffer_item) - self.buffer_cv.notify() - - def _is_end_signal(self, signal): - return signal is None - - def drop_select_handler(self): - try: - while True: - signal = self.signal_pipe.recv_tensor() - if self._is_end_signal(signal): - logger.info("Received end signal!") - break - - input_tokens = self.data_pipe.recv_tensor() - - roi = self.data_pipe.recv_tensor() - assert roi is not None, ( - "Please provide the roi when sending drop-select request" - ) - roi = roi > 0.5 - tokens_roi_recver = [input_tokens, roi] - - def is_buffer_available( - tokens_roi_recver: list[torch.Tensor], - ) -> bool: - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - for _ in range(len(self.buffer)): - if self._matches(self.buffer[0], tokens_roi_recver) > 0: - return True - # rotate the element we just accessed to the end - self.buffer.rotate(-1) - return False - - with self.buffer_cv: - while not is_buffer_available(tokens_roi_recver): - logger.debug("KV transfer buffer is not available. Waiting...") - self.buffer_cv.wait() - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - self.buffer_cv.notify() - - except RuntimeError as e: - if "Connection closed by peer" not in str(e): - raise e - - logger.debug("Closing drop_select_handler") - - def drop_select( - self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None - ) -> list[torch.Tensor | None]: - assert self.request_handling_thread is None, ( - "drop_select should be called by the KV cache consumer " - "(e.g. the decode vLLM instance)" - ) - - if isinstance(input_tokens, torch.Tensor): - input_tokens = input_tokens.clone() - if isinstance(roi, torch.Tensor): - roi = roi.clone().float() - - self.signal_pipe.send_tensor(self.normal_signal) - self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi) - - input_tokens = self.data_pipe.recv_tensor() - roi = self.data_pipe.recv_tensor() - if roi is not None: - # convert from float tensor to bool tensor - # as PyNccl does not support sending bool tensor - roi = roi > 0.5 - key = self.data_pipe.recv_tensor() - value = self.data_pipe.recv_tensor() - hidden = self.data_pipe.recv_tensor() - - return [input_tokens, roi, key, value, hidden] - - def insert( - self, - input_tokens: torch.Tensor, - roi: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - hidden: torch.Tensor, - ) -> None: - self._add_to_buffer(input_tokens, roi, key, value, hidden) - - # when calling the insert, the current process is a sender - # need to launch the request handler and start listening to request. - if self.request_handling_thread is None: - self.request_handling_thread = threading.Thread( - target=self.drop_select_handler - ) - self.request_handling_thread.start() - - def close(self): - if ( - hasattr(self, "request_handling_thread") - and self.request_handling_thread is not None - ): - self.request_handling_thread.join() - - else: - # TODO: have a explicit close signal and have a explicit way to - # check if it's requester - self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py deleted file mode 100644 index 1fe7a90e9a712..0000000000000 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This file defines an interface `KVPipeBase` -that provides an abstraction for sending and receiving tensors, or None, via -distributed communications. - -All classes instantiated from this interface are assumed to be a FIFO pipe. - -If your distributed communication platform already supports key-value lookup, -you can bypass this interface and directly start from `kv_lookup_buffer`. -""" - -from abc import ABC, abstractmethod - -import torch - - -class KVPipeBase(ABC): - """ - This class provides an interface for sending and receiving tensors, or - None, by distributed communications. - """ - - @abstractmethod - def send_tensor(self, tensor: torch.Tensor | None) -> None: - """Send a tensor, or None, via the pipe. - - Need to support sending None -- important for error handling. - - TODO: add a `key` argument so that we can use traditional - key-value database as the distributed communication mechanism behind - the pipe. - - Args: - tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def recv_tensor(self) -> torch.Tensor | None: - """Receive a tensor (can be None) from the pipeline. - - Returns: - Optional[torch.Tensor]: The tensor received from the pipeline. Can - be None. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def close(self) -> None: - """Close the pipeline and release resources. - - This method is responsible for closing the communication pipeline - and releasing any resources associated with it. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py deleted file mode 100644 index 542dde09abad4..0000000000000 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ /dev/null @@ -1,295 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import json -import os -import struct -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass - -import torch -import zmq -from safetensors.torch import load as safetensors_load -from safetensors.torch import save as safetensors_save - -from vllm.config.kv_transfer import KVTransferConfig -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from vllm.logger import init_logger -from vllm.utils.network_utils import join_host_port, make_zmq_path, split_host_port - -logger = init_logger(__name__) -NONE_INT = -150886311 - - -@dataclass -class MooncakeTransferEngineConfig: - prefill_url: str - decode_url: str - metadata_backend: str | None - metadata_server: str - protocol: str - device_name: str - - @staticmethod - def from_file(file_path: str) -> "MooncakeTransferEngineConfig": - """Load the config from a JSON file.""" - with open(file_path) as fin: - config = json.load(fin) - return MooncakeTransferEngineConfig( - prefill_url=config.get("prefill_url"), - decode_url=config.get("decode_url"), - metadata_backend=config.get("metadata_backend", None), - metadata_server=config.get("metadata_server"), - protocol=config.get("protocol", "tcp"), - device_name=config.get("device_name", ""), - ) - - @staticmethod - def load_from_env() -> "MooncakeTransferEngineConfig": - """Load config from a file specified in the environment variable.""" - config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") - if config_file_path is None: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." - ) - return MooncakeTransferEngineConfig.from_file(config_file_path) - - -class MooncakeTransferEngine: - """Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ.""" - - def __init__(self, kv_rank: int, local_rank: int): - try: - from mooncake.engine import TransferEngine - except ImportError as e: - raise ImportError( - "Please install mooncake by following the instructions at " - "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector." - ) from e - - self.engine = TransferEngine() - self.local_rank = local_rank - - try: - self.config = MooncakeTransferEngineConfig.load_from_env() - logger.info("Mooncake Configuration loaded successfully.") - except ValueError as e: - logger.error(e) - raise - except Exception as exc: - logger.error("An error occurred while loading the configuration: %s", exc) - raise - prefill_host, base_prefill_port = split_host_port(self.config.prefill_url) - decode_host, base_decode_port = split_host_port(self.config.decode_url) - - # Avoid ports conflict when running prefill and decode on the same node - if prefill_host == decode_host and base_prefill_port == base_decode_port: - base_decode_port = base_decode_port + 100 - - prefill_port = base_prefill_port + self.local_rank - decode_port = base_decode_port + self.local_rank - self.prefill_url = join_host_port(prefill_host, prefill_port) - self.decode_url = join_host_port(decode_host, decode_port) - - self.initialize( - self.prefill_url if kv_rank == 0 else self.decode_url, - self.config.metadata_server, - self.config.protocol, - self.config.device_name, - self.config.metadata_backend, - ) - - self.remote_url = self.decode_url if kv_rank == 0 else self.prefill_url - - # Initialize ZeroMQ context and sockets - self.context = zmq.Context() # type: ignore[attr-defined] - self.sender_socket = self.context.socket(zmq.constants.PUSH) - self.receiver_socket = self.context.socket(zmq.constants.PULL) - self.sender_ack = self.context.socket(zmq.constants.PULL) - self.receiver_ack = self.context.socket(zmq.constants.PUSH) - - self.buffer_cleaner = ThreadPoolExecutor(max_workers=1) - self._setup_metadata_sockets( - kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port - ) - - def _setup_metadata_sockets( - self, kv_rank: int, p_host: str, p_port: int, d_host: str, d_port: int - ) -> None: - """Set up ZeroMQ sockets for sending and receiving data.""" - # Offsets < 8 are left for initialization in case tp and pp are enabled - p_rank_offset = p_port + 8 + self.local_rank * 2 - d_rank_offset = d_port + 8 + self.local_rank * 2 - if kv_rank == 0: - self.sender_socket.bind(make_zmq_path("tcp", p_host, p_rank_offset + 1)) - self.receiver_socket.connect( - make_zmq_path("tcp", d_host, d_rank_offset + 1) - ) - self.sender_ack.connect(make_zmq_path("tcp", d_host, d_rank_offset + 2)) - self.receiver_ack.bind(make_zmq_path("tcp", p_host, p_rank_offset + 2)) - else: - self.receiver_socket.connect( - make_zmq_path("tcp", p_host, p_rank_offset + 1) - ) - self.sender_socket.bind(make_zmq_path("tcp", d_host, d_rank_offset + 1)) - self.receiver_ack.bind(make_zmq_path("tcp", d_host, d_rank_offset + 2)) - self.sender_ack.connect(make_zmq_path("tcp", p_host, p_rank_offset + 2)) - - def initialize( - self, - local_hostname: str, - metadata_server: str, - protocol: str, - device_name: str, - metadata_backend: str | None, - ) -> None: - """Initialize the mooncake instance.""" - if metadata_backend is None: - self.engine.initialize( - local_hostname, metadata_server, protocol, device_name - ) - else: - supported_backend = ["etcd", "redis"] - metadata_backend = metadata_backend.lower() - if metadata_backend not in supported_backend: - raise ValueError( - "Mooncake Configuration error. `metadata_backend`" - f" should be one of {supported_backend}." - ) - - self.engine.initialize_ext( - local_hostname, metadata_server, protocol, device_name, metadata_backend - ) - - def allocate_managed_buffer(self, length: int) -> int: - """Allocate a managed buffer of the specified length.""" - ret = self.engine.allocate_managed_buffer(length) - if ret <= 0: - logger.error("Allocation Return Error") - raise Exception("Allocation Return Error") - return ret - - def free_managed_buffer(self, buffer: int, length: int) -> int: - """Free a previously allocated managed buffer.""" - return self.engine.free_managed_buffer(buffer, length) - - def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int: - """Synchronously transfer data to the specified address.""" - ret = self.engine.transfer_sync_read( - self.remote_url, buffer, peer_buffer_address, length - ) - if ret < 0: - logger.error("Transfer Return Error") - raise Exception("Transfer Return Error") - return ret - - def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int: - """Write bytes to the allocated buffer.""" - return self.engine.write_bytes_to_buffer(buffer, user_data, length) - - def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: - """Read bytes from the allocated buffer.""" - return self.engine.read_bytes_from_buffer(buffer, length) - - def wait_for_ack(self, src_ptr: int, length: int) -> None: - """Asynchronously wait for ACK from the receiver.""" - ack = self.sender_ack.recv() - if ack != b"ACK": - logger.error("Failed to receive ACK from the receiver") - - self.free_managed_buffer(src_ptr, length) - - def send_bytes(self, user_data: bytes) -> None: - """Send bytes to the remote process.""" - length = len(user_data) - src_ptr = self.allocate_managed_buffer(length) - self.write_bytes_to_buffer(src_ptr, user_data, length) - self.sender_socket.send_multipart( - [struct.pack("!Q", src_ptr), struct.pack("!Q", length)] - ) - self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) - - def recv_bytes(self) -> bytes: - """Receive bytes from the remote process.""" - data = self.receiver_socket.recv_multipart() - src_ptr = struct.unpack("!Q", data[0])[0] - length = struct.unpack("!Q", data[1])[0] - dst_ptr = self.allocate_managed_buffer(length) - self.transfer_sync(dst_ptr, src_ptr, length) - ret = self.read_bytes_from_buffer(dst_ptr, length) - - # Buffer cleanup - self.receiver_ack.send(b"ACK") - self.free_managed_buffer(dst_ptr, length) - - return ret - - -class MooncakePipe(KVPipeBase): - """MooncakeTransferEngine based Pipe implementation.""" - - def __init__( - self, local_rank: int, config: KVTransferConfig, device: str | None = None - ): - """Initialize the mooncake pipe and set related parameters.""" - self.config = config - self.local_rank = local_rank - self.kv_rank = self.config.kv_rank - assert self.kv_rank is not None - if device is None: - self.device = self._select_device(self.config.kv_buffer_device) - else: - self.device = self._select_device(device) - - self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank) - self.transport_thread: ThreadPoolExecutor | None = None - self.none_tensor = torch.tensor([NONE_INT], device=self.device) - - def _select_device(self, device: str) -> torch.device: - """Select available device (CUDA or CPU).""" - logger.info("Selecting device: %s", device) - if device == "cuda": - return torch.device(f"cuda:{self.local_rank}") - else: - return torch.device("cpu") - - def tensor_hash(self, tensor: torch.Tensor) -> int: - """Calculate the hash value of the tensor.""" - return hash(tensor.data_ptr()) - - def _send_impl(self, tensor: torch.Tensor) -> None: - """Implement the tensor sending logic using safetensors.""" - self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor})) - - def _recv_impl(self) -> torch.Tensor: - """Implement the tensor receiving logic using safetensors.""" - data = self.transfer_engine.recv_bytes() - return safetensors_load(data)["tensor"].to(self.device) - - def send_tensor(self, tensor: torch.Tensor | None) -> None: - """Send tensor to the target process.""" - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - tensor = tensor if tensor is not None else self.none_tensor - assert len(tensor.shape) > 0 - self.transport_thread.submit(self._send_impl, tensor) - - def recv_tensor(self) -> torch.Tensor | None: - """Receive tensor from other processes.""" - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - tensor = self.transport_thread.submit(self._recv_impl).result() - if tensor.numel() == 1 and tensor.item() == NONE_INT: - return None - else: - return tensor - - def close(self) -> None: - """Cleanup logic when closing the pipe.""" - self.transfer_engine.sender_socket.close() - self.transfer_engine.receiver_socket.close() - self.transfer_engine.sender_ack.close() - self.transfer_engine.receiver_ack.close() - self.transfer_engine.context.term() # Terminate the ZMQ context - logger.info("Closed the transfer engine and cleaned up resources.") diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py deleted file mode 100644 index 526c5cd1d5278..0000000000000 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ /dev/null @@ -1,285 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This module implements a PyNccl pipe for sending and receiving -Optional[torch.Tensor] between distributed ranks with advanced -communication features. - -Key Features: -- Supports sending and receiving tensors with metadata -- Handles both CUDA and CPU device communications -- Implements a non-blocking tensor transfer mechanism -- Manages buffer size and provides backpressure control -- Supports distributed process groups with configurable parameters -""" - -import threading -import time -from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor - -import torch - -from vllm.config.kv_transfer import KVTransferConfig -from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator -from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase -from vllm.distributed.utils import StatelessProcessGroup -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class BrokenPipeException(Exception): - def __init__(self, message): - self.message = message - super().__init__(self.message) - - -Metadata = dict[str, torch.Tensor | None] - - -class PyNcclPipe(KVPipeBase): - METADATA_LENGTH = 16 - MAX_TENSOR_DIMENSIONS = 14 - METADATA_DTYPE = torch.int64 - - def __init__( - self, - local_rank: int, - config: KVTransferConfig, - device: str | None = None, - port_offset: int = 0, - ): - self.config = config - self.local_rank = local_rank - self.kv_rank = self.config.kv_rank - assert self.kv_rank is not None - self.kv_parallel_size = self.config.kv_parallel_size - if device is None: - self.device = self._select_device(self.config.kv_buffer_device) - else: - self.device = self._select_device(device) - - # build distributed connection and send/recv implementation - store_timeout = self.config.get_from_extra_config("store_timeout", 300) - self.group = StatelessProcessGroup.create( - host=self.config.kv_ip, - port=self.config.kv_port + port_offset, - rank=self.kv_rank, - world_size=self.kv_parallel_size, - store_timeout=store_timeout, - ) - # add a barrier to make sure the connection is initiated properly - self.group.barrier() - impl = self._get_device_send_recv_impl(self.group) - self.device_send_func, self.device_recv_func = impl - # set target rank - self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size - self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size - - # transportation-related variables - self.transport_thread: ThreadPoolExecutor | None = None - self.buffer_size = 0 - self.buffer_size_lock = threading.Lock() - self.buffer_size_thresh = self.config.kv_buffer_size - - def _get_device_send_recv_impl( - self, group: StatelessProcessGroup - ) -> tuple[ - Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None] - ]: - send: Callable[[torch.Tensor, int], None] - recv: Callable[[torch.Tensor, int], None] - if self.device.type == "cuda": - # use PyNCCL for send / recv - comm = PyNcclCommunicator(group, device=self.local_rank) - comm.disabled = False - send, recv = comm.send, comm.recv # type: ignore - else: - # This send / recv implementation here is NOT intended to transfer - # KV caches (and should NOT be repurposed to transfer KV caches). - # Currently it is only used to transmit control-plane messages - # for PyNcclBuffer. - send = group.send_obj - - def my_recv(x, src): - x[...] = group.recv_obj(src) - - recv = my_recv - - return send, recv - - def _select_device(self, device: str): - logger.info("Selecting device: %s", device) - if device == "cuda": - return torch.device(f"cuda:{self.local_rank}") - else: - return torch.device("cpu") - - def _make_metadata(self, tensor: torch.Tensor | None) -> Metadata: - """ - Create the metadata as a dictionary based on the input tensor. - - Args: - tensor: The input tensor or None if no tensor is provided. - - Returns: - metadata: A dictionary with the following keys: - - "dtype": The data type of the tensor or None. - - "shape": The shape of the tensor or None. - """ - if tensor is None: - return {"dtype": None, "shape": None} - else: - return {"dtype": tensor.dtype, "shape": tensor.shape} - - def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: - """ - Create a buffer to receive the tensor based on the provided metadata. - - Args: - metadata: A dictionary with keys "dtype" and "shape", - describing the tensor's data type and shape. - - Returns: - buffer: A tensor of the specified type and shape, - allocated on `self.device`. - """ - return torch.empty( - metadata["shape"], dtype=metadata["dtype"], device=self.device - ) - - def _send_metadata(self, metadata: Metadata): - """ - Send the metadata dictionary to the target rank. - - Args: - metadata: A dictionary with keys "dtype" and "shape". - """ - self.group.send_obj(metadata, self.target_rank_for_send) - - def _recv_metadata(self) -> Metadata: - """ - Receive the metadata dictionary from the target rank. - - Returns: - metadata: A dictionary with keys "dtype" and "shape" - describing the tensor. - """ - return self.group.recv_obj(self.target_rank_for_recv) - - def _send_impl(self, tensor: torch.Tensor | None) -> None: - """ - The actual implementation of sending the tensor and its metadata to the - target rank. - - Args: - tensor: The input tensor to be sent, or `None` if no tensor is - being sent. - """ - metadata = self._make_metadata(tensor) - self._send_metadata(metadata) - if tensor is not None: - self.device_send_func(tensor.to(self.device), self.target_rank_for_send) - - def _recv_impl(self) -> torch.Tensor | None: - """ - The actual implementation of receiving a tensor and its metadata from - the target rank. - - Returns: - buffer: The received tensor, or `None` if no tensor is received. - """ - metadata = self._recv_metadata() - if metadata["dtype"] is None: - return None - buffer = self._prepare_recv_buffer(metadata) - self.device_recv_func(buffer, self.target_rank_for_recv) - - return buffer - - def send_tensor_wrapper( - self, tensor: torch.Tensor | None, tensor_size: int - ) -> None: - """ - Wrapper for _send_impl to handle exceptions and update buffer size. - """ - try: - self._send_impl(tensor) - - with self.buffer_size_lock: - self.buffer_size -= tensor_size - except Exception as e: - logger.error( - "[rank%d]: Exception when trying to send %s, msg: %s", - torch.distributed.get_rank(), - str(tensor), - str(e), - ) - import traceback - - traceback.print_exc() - - def block_if_full(self): - """ - Block the current thread if the buffer size is larger than the - threshold. - """ - while self.buffer_size > self.buffer_size_thresh: - logger.debug("KV cache transfer pipe is full. Waiting...") - time.sleep(0.05) - - def send_tensor(self, tensor: torch.Tensor | None) -> None: - """ - Sends a tensor and its metadata to the destination rank in a - non-blocking way. - - Args: - tensor: The tensor to send, or `None` if no tensor is being sent. - """ - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - if tensor is not None: - tensor_size = tensor.element_size() * tensor.numel() - else: - tensor_size = 0 - - self.block_if_full() - - with self.buffer_size_lock: - self.buffer_size += tensor_size - - self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size) - - def recv_tensor(self) -> torch.Tensor | None: - """ - Receives a tensor and its metadata from the source rank. Blocking call. - - Returns: - The received tensor, or `None` if no tensor is received. - """ - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - - future = self.transport_thread.submit(self._recv_impl) - - try: - tensor = future.result() - except Exception as e: - logger.error("Encountering exception in KV receiving thread") - logger.error("%s", e) - logger.error("My device: %s", self.device) - import traceback - - traceback.print_exc() - raise e - - return tensor - - def close(self): - """ - Close the pipe and release associated resources. - """ - if hasattr(self, "transport_thread") and self.transport_thread is not None: - self.transport_thread.shutdown() diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c82a77c216af2..338cb1f1814b5 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1169,17 +1169,13 @@ def init_distributed_environment( from vllm.config import get_current_vllm_config config = get_current_vllm_config() - if config is not None and config.parallel_config.nnodes > 1: - parallel_config = config.parallel_config - ip = parallel_config.master_addr - rank = parallel_config.data_parallel_rank * world_size + rank - world_size = parallel_config.world_size_across_dp - port = parallel_config.master_port - distributed_init_method = get_distributed_init_method(ip, port) - elif ( + if ( config is not None - and config.parallel_config.data_parallel_size > 1 and config.parallel_config.distributed_executor_backend != "external_launcher" + and ( + config.parallel_config.nnodes > 1 + or config.parallel_config.data_parallel_size > 1 + ) ): parallel_config = config.parallel_config # adjust to take into account data parallelism @@ -1187,15 +1183,22 @@ def init_distributed_environment( rank = parallel_config.data_parallel_rank * world_size + rank # adjust the world size to take into account data parallelism world_size = parallel_config.world_size_across_dp - ip = parallel_config.data_parallel_master_ip - port = parallel_config.get_next_dp_init_port() - distributed_init_method = get_distributed_init_method(ip, port) - logger.debug( - "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", - world_size, - rank, - distributed_init_method, - ) + + # Use appropriate IP and port based on configuration + if parallel_config.nnodes > 1: + ip = parallel_config.master_addr + port = parallel_config.master_port + distributed_init_method = get_distributed_init_method(ip, port) + else: + ip = parallel_config.data_parallel_master_ip + port = parallel_config.get_next_dp_init_port() + distributed_init_method = get_distributed_init_method(ip, port) + logger.debug( + "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", + world_size, + rank, + distributed_init_method, + ) if not torch.distributed.is_initialized(): logger.info( "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", @@ -1583,6 +1586,8 @@ def destroy_distributed_environment(): def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + # Reset environment variable cache + envs.disable_envs_cache() # Ensure all objects are not frozen before cleanup gc.unfreeze() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 96b1b971552c6..ca19e468914c7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -34,6 +34,7 @@ from typing_extensions import TypeIs import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( + AttentionConfig, CacheConfig, CompilationConfig, ConfigType, @@ -49,6 +50,7 @@ from vllm.config import ( ObservabilityConfig, ParallelConfig, PoolerConfig, + ProfilerConfig, SchedulerConfig, SpeculativeConfig, StructuredOutputsConfig, @@ -69,7 +71,6 @@ from vllm.config.model import ( LogprobsMode, ModelDType, RunnerOption, - TaskOption, TokenizerMode, ) from vllm.config.multimodal import MMCacheType, MMEncoderTPMode @@ -86,8 +87,9 @@ from vllm.transformers_utils.config import ( is_interleaved, maybe_override_with_speculators, ) +from vllm.transformers_utils.gguf_utils import is_gguf from vllm.transformers_utils.repo_utils import get_model_path -from vllm.transformers_utils.utils import is_cloud_storage, is_gguf +from vllm.transformers_utils.utils import is_cloud_storage from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.mem_constants import GiB_bytes from vllm.utils.network_utils import get_ip @@ -357,7 +359,6 @@ class EngineArgs: hf_config_path: str | None = ModelConfig.hf_config_path runner: RunnerOption = ModelConfig.runner convert: ConvertOption = ModelConfig.convert - task: TaskOption | None = ModelConfig.task skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode @@ -370,9 +371,8 @@ class EngineArgs: config_format: str = ModelConfig.config_format dtype: ModelDType = ModelConfig.dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype - seed: int | None = 0 + seed: int = ModelConfig.seed max_model_len: int | None = ModelConfig.max_model_len - cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes cudagraph_capture_sizes: list[int] | None = ( CompilationConfig.cudagraph_capture_sizes ) @@ -408,6 +408,7 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel all2all_backend: str | None = ParallelConfig.all2all_backend enable_dbo: bool = ParallelConfig.enable_dbo + ubatch_size: int = ParallelConfig.ubatch_size dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold disable_nccl_for_dp_synchronization: bool = ( @@ -420,10 +421,6 @@ class EngineArgs: ) _api_process_count: int = ParallelConfig._api_process_count _api_process_rank: int = ParallelConfig._api_process_rank - num_redundant_experts: int = EPLBConfig.num_redundant_experts - eplb_window_size: int = EPLBConfig.window_size - eplb_step_interval: int = EPLBConfig.step_interval - eplb_log_balancedness: bool = EPLBConfig.log_balancedness max_parallel_loading_workers: int | None = ( ParallelConfig.max_parallel_loading_workers ) @@ -464,7 +461,6 @@ class EngineArgs: MultiModalConfig, "media_io_kwargs" ) mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs - disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb mm_processor_cache_type: MMCacheType | None = ( MultiModalConfig.mm_processor_cache_type @@ -496,7 +492,7 @@ class EngineArgs: enable_chunked_prefill: bool | None = None disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input - disable_hybrid_kv_cache_manager: bool = ( + disable_hybrid_kv_cache_manager: bool | None = ( SchedulerConfig.disable_hybrid_kv_cache_manager ) @@ -517,14 +513,25 @@ class EngineArgs: collect_detailed_traces: list[DetailedTraceModules] | None = ( ObservabilityConfig.collect_detailed_traces ) + kv_cache_metrics: bool = ObservabilityConfig.kv_cache_metrics + kv_cache_metrics_sample: float = get_field( + ObservabilityConfig, "kv_cache_metrics_sample" + ) + cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics + enable_layerwise_nvtx_tracing: bool = ( + ObservabilityConfig.enable_layerwise_nvtx_tracing + ) scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls pooler_config: PoolerConfig | None = ModelConfig.pooler_config compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") + attention_config: AttentionConfig = get_field(VllmConfig, "attention_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls + profiler_config: ProfilerConfig = get_field(VllmConfig, "profiler_config") + kv_transfer_config: KVTransferConfig | None = None kv_events_config: KVEventsConfig | None = None @@ -537,6 +544,7 @@ class EngineArgs: ) model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype + attention_backend: AttentionBackendEnum | None = AttentionConfig.backend calculate_kv_scales: bool = CacheConfig.calculate_kv_scales mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype @@ -548,9 +556,6 @@ class EngineArgs: use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location - # DEPRECATED - enable_multimodal_encoder_data_parallel: bool = False - logits_processors: list[str | type[LogitsProcessor]] | None = ( ModelConfig.logits_processors ) @@ -575,6 +580,8 @@ class EngineArgs: # CompilationConfig object if isinstance(self.compilation_config, dict): self.compilation_config = CompilationConfig(**self.compilation_config) + if isinstance(self.attention_config, dict): + self.attention_config = AttentionConfig(**self.attention_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins @@ -616,7 +623,6 @@ class EngineArgs: model_group.add_argument("--model", **model_kwargs["model"]) model_group.add_argument("--runner", **model_kwargs["runner"]) model_group.add_argument("--convert", **model_kwargs["convert"]) - model_group.add_argument("--task", **model_kwargs["task"], deprecated=True) model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"]) model_group.add_argument( @@ -712,6 +718,16 @@ class EngineArgs: "--pt-load-map-location", **load_kwargs["pt_load_map_location"] ) + # Attention arguments + attention_kwargs = get_kwargs(AttentionConfig) + attention_group = parser.add_argument_group( + title="AttentionConfig", + description=AttentionConfig.__doc__, + ) + attention_group.add_argument( + "--attention-backend", **attention_kwargs["backend"] + ) + # Structured outputs arguments structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) structured_outputs_group = parser.add_argument_group( @@ -826,6 +842,10 @@ class EngineArgs: "--all2all-backend", **parallel_kwargs["all2all_backend"] ) parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) + parallel_group.add_argument( + "--ubatch-size", + **parallel_kwargs["ubatch_size"], + ) parallel_group.add_argument( "--dbo-decode-token-threshold", **parallel_kwargs["dbo_decode_token_threshold"], @@ -860,11 +880,6 @@ class EngineArgs: parallel_group.add_argument( "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"] ) - parallel_group.add_argument( - "--enable-multimodal-encoder-data-parallel", - action="store_true", - deprecated=True, - ) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -938,9 +953,6 @@ class EngineArgs: multimodal_group.add_argument( "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"] ) - multimodal_group.add_argument( - "--disable-mm-preprocessor-cache", action="store_true", deprecated=True - ) multimodal_group.add_argument( "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"] ) @@ -1013,6 +1025,21 @@ class EngineArgs: "--collect-detailed-traces", **observability_kwargs["collect_detailed_traces"], ) + observability_group.add_argument( + "--kv-cache-metrics", **observability_kwargs["kv_cache_metrics"] + ) + observability_group.add_argument( + "--kv-cache-metrics-sample", + **observability_kwargs["kv_cache_metrics_sample"], + ) + observability_group.add_argument( + "--cudagraph-metrics", + **observability_kwargs["cudagraph_metrics"], + ) + observability_group.add_argument( + "--enable-layerwise-nvtx-tracing", + **observability_kwargs["enable_layerwise_nvtx_tracing"], + ) # Scheduler arguments scheduler_kwargs = get_kwargs(SchedulerConfig) @@ -1083,15 +1110,6 @@ class EngineArgs: compilation_group.add_argument( "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"] ) - compilation_kwargs["cudagraph_capture_sizes"]["help"] = ( - "--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0," - " whichever is soonest. Please use --cudagraph-capture-sizes instead." - ) - compilation_group.add_argument( - "--cuda-graph-sizes", - **compilation_kwargs["cudagraph_capture_sizes"], - deprecated=True, - ) compilation_group.add_argument( "--max-cudagraph-capture-size", **compilation_kwargs["max_cudagraph_capture_size"], @@ -1120,13 +1138,16 @@ class EngineArgs: vllm_group.add_argument( "--compilation-config", "-cc", **vllm_kwargs["compilation_config"] ) + vllm_group.add_argument( + "--attention-config", "-ac", **vllm_kwargs["attention_config"] + ) vllm_group.add_argument( "--additional-config", **vllm_kwargs["additional_config"] ) vllm_group.add_argument( "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] ) - + vllm_group.add_argument("--profiler-config", **vllm_kwargs["profiler_config"]) vllm_group.add_argument( "--optimization-level", **vllm_kwargs["optimization_level"] ) @@ -1161,62 +1182,20 @@ class EngineArgs: if is_gguf(self.model): self.quantization = self.load_format = "gguf" - # NOTE(woosuk): In V1, we use separate processes for workers (unless - # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here - # doesn't affect the user process. - if self.seed is None: - logger.warning_once( - "`seed=None` is equivalent to `seed=0` in V1 Engine. " - "You will no longer be allowed to pass `None` in v0.13.", - scope="local", + if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: + logger.warning( + "The global random seed is set to %d. Since " + "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " + "affect the random state of the Python process that " + "launched vLLM.", + self.seed, ) - self.seed = 0 - if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: - logger.warning( - "The global random seed is set to %d. Since " - "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " - "affect the random state of the Python process that " - "launched vLLM.", - self.seed, - ) - - if self.disable_mm_preprocessor_cache: - logger.warning_once( - "`--disable-mm-preprocessor-cache` is deprecated " - "and will be removed in v0.13. " - "Please use `--mm-processor-cache-gb 0` instead.", - scope="local", - ) - - self.mm_processor_cache_gb = 0 - elif envs.VLLM_MM_INPUT_CACHE_GIB != 4: - logger.warning_once( - "VLLM_MM_INPUT_CACHE_GIB` is deprecated " - "and will be removed in v0.13. " - "Please use `--mm-processor-cache-gb %d` instead.", - envs.VLLM_MM_INPUT_CACHE_GIB, - scope="local", - ) - - self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB - - if self.enable_multimodal_encoder_data_parallel: - logger.warning_once( - "--enable-multimodal-encoder-data-parallel` is deprecated " - "and will be removed in v0.13. " - "Please use `--mm-encoder-tp-mode data` instead.", - scope="local", - ) - - self.mm_encoder_tp_mode = "data" - return ModelConfig( model=self.model, hf_config_path=self.hf_config_path, runner=self.runner, convert=self.convert, - task=self.task, tokenizer=self.tokenizer, tokenizer_mode=self.tokenizer_mode, trust_remote_code=self.trust_remote_code, @@ -1564,22 +1543,6 @@ class EngineArgs: model_config.skip_tokenizer_init = True logger.info("Skipping tokenizer initialization for tokens-only mode.") - if self.async_scheduling and not self.disable_nccl_for_dp_synchronization: - logger.info( - "Disabling NCCL for DP synchronization when using async scheduling." - ) - self.disable_nccl_for_dp_synchronization = True - - # Forward the deprecated CLI args to the EPLB config. - if self.num_redundant_experts is not None: - self.eplb_config.num_redundant_experts = self.num_redundant_experts - if self.eplb_window_size is not None: - self.eplb_config.window_size = self.eplb_window_size - if self.eplb_step_interval is not None: - self.eplb_config.step_interval = self.eplb_step_interval - if self.eplb_log_balancedness is not None: - self.eplb_config.log_balancedness = self.eplb_log_balancedness - parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1599,6 +1562,7 @@ class EngineArgs: enable_expert_parallel=self.enable_expert_parallel, all2all_backend=self.all2all_backend, enable_dbo=self.enable_dbo, + ubatch_size=self.ubatch_size, dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization, @@ -1683,6 +1647,22 @@ class EngineArgs: if model_config.quantization == "bitsandbytes": self.quantization = self.load_format = "bitsandbytes" + # Attention config overrides + attention_config = copy.deepcopy(self.attention_config) + if self.attention_backend is not None: + if attention_config.backend is not None: + raise ValueError( + "attention_backend and attention_config.backend " + "are mutually exclusive" + ) + # Convert string to enum if needed (CLI parsing returns a string) + if isinstance(self.attention_backend, str): + attention_config.backend = AttentionBackendEnum[ + self.attention_backend.upper() + ] + else: + attention_config.backend = self.attention_backend + load_config = self.create_load_config() # Pass reasoning_parser into StructuredOutputsConfig @@ -1698,22 +1678,14 @@ class EngineArgs: show_hidden_metrics_for_version=self.show_hidden_metrics_for_version, otlp_traces_endpoint=self.otlp_traces_endpoint, collect_detailed_traces=self.collect_detailed_traces, + kv_cache_metrics=self.kv_cache_metrics, + kv_cache_metrics_sample=self.kv_cache_metrics_sample, + cudagraph_metrics=self.cudagraph_metrics, + enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing, ) # Compilation config overrides compilation_config = copy.deepcopy(self.compilation_config) - if self.cuda_graph_sizes is not None: - logger.warning( - "--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or " - "v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes " - "instead." - ) - if compilation_config.cudagraph_capture_sizes is not None: - raise ValueError( - "cuda_graph_sizes and compilation_config." - "cudagraph_capture_sizes are mutually exclusive" - ) - compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes if self.cudagraph_capture_sizes is not None: if compilation_config.cudagraph_capture_sizes is not None: raise ValueError( @@ -1736,15 +1708,17 @@ class EngineArgs: parallel_config=parallel_config, scheduler_config=scheduler_config, device_config=device_config, + load_config=load_config, + attention_config=attention_config, lora_config=lora_config, speculative_config=speculative_config, - load_config=load_config, structured_outputs_config=self.structured_outputs_config, observability_config=observability_config, compilation_config=compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, ec_transfer_config=self.ec_transfer_config, + profiler_config=self.profiler_config, additional_config=self.additional_config, optimization_level=self.optimization_level, ) @@ -1821,6 +1795,7 @@ class EngineArgs: except Exception: # This is only used to set default_max_num_batched_tokens device_memory = 0 + device_name = "" # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces # throughput, see PR #17885 for more details. @@ -1885,16 +1860,6 @@ class EngineArgs: default_chunked_prefill = model_config.is_chunked_prefill_supported default_prefix_caching = model_config.is_prefix_caching_supported - if self.prefill_context_parallel_size > 1: - default_chunked_prefill = False - default_prefix_caching = False - logger.warning_once( - "--prefill-context-parallel-size > 1 is not compatible with " - "chunked prefill and prefix caching now. Chunked prefill " - "and prefix caching have been disabled by default.", - scope="local", - ) - if self.enable_chunked_prefill is None: self.enable_chunked_prefill = default_chunked_prefill @@ -2080,11 +2045,13 @@ def human_readable_int(value): "k": 10**3, "m": 10**6, "g": 10**9, + "t": 10**12, } binary_multiplier = { "K": 2**10, "M": 2**20, "G": 2**30, + "T": 2**40, } number, suffix = match.groups() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index f2b19c845018c..d94951a0cffc8 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -116,8 +116,10 @@ class EngineClient(ABC): ... @abstractmethod - async def reset_prefix_cache(self) -> None: - """Reset the prefix cache""" + async def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + """Reset the prefix cache and optionally any configured connector cache""" ... @abstractmethod diff --git a/vllm/entrypoints/anthropic/serving_messages.py b/vllm/entrypoints/anthropic/serving_messages.py index 340dabf0e7117..25c2d88a2c7a4 100644 --- a/vllm/entrypoints/anthropic/serving_messages.py +++ b/vllm/entrypoints/anthropic/serving_messages.py @@ -183,7 +183,9 @@ class AnthropicServingMessages(OpenAIServingChat): if anthropic_request.stream: req.stream = anthropic_request.stream - req.stream_options = StreamOptions.validate({"include_usage": True}) + req.stream_options = StreamOptions.validate( + {"include_usage": True, "continuous_usage_stats": True} + ) if anthropic_request.tool_choice is None: req.tool_choice = None @@ -322,6 +324,12 @@ class AnthropicServingMessages(OpenAIServingChat): id=origin_chunk.id, content=[], model=origin_chunk.model, + usage=AnthropicUsage( + input_tokens=origin_chunk.usage.prompt_tokens + if origin_chunk.usage + else 0, + output_tokens=0, + ), ), ) first_item = False diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 154cdeb42a3ea..b59f7120551e0 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -118,6 +118,7 @@ async def init_app( ) ) app.state.engine_client = engine + app.state.args = args return app diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 1643906894c66..ab055dfb1fb0e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque from collections.abc import Awaitable, Callable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast import jinja2 import jinja2.ext @@ -24,6 +24,7 @@ from openai.types.chat import ( ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartRefusalParam, ChatCompletionContentPartTextParam, + ChatCompletionFunctionToolParam, ChatCompletionMessageToolCallParam, ChatCompletionToolMessageParam, ) @@ -49,11 +50,20 @@ from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import random_uuid +from vllm.utils.collection_utils import is_list_of from vllm.utils.func_utils import supports_kw +from vllm.utils.import_utils import LazyLoader + +if TYPE_CHECKING: + import torch + + from vllm.tokenizers.mistral import MistralTokenizer +else: + torch = LazyLoader("torch", globals(), "torch") logger = init_logger(__name__) @@ -260,6 +270,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): reasoning: str | None """The reasoning content for interleaved thinking.""" + tools: list[ChatCompletionFunctionToolParam] | None + """The tools for developer role.""" + ChatCompletionMessageParam: TypeAlias = ( OpenAIChatCompletionMessageParam @@ -291,6 +304,9 @@ class ConversationMessage(TypedDict, total=False): reasoning_content: str | None """Deprecated: The reasoning content for interleaved thinking.""" + tools: list[ChatCompletionFunctionToolParam] | None + """The tools for developer role.""" + # Passed in by user ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] @@ -536,7 +552,7 @@ def resolve_hf_chat_template( def _resolve_chat_template_content_format( chat_template: str | None, tools: list[dict[str, Any]] | None, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, *, model_config: ModelConfig, ) -> _ChatTemplateContentFormat: @@ -593,7 +609,7 @@ def resolve_chat_template_content_format( chat_template: str | None, tools: list[dict[str, Any]] | None, given_format: ChatTemplateContentFormatOption, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, *, model_config: ModelConfig, ) -> _ChatTemplateContentFormat: @@ -620,6 +636,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"] _T = TypeVar("_T") +def _extract_embeds(tensors: list[torch.Tensor]): + if len(tensors) == 0: + return tensors + + if len(tensors) == 1: + tensors[0]._is_single_item = True # type: ignore + return tensors[0] # To keep backwards compatibility for single item input + + first_shape = tensors[0].shape + if all(t.shape == first_shape for t in tensors): + return torch.stack(tensors) + + return tensors + + +def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str): + embeds_key = f"{modality}_embeds" + embeds = items_by_modality[embeds_key] + + if len(embeds) == 0: + return embeds + if is_list_of(embeds, torch.Tensor): + return _extract_embeds(embeds) + if is_list_of(embeds, dict): + if not embeds: + return {} + + first_keys = set(embeds[0].keys()) + if any(set(item.keys()) != first_keys for item in embeds[1:]): + raise ValueError( + "All dictionaries in the list of embeddings must have the same keys." + ) + + return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys} + + return embeds + + class BaseMultiModalItemTracker(ABC, Generic[_T]): """ Tracks multi-modal items in a given request and ensures that the number @@ -627,11 +681,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): maximum per prompt. """ - def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike): + def __init__(self, model_config: ModelConfig): super().__init__() self._model_config = model_config - self._tokenizer = tokenizer self._items_by_modality = defaultdict[str, list[_T | None]](list) self._uuids_by_modality = defaultdict[str, list[str | None]](list) @@ -689,27 +742,25 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def all_mm_uuids(self) -> MultiModalUUIDDict | None: if not self._items_by_modality: return None - mm_uuids = {} + uuids_by_modality = dict(self._uuids_by_modality) if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") + if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality: + raise ValueError("Mixing raw audio and embedding inputs is not allowed") + mm_uuids = {} if "image_embeds" in uuids_by_modality: - image_embeds_uuids = uuids_by_modality["image_embeds"] - if len(image_embeds_uuids) > 1: - raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_uuids["image"] = uuids_by_modality["image_embeds"] if "image" in uuids_by_modality: mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images if "audio_embeds" in uuids_by_modality: - audio_embeds_uuids = uuids_by_modality["audio_embeds"] - if len(audio_embeds_uuids) > 1: - raise ValueError("Only one message can have {'type': 'audio_embeds'}") mm_uuids["audio"] = uuids_by_modality["audio_embeds"] if "audio" in uuids_by_modality: mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios if "video" in uuids_by_modality: mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos + return mm_uuids @abstractmethod @@ -721,29 +772,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None - mm_inputs = {} + items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "audio" in items_by_modality and "audio_embeds" in items_by_modality: raise ValueError("Mixing raw audio and embedding inputs is not allowed") + mm_inputs = {} if "image_embeds" in items_by_modality: - image_embeds_lst = items_by_modality["image_embeds"] - if len(image_embeds_lst) > 1: - raise ValueError("Only one message can have {'type': 'image_embeds'}") - mm_inputs["image"] = image_embeds_lst[0] + mm_inputs["image"] = _get_embeds_data(items_by_modality, "image") if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio_embeds" in items_by_modality: - audio_embeds_lst = items_by_modality["audio_embeds"] - if len(audio_embeds_lst) > 1: - raise ValueError("Only one message can have {'type': 'audio_embeds'}") - mm_inputs["audio"] = audio_embeds_lst[0] + mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio") if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -754,38 +801,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): async def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None - mm_inputs = {} - items_by_modality = {} - for modality, items in self._items_by_modality.items(): - coros = [] - for item in items: - if item is not None: - coros.append(item) - else: - coros.append(asyncio.sleep(0)) - items_by_modality[modality] = await asyncio.gather(*coros) + coros_by_modality = { + modality: [item or asyncio.sleep(0) for item in items] + for modality, items in self._items_by_modality.items() + } + items_by_modality: dict[str, list[object | None]] = { + modality: await asyncio.gather(*coros) + for modality, coros in coros_by_modality.items() + } if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "audio" in items_by_modality and "audio_embeds" in items_by_modality: raise ValueError("Mixing raw audio and embedding inputs is not allowed") + mm_inputs = {} if "image_embeds" in items_by_modality: - image_embeds_lst = items_by_modality["image_embeds"] - if len(image_embeds_lst) > 1: - raise ValueError("Only one message can have {'type': 'image_embeds'}") - mm_inputs["image"] = image_embeds_lst[0] + mm_inputs["image"] = _get_embeds_data(items_by_modality, "image") if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images if "audio_embeds" in items_by_modality: - audio_embeds_lst = items_by_modality["audio_embeds"] - if len(audio_embeds_lst) > 1: - raise ValueError("Only one message can have {'type': 'audio_embeds'}") - mm_inputs["audio"] = audio_embeds_lst[0] + mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio") if "audio" in items_by_modality: mm_inputs["audio"] = items_by_modality["audio"] # A list of audios if "video" in items_by_modality: mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs def create_parser(self) -> "BaseMultiModalContentParser": @@ -1139,11 +1180,19 @@ def validate_chat_template(chat_template: Path | str | None): not any(c in chat_template for c in JINJA_CHARS) and not Path(chat_template).exists() ): - raise ValueError( - f"The supplied chat template string ({chat_template}) " - f"appears path-like, but doesn't exist!" + # Try to find the template in the built-in templates directory + from vllm.transformers_utils.chat_templates.registry import ( + CHAT_TEMPLATES_DIR, ) + builtin_template_path = CHAT_TEMPLATES_DIR / chat_template + if not builtin_template_path.exists(): + raise ValueError( + f"The supplied chat template string ({chat_template}) " + f"appears path-like, but doesn't exist! " + f"Tried: {chat_template} and {builtin_template_path}" + ) + else: raise TypeError(f"{type(chat_template)} is not a valid chat template type") @@ -1173,12 +1222,23 @@ def _load_chat_template( JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): - msg = ( - f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}" + # Try to load from the built-in templates directory + from vllm.transformers_utils.chat_templates.registry import ( + CHAT_TEMPLATES_DIR, ) - raise ValueError(msg) from e + + builtin_template_path = CHAT_TEMPLATES_DIR / chat_template + try: + with open(builtin_template_path) as f: + return f.read() + except OSError: + msg = ( + f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be opened. " + f"Tried: {chat_template} and {builtin_template_path}. " + f"Reason: {e}" + ) + raise ValueError(msg) from e # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly @@ -1530,6 +1590,7 @@ def _parse_chat_message_content( role = message["role"] content = message.get("content") reasoning = message.get("reasoning") or message.get("reasoning_content") + if content is None: content = [] elif isinstance(content, str): @@ -1565,6 +1626,8 @@ def _parse_chat_message_content( if "name" in message and isinstance(message["name"], str): result_msg["name"] = message["name"] + if role == "developer": + result_msg["tools"] = message.get("tools", None) return result @@ -1575,12 +1638,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: # so, for messages that have tool_calls, parse the string (which we get # from openAI format) to dict for message in messages: - if ( - message["role"] == "assistant" - and "tool_calls" in message - and isinstance(message["tool_calls"], list) - ): - for item in message["tool_calls"]: + if message["role"] == "assistant" and "tool_calls" in message: + tool_calls = message.get("tool_calls") + if not isinstance(tool_calls, list): + continue + + if len(tool_calls) == 0: + # Drop empty tool_calls to keep templates on the normal assistant path. + message.pop("tool_calls", None) + continue + + for item in tool_calls: # if arguments is None or empty string, set to {} if content := item["function"].get("arguments"): if not isinstance(content, (dict, list)): @@ -1592,7 +1660,6 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: def parse_chat_messages( messages: list[ChatCompletionMessageParam], model_config: ModelConfig, - tokenizer: TokenizerLike, content_format: _ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], @@ -1600,7 +1667,7 @@ def parse_chat_messages( MultiModalUUIDDict | None, ]: conversation: list[ConversationMessage] = [] - mm_tracker = MultiModalItemTracker(model_config, tokenizer) + mm_tracker = MultiModalItemTracker(model_config) for msg in messages: sub_messages = _parse_chat_message_content( @@ -1624,7 +1691,6 @@ def parse_chat_messages( def parse_chat_messages_futures( messages: list[ChatCompletionMessageParam], model_config: ModelConfig, - tokenizer: TokenizerLike, content_format: _ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], @@ -1632,7 +1698,7 @@ def parse_chat_messages_futures( MultiModalUUIDDict | None, ]: conversation: list[ConversationMessage] = [] - mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) + mm_tracker = AsyncMultiModalItemTracker(model_config) for msg in messages: sub_messages = _parse_chat_message_content( @@ -1781,7 +1847,7 @@ def apply_hf_chat_template( def apply_mistral_chat_template( - tokenizer: MistralTokenizer, + tokenizer: "MistralTokenizer", messages: list[ChatCompletionMessageParam], chat_template: str | None, tools: list[dict[str, Any]] | None, diff --git a/vllm/entrypoints/cli/__init__.py b/vllm/entrypoints/cli/__init__.py index 9dff68236fe94..dc02ac563406a 100644 --- a/vllm/entrypoints/cli/__init__.py +++ b/vllm/entrypoints/cli/__init__.py @@ -2,12 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand +from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand __all__: list[str] = [ "BenchmarkLatencySubcommand", "BenchmarkServingSubcommand", + "BenchmarkStartupSubcommand", "BenchmarkSweepSubcommand", "BenchmarkThroughputSubcommand", ] diff --git a/vllm/entrypoints/cli/benchmark/startup.py b/vllm/entrypoints/cli/benchmark/startup.py new file mode 100644 index 0000000000000..81eefd7c174dc --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/startup.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm.benchmarks.startup import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase + + +class BenchmarkStartupSubcommand(BenchmarkSubcommandBase): + """The `startup` subcommand for `vllm bench`.""" + + name = "startup" + help = "Benchmark the startup time of vLLM models." + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index fb49be370203e..1c18b193d1cdc 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -109,6 +109,10 @@ def _add_query_options(parser: FlexibleArgumentParser) -> FlexibleArgumentParser help=( "API key for OpenAI services. If provided, this api key " "will overwrite the api key obtained through environment variables." + " It is important to note that this option only applies to the " + "OpenAI-compatible API endpoints and NOT other endpoints that may " + "be present in the server. See the security guide in the vLLM docs " + "for more details." ), ) return parser diff --git a/vllm/entrypoints/constants.py b/vllm/entrypoints/constants.py index b5bcccc35d6c8..5726ee0735d4c 100644 --- a/vllm/entrypoints/constants.py +++ b/vllm/entrypoints/constants.py @@ -8,3 +8,5 @@ Shared constants for vLLM entrypoints. # These constants help mitigate header abuse attacks H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB H11_MAX_HEADER_COUNT_DEFAULT = 256 + +MCP_PREFIX = "mcp_" diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 7a41c668d7645..b076b883b4d93 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -2,24 +2,49 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import contextlib +import copy import json import logging from abc import ABC, abstractmethod +from collections.abc import Callable from contextlib import AsyncExitStack +from dataclasses import replace from typing import TYPE_CHECKING, Union +from openai.types.responses.response_function_tool_call_output_item import ( + ResponseFunctionToolCallOutputItem, +) from openai.types.responses.tool import Mcp from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm import envs -from vllm.entrypoints.harmony_utils import ( +from vllm.entrypoints.chat_utils import ( + ChatTemplateContentFormatOption, +) +from vllm.entrypoints.constants import MCP_PREFIX +from vllm.entrypoints.openai.parser.harmony_utils import ( get_encoding, get_streamable_parser_for_assistant, render_for_completion, ) +from vllm.entrypoints.openai.parser.responses_parser import ( + get_responses_parser_for_simple_context, +) +from vllm.entrypoints.openai.protocol import ( + FunctionCall, + ResponseInputOutputItem, + ResponseRawMessageAndToken, + ResponsesRequest, +) +from vllm.entrypoints.responses_utils import construct_tool_dicts from vllm.entrypoints.tool import Tool from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser +from vllm.tokenizers.protocol import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid if TYPE_CHECKING: from mcp.client import ClientSession @@ -51,24 +76,24 @@ class TurnMetrics: def __init__( self, - input_tokens=0, - output_tokens=0, - cached_input_tokens=0, - tool_output_tokens=0, - ): + input_tokens: int = 0, + output_tokens: int = 0, + cached_input_tokens: int = 0, + tool_output_tokens: int = 0, + ) -> None: self.input_tokens = input_tokens self.output_tokens = output_tokens self.cached_input_tokens = cached_input_tokens self.tool_output_tokens = tool_output_tokens - def reset(self): + def reset(self) -> None: """Reset counters for a new turn.""" self.input_tokens = 0 self.output_tokens = 0 self.cached_input_tokens = 0 self.tool_output_tokens = 0 - def copy(self): + def copy(self) -> "TurnMetrics": """Create a copy of this turn's token counts.""" return TurnMetrics( self.input_tokens, @@ -137,8 +162,16 @@ def _create_json_parse_error_messages( class SimpleContext(ConversationContext): + """This is a context that cannot handle MCP tool calls""" + def __init__(self): self.last_output = None + + # Accumulated final output for streaming mode + self._accumulated_text: str = "" + self._accumulated_token_ids: list[int] = [] + self._accumulated_logprobs: list = [] + self.num_prompt_tokens = 0 self.num_output_tokens = 0 self.num_cached_tokens = 0 @@ -147,6 +180,9 @@ class SimpleContext(ConversationContext): # not implemented yet for SimpleContext self.all_turn_metrics = [] + self.input_messages: list[ResponseRawMessageAndToken] = [] + self.output_messages: list[ResponseRawMessageAndToken] = [] + def append_output(self, output) -> None: self.last_output = output if not isinstance(output, RequestOutput): @@ -155,6 +191,44 @@ class SimpleContext(ConversationContext): self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) + # Accumulate text, token_ids, and logprobs for streaming mode + delta_output = output.outputs[0] + self._accumulated_text += delta_output.text + self._accumulated_token_ids.extend(delta_output.token_ids) + if delta_output.logprobs is not None: + self._accumulated_logprobs.extend(delta_output.logprobs) + + if len(self.input_messages) == 0: + output_prompt = output.prompt or "" + output_prompt_token_ids = output.prompt_token_ids or [] + self.input_messages.append( + ResponseRawMessageAndToken( + message=output_prompt, + tokens=output_prompt_token_ids, + ) + ) + self.output_messages.append( + ResponseRawMessageAndToken( + message=delta_output.text, + tokens=delta_output.token_ids, + ) + ) + + @property + def final_output(self) -> RequestOutput | None: + """Return the final output, with complete text/token_ids/logprobs.""" + if self.last_output is not None and self.last_output.outputs: + assert isinstance(self.last_output, RequestOutput) + final_output = copy.copy(self.last_output) + # copy inner item to avoid modify last_output + final_output.outputs = [replace(item) for item in self.last_output.outputs] + final_output.outputs[0].text = self._accumulated_text + final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids) + if self._accumulated_logprobs: + final_output.outputs[0].logprobs = self._accumulated_logprobs + return final_output + return self.last_output + def append_tool_output(self, output) -> None: raise NotImplementedError("Should not be called.") @@ -180,6 +254,253 @@ class SimpleContext(ConversationContext): raise NotImplementedError("Should not be called.") +class ParsableContext(ConversationContext): + def __init__( + self, + *, + response_messages: list[ResponseInputOutputItem], + tokenizer: AnyTokenizer, + reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None, + request: ResponsesRequest, + available_tools: list[str] | None, + tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, + chat_template: str | None, + chat_template_content_format: ChatTemplateContentFormatOption, + ): + self.num_prompt_tokens = 0 + self.num_output_tokens = 0 + self.num_cached_tokens = 0 + # TODO: num_reasoning_tokens is not implemented yet. + self.num_reasoning_tokens = 0 + # not implemented yet for ParsableContext + self.all_turn_metrics: list[TurnMetrics] = [] + + if reasoning_parser_cls is None: + raise ValueError("reasoning_parser_cls must be provided.") + + self.parser = get_responses_parser_for_simple_context( + tokenizer=tokenizer, + reasoning_parser_cls=reasoning_parser_cls, + response_messages=response_messages, + request=request, + tool_parser_cls=tool_parser_cls, + ) + self.tool_parser_cls = tool_parser_cls + self.request = request + self.tokenizer = tokenizer + + self.available_tools = available_tools or [] + self._tool_sessions: dict[str, ClientSession | Tool] = {} + self.called_tools: set[str] = set() + + self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) + self.chat_template = chat_template + self.chat_template_content_format = chat_template_content_format + + self.input_messages: list[ResponseRawMessageAndToken] = [] + self.output_messages: list[ResponseRawMessageAndToken] = [] + + def append_output(self, output: RequestOutput) -> None: + self.num_prompt_tokens = len(output.prompt_token_ids or []) + self.num_cached_tokens = output.num_cached_tokens or 0 + self.num_output_tokens += len(output.outputs[0].token_ids or []) + self.parser.process(output.outputs[0]) + + # only store if enable_response_messages is True, save memory + if self.request.enable_response_messages: + output_prompt = output.prompt or "" + output_prompt_token_ids = output.prompt_token_ids or [] + if len(self.input_messages) == 0: + self.input_messages.append( + ResponseRawMessageAndToken( + message=output_prompt, + tokens=output_prompt_token_ids, + ) + ) + else: + self.output_messages.append( + ResponseRawMessageAndToken( + message=output_prompt, + tokens=output_prompt_token_ids, + ) + ) + self.output_messages.append( + ResponseRawMessageAndToken( + message=output.outputs[0].text, + tokens=output.outputs[0].token_ids, + ) + ) + + def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None: + self.parser.response_messages.extend(output) + + def need_builtin_tool_call(self) -> bool: + """Return true if the last message is a MCP tool call""" + last_message = self.parser.response_messages[-1] + # TODO(qandrew): figure out which tools are MCP tools + if last_message.type == "function_call": # noqa: SIM102 + if last_message.name in ( + "code_interpreter", + "python", + "web_search_preview", + ) or last_message.name.startswith("container"): + return True + + return False + + async def call_python_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall + ) -> list[ResponseInputOutputItem]: + self.called_tools.add("python") + if isinstance(tool_session, Tool): + return await tool_session.get_result_parsable_context(self) + args = json.loads(last_msg.arguments) + param = { + "code": args["code"], + } + result = await tool_session.call_tool("python", param) + result_str = result.content[0].text + + message = ResponseFunctionToolCallOutputItem( + id=f"mcpo_{random_uuid()}", + type="function_call_output", + call_id=f"call_{random_uuid()}", + output=result_str, + status="completed", + ) + + return [message] + + async def call_search_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall + ) -> list[ResponseInputOutputItem]: + self.called_tools.add("browser") + if isinstance(tool_session, Tool): + return await tool_session.get_result_parsable_context(self) + if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: + try: + args = json.loads(last_msg.arguments) + except json.JSONDecodeError as e: + return _create_json_parse_error_messages(last_msg, e) + else: + args = json.loads(last_msg.arguments) + result = await tool_session.call_tool("search", args) + result_str = result.content[0].text + + message = ResponseFunctionToolCallOutputItem( + id=f"fco_{random_uuid()}", + type="function_call_output", + call_id=f"call_{random_uuid()}", + output=result_str, + status="completed", + ) + + return [message] + + async def call_container_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: Message + ) -> list[Message]: + """ + Call container tool. Expect this to be run in a stateful docker + with command line terminal. + The official container tool would at least + expect the following format: + - for tool name: exec + - args: + { + "cmd":List[str] "command to execute", + "workdir":optional[str] "current working directory", + "env":optional[object/dict] "environment variables", + "session_name":optional[str] "session name", + "timeout":optional[int] "timeout in seconds", + "user":optional[str] "user name", + } + """ + self.called_tools.add("container") + if isinstance(tool_session, Tool): + return await tool_session.get_result_parsable_context(self) + # tool_name = last_msg.recipient.split(".")[1].split(" ")[0] + if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: + try: + args = json.loads(last_msg.arguments) + except json.JSONDecodeError as e: + return _create_json_parse_error_messages(last_msg, e) + else: + args = json.loads(last_msg.arguments) + result = await tool_session.call_tool("exec", args) + result_str = result.content[0].text + + message = ResponseFunctionToolCallOutputItem( + id=f"fco_{random_uuid()}", + type="function_call_output", + call_id=f"call_{random_uuid()}", + output=result_str, + status="completed", + ) + + return [message] + + async def call_tool(self) -> list[ResponseInputOutputItem]: + if not self.parser.response_messages: + return [] + last_msg = self.parser.response_messages[-1] + # change this to a mcp_ function call + last_msg.id = f"{MCP_PREFIX}{random_uuid()}" + self.parser.response_messages[-1] = last_msg + if last_msg.name == "code_interpreter": + return await self.call_python_tool(self._tool_sessions["python"], last_msg) + elif last_msg.name == "web_search_preview": + return await self.call_search_tool(self._tool_sessions["browser"], last_msg) + elif last_msg.name.startswith("container"): + return await self.call_container_tool( + self._tool_sessions["container"], last_msg + ) + return [] + + def render_for_completion(self): + raise NotImplementedError("Should not be called.") + + async def init_tool_sessions( + self, + tool_server: ToolServer | None, + exit_stack: AsyncExitStack, + request_id: str, + mcp_tools: dict[str, Mcp], + ): + if tool_server: + for tool_name in self.available_tools: + if tool_name in self._tool_sessions: + continue + + tool_type = _map_tool_name_to_tool_type(tool_name) + headers = ( + mcp_tools[tool_type].headers if tool_type in mcp_tools else None + ) + tool_session = await exit_stack.enter_async_context( + tool_server.new_session(tool_name, request_id, headers) + ) + self._tool_sessions[tool_name] = tool_session + exit_stack.push_async_exit(self.cleanup_session) + + async def cleanup_session(self, *args, **kwargs) -> None: + """Can be used as coro to used in __aexit__""" + + async def cleanup_tool_session(tool_session): + if not isinstance(tool_session, Tool): + logger.info( + "Cleaning up tool session for %s", tool_session._client_info + ) + with contextlib.suppress(Exception): + await tool_session.call_tool("cleanup_session", {}) + + await asyncio.gather( + *( + cleanup_tool_session(self._tool_sessions[tool]) + for tool in self.called_tools + ) + ) + + class HarmonyContext(ConversationContext): def __init__( self, diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py deleted file mode 100644 index 47a252348c102..0000000000000 --- a/vllm/entrypoints/harmony_utils.py +++ /dev/null @@ -1,535 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import datetime -import json -from collections.abc import Iterable, Sequence -from typing import Literal - -from openai.types.responses import ( - ResponseFunctionToolCall, - ResponseOutputItem, - ResponseOutputMessage, - ResponseOutputText, - ResponseReasoningItem, -) -from openai.types.responses.response_function_web_search import ( - ActionFind, - ActionOpenPage, - ActionSearch, - ResponseFunctionWebSearch, -) -from openai.types.responses.response_reasoning_item import ( - Content as ResponseReasoningTextContent, -) -from openai.types.responses.tool import Tool -from openai_harmony import ( - Author, - ChannelConfig, - Conversation, - DeveloperContent, - HarmonyEncodingName, - Message, - ReasoningEffort, - Role, - StreamableParser, - SystemContent, - TextContent, - ToolDescription, - load_harmony_encoding, -) -from openai_harmony import Message as OpenAIHarmonyMessage -from openai_harmony import Role as OpenAIHarmonyRole - -from vllm import envs -from vllm.entrypoints.openai.protocol import ( - ChatCompletionToolsParam, - ResponseInputOutputItem, - ResponsesRequest, -) -from vllm.utils import random_uuid - -REASONING_EFFORT = { - "high": ReasoningEffort.HIGH, - "medium": ReasoningEffort.MEDIUM, - "low": ReasoningEffort.LOW, -} - -_harmony_encoding = None - -# Builtin tools that should be included in the system message when -# they are available and requested by the user. -# Tool args are provided by MCP tool descriptions. Output -# of the tools are stringified. -MCP_BUILTIN_TOOLS: set[str] = { - "web_search_preview", - "code_interpreter", - "container", -} - - -def has_custom_tools(tool_types: set[str]) -> bool: - """ - Checks if the given tool types are custom tools - (i.e. any tool other than MCP buildin tools) - """ - return not tool_types.issubset(MCP_BUILTIN_TOOLS) - - -def get_encoding(): - global _harmony_encoding - if _harmony_encoding is None: - _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) - return _harmony_encoding - - -def get_system_message( - model_identity: str | None = None, - reasoning_effort: Literal["high", "medium", "low"] | None = None, - start_date: str | None = None, - browser_description: str | None = None, - python_description: str | None = None, - container_description: str | None = None, - instructions: str | None = None, - with_custom_tools: bool = False, -) -> Message: - sys_msg_content = SystemContent.new() - if model_identity is not None: - sys_msg_content = sys_msg_content.with_model_identity(model_identity) - if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: - current_identity = sys_msg_content.model_identity - new_identity = ( - f"{current_identity}\n{instructions}" if current_identity else instructions - ) - sys_msg_content = sys_msg_content.with_model_identity(new_identity) - if reasoning_effort is not None: - sys_msg_content = sys_msg_content.with_reasoning_effort( - REASONING_EFFORT[reasoning_effort] - ) - if start_date is None: - # NOTE(woosuk): This brings non-determinism in vLLM. Be careful. - start_date = datetime.datetime.now().strftime("%Y-%m-%d") - sys_msg_content = sys_msg_content.with_conversation_start_date(start_date) - if browser_description is not None: - sys_msg_content = sys_msg_content.with_tools(browser_description) - if python_description is not None: - sys_msg_content = sys_msg_content.with_tools(python_description) - if container_description is not None: - sys_msg_content = sys_msg_content.with_tools(container_description) - if not with_custom_tools: - channel_config = sys_msg_content.channel_config - invalid_channel = "commentary" - new_config = ChannelConfig.require_channels( - [c for c in channel_config.valid_channels if c != invalid_channel] - ) - sys_msg_content = sys_msg_content.with_channel_config(new_config) - sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) - return sys_msg - - -def create_tool_definition(tool: ChatCompletionToolsParam | Tool): - if isinstance(tool, ChatCompletionToolsParam): - return ToolDescription.new( - name=tool.function.name, - description=tool.function.description, - parameters=tool.function.parameters, - ) - return ToolDescription.new( - name=tool.name, - description=tool.description, - parameters=tool.parameters, - ) - - -def get_developer_message( - instructions: str | None = None, - tools: list[Tool | ChatCompletionToolsParam] | None = None, -) -> Message: - dev_msg_content = DeveloperContent.new() - if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: - dev_msg_content = dev_msg_content.with_instructions(instructions) - if tools is not None: - function_tools: list[Tool | ChatCompletionToolsParam] = [] - for tool in tools: - if tool.type in ( - "web_search_preview", - "code_interpreter", - "container", - "mcp", - ): - # These are built-in tools that are added to the system message. - # Adding in MCP for now until we support MCP tools executed - # server side - pass - - elif tool.type == "function": - function_tools.append(tool) - else: - raise ValueError(f"tool type {tool.type} not supported") - if function_tools: - function_tool_descriptions = [ - create_tool_definition(tool) for tool in function_tools - ] - dev_msg_content = dev_msg_content.with_function_tools( - function_tool_descriptions - ) - dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) - return dev_msg - - -def get_user_message(content: str) -> Message: - return Message.from_role_and_content(Role.USER, content) - - -def parse_response_input( - response_msg: ResponseInputOutputItem, - prev_responses: list[ResponseOutputItem | ResponseReasoningItem], -) -> Message: - if not isinstance(response_msg, dict): - response_msg = response_msg.model_dump() - if "type" not in response_msg or response_msg["type"] == "message": - role = response_msg["role"] - content = response_msg["content"] - if role == "system": - # User is trying to set a system message. Change it to: - # <|start|>developer<|message|># Instructions - # {instructions}<|end|> - role = "developer" - text_prefix = "Instructions:\n" - else: - text_prefix = "" - if isinstance(content, str): - msg = Message.from_role_and_content(role, text_prefix + content) - else: - contents = [TextContent(text=text_prefix + c["text"]) for c in content] - msg = Message.from_role_and_contents(role, contents) - if role == "assistant": - msg = msg.with_channel("final") - elif response_msg["type"] == "function_call_output": - call_id = response_msg["call_id"] - call_response: ResponseFunctionToolCall | None = None - for prev_response in reversed(prev_responses): - if ( - isinstance(prev_response, ResponseFunctionToolCall) - and prev_response.call_id == call_id - ): - call_response = prev_response - break - if call_response is None: - raise ValueError(f"No call message found for {call_id}") - msg = Message.from_author_and_content( - Author.new(Role.TOOL, f"functions.{call_response.name}"), - response_msg["output"], - ) - elif response_msg["type"] == "reasoning": - content = response_msg["content"] - assert len(content) == 1 - msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) - elif response_msg["type"] == "function_call": - msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"]) - msg = msg.with_channel("commentary") - msg = msg.with_recipient(f"functions.{response_msg['name']}") - msg = msg.with_content_type("json") - else: - raise ValueError(f"Unknown input type: {response_msg['type']}") - return msg - - -def parse_input_to_harmony_message(chat_msg) -> list[Message]: - if not isinstance(chat_msg, dict): - # Handle Pydantic models - chat_msg = chat_msg.model_dump(exclude_none=True) - - role = chat_msg.get("role") - - # Assistant message with tool calls - tool_calls = chat_msg.get("tool_calls") - if role == "assistant" and tool_calls: - msgs: list[Message] = [] - for call in tool_calls: - func = call.get("function", {}) - name = func.get("name", "") - arguments = func.get("arguments", "") or "" - msg = Message.from_role_and_content(Role.ASSISTANT, arguments) - msg = msg.with_channel("commentary") - msg = msg.with_recipient(f"functions.{name}") - msg = msg.with_content_type("json") - msgs.append(msg) - return msgs - - # Tool role message (tool output) - if role == "tool": - name = chat_msg.get("name", "") - content = chat_msg.get("content", "") or "" - if isinstance(content, list): - # Handle array format for tool message content - # by concatenating all text parts. - content = "".join( - item.get("text", "") - for item in content - if isinstance(item, dict) and item.get("type") == "text" - ) - - msg = Message.from_author_and_content( - Author.new(Role.TOOL, f"functions.{name}"), content - ).with_channel("commentary") - return [msg] - - # Default: user/assistant/system messages with content - content = chat_msg.get("content", "") - if isinstance(content, str): - contents = [TextContent(text=content)] - else: - # TODO: Support refusal. - contents = [TextContent(text=c.get("text", "")) for c in content] - msg = Message.from_role_and_contents(role, contents) - return [msg] - - -def construct_harmony_previous_input_messages( - request: ResponsesRequest, -) -> list[OpenAIHarmonyMessage]: - messages: list[OpenAIHarmonyMessage] = [] - if request.previous_input_messages: - for message in request.previous_input_messages: - # Handle both OpenAIHarmonyMessage objects and dictionary inputs - if isinstance(message, OpenAIHarmonyMessage): - message_role = message.author.role - # To match OpenAI, instructions, reasoning and tools are - # always taken from the most recent Responses API request - # not carried over from previous requests - if ( - message_role == OpenAIHarmonyRole.SYSTEM - or message_role == OpenAIHarmonyRole.DEVELOPER - ): - continue - messages.append(message) - else: - harmony_messages = parse_input_to_harmony_message(message) - for harmony_msg in harmony_messages: - message_role = harmony_msg.author.role - # To match OpenAI, instructions, reasoning and tools are - # always taken from the most recent Responses API request - # not carried over from previous requests - if ( - message_role == OpenAIHarmonyRole.SYSTEM - or message_role == OpenAIHarmonyRole.DEVELOPER - ): - continue - messages.append(harmony_msg) - return messages - - -def render_for_completion(messages: list[Message]) -> list[int]: - conversation = Conversation.from_messages(messages) - token_ids = get_encoding().render_conversation_for_completion( - conversation, Role.ASSISTANT - ) - return token_ids - - -def parse_output_message(message: Message) -> list[ResponseOutputItem]: - """ - Parse a Harmony message into a list of output response items. - """ - if message.author.role != "assistant": - # This is a message from a tool to the assistant (e.g., search result). - # Don't include it in the final output for now. This aligns with - # OpenAI's behavior on models like o4-mini. - return [] - - output_items: list[ResponseOutputItem] = [] - recipient = message.recipient - if recipient is not None and recipient.startswith("browser."): - if len(message.content) != 1: - raise ValueError("Invalid number of contents in browser message") - content = message.content[0] - # We do not need to check the VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY - # env variable since if it is not set, we are certain the json is valid - # The use of Actions for web search will be removed entirely in - # the future, so this is only necessary temporarily - try: - browser_call = json.loads(content.text) - except json.JSONDecodeError: - # If the content is not valid JSON, then it was - # caught and retried by vLLM, which means we - # need to make note of that so the user is aware - json_retry_output_message = ( - f"Invalid JSON args, caught and retried: {content.text}" - ) - browser_call = { - "query": json_retry_output_message, - "url": json_retry_output_message, - "pattern": json_retry_output_message, - } - # TODO: translate to url properly! - if recipient == "browser.search": - action = ActionSearch( - query=f"cursor:{browser_call.get('query', '')}", type="search" - ) - elif recipient == "browser.open": - action = ActionOpenPage( - url=f"cursor:{browser_call.get('url', '')}", type="open_page" - ) - elif recipient == "browser.find": - action = ActionFind( - pattern=browser_call["pattern"], - url=f"cursor:{browser_call.get('url', '')}", - type="find", - ) - else: - raise ValueError(f"Unknown browser action: {recipient}") - web_search_item = ResponseFunctionWebSearch( - id=f"ws_{random_uuid()}", - action=action, - status="completed", - type="web_search_call", - ) - output_items.append(web_search_item) - elif message.channel == "analysis": - for content in message.content: - reasoning_item = ResponseReasoningItem( - id=f"rs_{random_uuid()}", - summary=[], - type="reasoning", - content=[ - ResponseReasoningTextContent( - text=content.text, type="reasoning_text" - ) - ], - status=None, - ) - output_items.append(reasoning_item) - elif message.channel == "commentary": - if recipient is not None and recipient.startswith("functions."): - function_name = recipient.split(".")[-1] - for content in message.content: - random_id = random_uuid() - response_item = ResponseFunctionToolCall( - arguments=content.text, - call_id=f"call_{random_id}", - type="function_call", - name=function_name, - id=f"fc_{random_id}", - ) - output_items.append(response_item) - elif recipient is not None and ( - recipient.startswith("python") - or recipient.startswith("browser") - or recipient.startswith("container") - ): - for content in message.content: - reasoning_item = ResponseReasoningItem( - id=f"rs_{random_uuid()}", - summary=[], - type="reasoning", - content=[ - ResponseReasoningTextContent( - text=content.text, type="reasoning_text" - ) - ], - status=None, - ) - output_items.append(reasoning_item) - else: - raise ValueError(f"Unknown recipient: {recipient}") - elif message.channel == "final": - contents = [] - for content in message.content: - output_text = ResponseOutputText( - text=content.text, - annotations=[], # TODO - type="output_text", - logprobs=None, # TODO - ) - contents.append(output_text) - text_item = ResponseOutputMessage( - id=f"msg_{random_uuid()}", - content=contents, - role=message.author.role, - status="completed", - type="message", - ) - output_items.append(text_item) - else: - raise ValueError(f"Unknown channel: {message.channel}") - return output_items - - -def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]: - if not parser.current_content: - return [] - if parser.current_role != Role.ASSISTANT: - return [] - current_recipient = parser.current_recipient - if current_recipient is not None and current_recipient.startswith("browser."): - return [] - - if parser.current_channel == "analysis": - reasoning_item = ResponseReasoningItem( - id=f"rs_{random_uuid()}", - summary=[], - type="reasoning", - content=[ - ResponseReasoningTextContent( - text=parser.current_content, type="reasoning_text" - ) - ], - status=None, - ) - return [reasoning_item] - elif parser.current_channel == "final": - output_text = ResponseOutputText( - text=parser.current_content, - annotations=[], # TODO - type="output_text", - logprobs=None, # TODO - ) - text_item = ResponseOutputMessage( - id=f"msg_{random_uuid()}", - content=[output_text], - role="assistant", - # if the parser still has messages (ie if the generator got cut - # abruptly), this should be incomplete - status="incomplete", - type="message", - ) - return [text_item] - return [] - - -def get_stop_tokens_for_assistant_actions() -> list[int]: - return get_encoding().stop_tokens_for_assistant_actions() - - -def get_streamable_parser_for_assistant() -> StreamableParser: - return StreamableParser(get_encoding(), role=Role.ASSISTANT) - - -def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: - parser = get_streamable_parser_for_assistant() - for token_id in token_ids: - parser.process(token_id) - return parser - - -def parse_chat_output( - token_ids: Sequence[int], -) -> tuple[str | None, str | None, bool]: - parser = parse_output_into_messages(token_ids) - output_msgs = parser.messages - is_tool_call = False # TODO: update this when tool call is supported - if len(output_msgs) == 0: - # The generation has stopped during reasoning. - reasoning = parser.current_content - final_content = None - elif len(output_msgs) == 1: - # The generation has stopped during final message. - reasoning = output_msgs[0].content[0].text - final_content = parser.current_content - else: - reasoning_msg = output_msgs[:-1] - final_msg = output_msgs[-1] - reasoning = "\n".join([msg.content[0].text for msg in reasoning_msg]) - final_content = final_msg.content[0].text - return reasoning, final_content, is_tool_call diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f005605c08d7e..2768e267f4837 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -9,7 +9,7 @@ import cloudpickle import torch.nn as nn from pydantic import ValidationError from tqdm.auto import tqdm -from typing_extensions import TypeVar, deprecated +from typing_extensions import TypeVar from vllm.beam_search import ( BeamSearchInstance, @@ -18,8 +18,10 @@ from vllm.beam_search import ( create_sort_beams_key_function, ) from vllm.config import ( + AttentionConfig, CompilationConfig, PoolerConfig, + ProfilerConfig, StructuredOutputsConfig, is_init_field, ) @@ -71,8 +73,8 @@ from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask -from vllm.tokenizers import MistralTokenizer, TokenizerLike -from vllm.tokenizers.hf import get_cached_tokenizer +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.counter import Counter @@ -174,6 +176,10 @@ class LLM: compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the mode of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. + attention_config: Configuration for attention mechanisms. Can be a + dictionary or an AttentionConfig instance. If a dictionary, it will + be converted to an AttentionConfig. Allows specifying the attention + backend and other attention-related settings. **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. Note: @@ -198,7 +204,7 @@ class LLM: quantization: QuantizationMethods | None = None, revision: str | None = None, tokenizer_revision: str | None = None, - seed: int | None = None, + seed: int = 0, gpu_memory_utilization: float = 0.9, swap_space: float = 4, cpu_offload_gb: float = 0, @@ -211,6 +217,8 @@ class LLM: structured_outputs_config: dict[str, Any] | StructuredOutputsConfig | None = None, + profiler_config: dict[str, Any] | ProfilerConfig | None = None, + attention_config: dict[str, Any] | AttentionConfig | None = None, kv_cache_memory_bytes: int | None = None, compilation_config: int | dict[str, Any] | CompilationConfig | None = None, logits_processors: list[str | type[LogitsProcessor]] | None = None, @@ -250,37 +258,28 @@ class LLM: if hf_overrides is None: hf_overrides = {} - if compilation_config is not None: - if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig( - mode=CompilationMode(compilation_config) - ) - elif isinstance(compilation_config, dict): - compilation_config_instance = CompilationConfig( - **{ - k: v - for k, v in compilation_config.items() - if is_init_field(CompilationConfig, k) - } - ) - else: - compilation_config_instance = compilation_config - else: - compilation_config_instance = CompilationConfig() + def _make_config(value: Any, cls: type[_R]) -> _R: + """Convert dict/None/instance to a config instance.""" + if value is None: + return cls() + if isinstance(value, dict): + return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type] + return value - if structured_outputs_config is not None: - if isinstance(structured_outputs_config, dict): - structured_outputs_instance = StructuredOutputsConfig( - **{ - k: v - for k, v in structured_outputs_config.items() - if is_init_field(StructuredOutputsConfig, k) - } - ) - else: - structured_outputs_instance = structured_outputs_config + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig( + mode=CompilationMode(compilation_config) + ) else: - structured_outputs_instance = StructuredOutputsConfig() + compilation_config_instance = _make_config( + compilation_config, CompilationConfig + ) + + structured_outputs_instance = _make_config( + structured_outputs_config, StructuredOutputsConfig + ) + profiler_config_instance = _make_config(profiler_config, ProfilerConfig) + attention_config_instance = _make_config(attention_config, AttentionConfig) # warn about single-process data parallel usage. _dp_size = int(kwargs.get("data_parallel_size", 1)) @@ -324,6 +323,8 @@ class LLM: mm_processor_kwargs=mm_processor_kwargs, pooler_config=pooler_config, structured_outputs_config=structured_outputs_instance, + profiler_config=profiler_config_instance, + attention_config=attention_config_instance, compilation_config=compilation_config_instance, logits_processors=logits_processors, **kwargs, @@ -350,16 +351,6 @@ class LLM: def get_tokenizer(self) -> TokenizerLike: return self.llm_engine.get_tokenizer() - @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.") - def set_tokenizer(self, tokenizer: TokenizerLike) -> None: - # While CachedTokenizer is dynamic, have no choice but - # compare class name. Misjudgment will arise from - # user-defined tokenizer started with 'Cached' - if tokenizer.__class__.__name__.startswith("Cached"): - self.llm_engine.tokenizer = tokenizer - else: - self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) - def reset_mm_cache(self) -> None: self.input_processor.clear_mm_cache() self.llm_engine.reset_mm_cache() @@ -834,7 +825,6 @@ class LLM: conversation, mm_data, mm_uuids = parse_chat_messages( msgs, model_config, - tokenizer, content_format=resolved_content_format, ) @@ -1077,6 +1067,7 @@ class LLM: params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, ) outputs = self._run_engine(use_tqdm=use_tqdm) @@ -1114,6 +1105,7 @@ class LLM: use_tqdm: bool | Callable[..., tqdm] = True, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> list[EmbeddingRequestOutput]: """ Generate an embedding vector for each prompt. @@ -1151,6 +1143,7 @@ class LLM: pooling_params=pooling_params, lora_request=lora_request, pooling_task="embed", + tokenization_kwargs=tokenization_kwargs, ) return [EmbeddingRequestOutput.from_base(item) for item in items] @@ -1162,6 +1155,7 @@ class LLM: use_tqdm: bool | Callable[..., tqdm] = True, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> list[ClassificationRequestOutput]: """ Generate class logits for each prompt. @@ -1197,6 +1191,7 @@ class LLM: pooling_params=pooling_params, lora_request=lora_request, pooling_task="classify", + tokenization_kwargs=tokenization_kwargs, ) return [ClassificationRequestOutput.from_base(item) for item in items] @@ -1210,6 +1205,7 @@ class LLM: use_tqdm: bool | Callable[..., tqdm] = True, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> list[PoolingRequestOutput]: """ Generate rewards for each prompt. @@ -1237,6 +1233,7 @@ class LLM: pooling_params=pooling_params, truncate_prompt_tokens=truncate_prompt_tokens, pooling_task="token_classify", + tokenization_kwargs=tokenization_kwargs, ) def _embedding_score( @@ -1248,6 +1245,7 @@ class LLM: use_tqdm: bool | Callable[..., tqdm] = True, pooling_params: PoolingParams | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> list[ScoringRequestOutput]: encoded_output: list[PoolingRequestOutput] = self.encode( text_1 + text_2, @@ -1256,6 +1254,7 @@ class LLM: lora_request=lora_request, pooling_params=pooling_params, pooling_task="embed", + tokenization_kwargs=tokenization_kwargs, ) encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)] @@ -1280,6 +1279,7 @@ class LLM: use_tqdm: bool | Callable[..., tqdm] = True, pooling_params: PoolingParams | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> list[ScoringRequestOutput]: model_config = self.model_config @@ -1295,7 +1295,8 @@ class LLM: pooling_params.verify("score", model_config) pooling_params_list = list[PoolingParams]() - tokenization_kwargs: dict[str, Any] = {} + local_kwargs = tokenization_kwargs or {} + tokenization_kwargs = local_kwargs.copy() _validate_truncation_size( model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs @@ -1492,8 +1493,12 @@ class LLM: def stop_profile(self) -> None: self.llm_engine.stop_profile() - def reset_prefix_cache(self) -> None: - self.llm_engine.reset_prefix_cache() + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.llm_engine.reset_prefix_cache( + reset_running_requests, reset_connector + ) def sleep(self, level: int = 1): """ @@ -1554,6 +1559,7 @@ class LLM: use_tqdm: bool | Callable[..., tqdm] = True, lora_request: Sequence[LoRARequest] | LoRARequest | None, priority: list[int] | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> None: if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. @@ -1599,6 +1605,7 @@ class LLM: if isinstance(lora_request, Sequence) else lora_request, priority=priority[i] if priority else 0, + tokenization_kwargs=tokenization_kwargs, ) added_request_ids.append(request_id) except Exception as e: @@ -1662,9 +1669,12 @@ class LLM: *, lora_request: LoRARequest | None, priority: int, + tokenization_kwargs: dict[str, Any] | None = None, ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for LLMEngine.""" - tokenization_kwargs: dict[str, Any] = {} + + local_kwargs = tokenization_kwargs or {} + tokenization_kwargs = local_kwargs.copy() _validate_truncation_size( self.model_config.max_model_len, params.truncate_prompt_tokens, @@ -1687,6 +1697,7 @@ class LLM: params: SamplingParams | PoolingParams, lora_request: LoRARequest | None = None, priority: int = 0, + tokenization_kwargs: dict[str, Any] | None = None, ) -> str: prompt_text, _, _ = get_prompt_components(prompt) request_id = str(next(self.request_counter)) @@ -1697,6 +1708,7 @@ class LLM: params, lora_request=lora_request, priority=priority, + tokenization_kwargs=tokenization_kwargs, ) self.llm_engine.add_request( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6a648822d9b2b..5d0eacae34dd7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -20,21 +20,15 @@ from http import HTTPStatus from typing import Annotated, Any, Literal import model_hosting_container_standards.sagemaker as sagemaker_standards -import prometheus_client import pydantic -import regex as re import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse -from prometheus_client import make_asgi_app -from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders, State -from starlette.routing import Mount from starlette.types import ASGIApp, Message, Receive, Scope, Send -from typing_extensions import assert_never import vllm.envs as envs from vllm.config import VllmConfig @@ -56,21 +50,15 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionResponse, CompletionRequest, CompletionResponse, - DetokenizeRequest, - DetokenizeResponse, ErrorInfo, ErrorResponse, - GenerateRequest, - GenerateResponse, ResponsesRequest, ResponsesResponse, StreamingResponsesResponse, - TokenizeRequest, - TokenizeResponse, TranscriptionRequest, - TranscriptionResponse, + TranscriptionResponseVariant, TranslationRequest, - TranslationResponse, + TranslationResponseVariant, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -80,18 +68,20 @@ from vllm.entrypoints.openai.serving_models import ( OpenAIServingModels, ) from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses -from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization -from vllm.entrypoints.openai.serving_tokens import ServingTokens from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation, ) -from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.score.serving import ServingScores +from vllm.entrypoints.serve.disagg.serving import ServingTokens +from vllm.entrypoints.serve.elastic_ep.middleware import ( + ScalingMiddleware, +) +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer from vllm.entrypoints.utils import ( cli_env_setup, @@ -104,13 +94,12 @@ from vllm.entrypoints.utils import ( from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.tasks import POOLING_TASKS +from vllm.tool_parsers import ToolParserManager from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.system_utils import decorate_logs, set_ulimit -from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION prometheus_multiproc_dir: tempfile.TemporaryDirectory @@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args( router = APIRouter() -class PrometheusResponse(Response): - media_type = prometheus_client.CONTENT_TYPE_LATEST - - -def mount_metrics(app: FastAPI): - """Mount prometheus metrics to a FastAPI app.""" - - registry = get_prometheus_registry() - - # `response_class=PrometheusResponse` is needed to return an HTTP response - # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" - # instead of the default "application/json" which is incorrect. - # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364 - Instrumentator( - excluded_handlers=[ - "/metrics", - "/health", - "/load", - "/ping", - "/version", - "/server_info", - ], - registry=registry, - ).add().instrument(app).expose(app, response_class=PrometheusResponse) - - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) - - # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$") - app.routes.append(metrics_route) - - def base(request: Request) -> OpenAIServing: # Reuse the existing instance return tokenization(request) @@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None: return request.app.state.serving_tokens -@router.get("/health", response_class=Response) -async def health(raw_request: Request) -> Response: - """Health check.""" - try: - await engine_client(raw_request).check_health() - return Response(status_code=200) - except EngineDeadError: - return Response(status_code=503) - - @router.get("/load") async def get_server_load_metrics(request: Request): # This endpoint returns the current server load metrics. @@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) -@router.post("/pause") -async def pause_generation( - raw_request: Request, - wait_for_inflight_requests: bool = Query(False), - clear_cache: bool = Query(True), -) -> JSONResponse: - """Pause generation requests to allow weight updates. - - Args: - wait_for_inflight_requests: When ``True`` waits for in-flight - requests to finish before pausing. When ``False`` (default), - aborts any in-flight requests immediately. - clear_cache: Whether to clear KV/prefix caches after draining. - """ - - engine = engine_client(raw_request) - - try: - await engine.pause_generation( - wait_for_inflight_requests=wait_for_inflight_requests, - clear_cache=clear_cache, - ) - return JSONResponse( - content={"status": "paused"}, - status_code=HTTPStatus.OK.value, - ) - - except ValueError as err: - return JSONResponse( - content={"error": str(err)}, - status_code=HTTPStatus.BAD_REQUEST.value, - ) - except Exception as err: # pragma: no cover - defensive - logger.exception("Failed to pause generation") - return JSONResponse( - content={"error": f"Failed to pause generation: {err}"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - -@router.post("/resume") -async def resume_generation(raw_request: Request) -> JSONResponse: - """Resume generation after a pause.""" - - engine = engine_client(raw_request) - - try: - await engine.resume_generation() - return JSONResponse( - content={"status": "resumed"}, - status_code=HTTPStatus.OK.value, - ) - except Exception as err: # pragma: no cover - defensive - logger.exception("Failed to resume generation") - return JSONResponse( - content={"error": f"Failed to resume generation: {err}"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - -@router.get("/is_paused") -async def is_paused(raw_request: Request) -> JSONResponse: - """Return the current pause status.""" - - engine = engine_client(raw_request) - - try: - paused = await engine.is_paused() - except Exception as err: # pragma: no cover - defensive - logger.exception("Failed to fetch pause status") - return JSONResponse( - content={"error": f"Failed to fetch pause status: {err}"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - return JSONResponse(content={"is_paused": paused}) - - -@router.post( - "/tokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -async def tokenize(request: TokenizeRequest, raw_request: Request): - handler = tokenization(raw_request) - - try: - generator = await handler.create_tokenize(request, raw_request) - except NotImplementedError as e: - raise HTTPException( - status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e) - ) from e - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - elif isinstance(generator, TokenizeResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) - - -@router.post( - "/detokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -async def detokenize(request: DetokenizeRequest, raw_request: Request): - handler = tokenization(raw_request) - - try: - generator = await handler.create_detokenize(request, raw_request) - except OverflowError as e: - raise RequestValidationError(errors=[str(e)]) from e - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - elif isinstance(generator, DetokenizeResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) - - -def maybe_register_tokenizer_info_endpoint(args): - """Conditionally register the tokenizer info endpoint if enabled.""" - if getattr(args, "enable_tokenizer_info_endpoint", False): - - @router.get("/tokenizer_info") - async def get_tokenizer_info(raw_request: Request): - """Get comprehensive tokenizer information.""" - result = await tokenization(raw_request).get_tokenizer_info() - return JSONResponse( - content=result.model_dump(), - status_code=result.error.code - if isinstance(result, ErrorResponse) - else 200, - ) - - @router.get("/v1/models") async def show_available_models(raw_request: Request): handler = models(raw_request) @@ -809,7 +594,7 @@ async def create_transcriptions( content=generator.model_dump(), status_code=generator.error.code ) - elif isinstance(generator, TranscriptionResponse): + elif isinstance(generator, TranscriptionResponseVariant): return JSONResponse(content=generator.model_dump()) return StreamingResponse(content=generator, media_type="text/event-stream") @@ -848,7 +633,7 @@ async def create_translations( content=generator.model_dump(), status_code=generator.error.code ) - elif isinstance(generator, TranslationResponse): + elif isinstance(generator, TranslationResponseVariant): return JSONResponse(content=generator.model_dump()) return StreamingResponse(content=generator, media_type="text/event-stream") @@ -877,13 +662,28 @@ if envs.VLLM_SERVER_DEV_MODE: return JSONResponse(content=server_info) @router.post("/reset_prefix_cache") - async def reset_prefix_cache(raw_request: Request): + async def reset_prefix_cache( + raw_request: Request, + reset_running_requests: bool = Query(default=False), + reset_external: bool = Query(default=False), + ): """ - Reset the prefix cache. Note that we currently do not check if the - prefix cache is successfully reset in the API server. + Reset the local prefix cache. + + Optionally, if the query parameter `reset_external=true` + also resets the external (connector-managed) prefix cache. + + Note that we currently do not check if the prefix cache + is successfully reset in the API server. + + Example: + POST /reset_prefix_cache?reset_external=true """ logger.info("Resetting prefix cache...") - await engine_client(raw_request).reset_prefix_cache() + + await engine_client(raw_request).reset_prefix_cache( + reset_running_requests, reset_external + ) return Response(status_code=200) @router.post("/reset_mm_cache") @@ -896,33 +696,6 @@ if envs.VLLM_SERVER_DEV_MODE: await engine_client(raw_request).reset_mm_cache() return Response(status_code=200) - @router.post("/sleep") - async def sleep(raw_request: Request): - # get POST params - level = raw_request.query_params.get("level", "1") - await engine_client(raw_request).sleep(int(level)) - # FIXME: in v0 with frontend multiprocessing, the sleep command - # is sent but does not finish yet when we return a response. - return Response(status_code=200) - - @router.post("/wake_up") - async def wake_up(raw_request: Request): - tags = raw_request.query_params.getlist("tags") - if tags == []: - # set to None to wake up all tags if no tags are provided - tags = None - logger.info("wake up the engine with tags: %s", tags) - await engine_client(raw_request).wake_up(tags) - # FIXME: in v0 with frontend multiprocessing, the wake-up command - # is sent but does not finish yet when we return a response. - return Response(status_code=200) - - @router.get("/is_sleeping") - async def is_sleeping(raw_request: Request): - logger.info("check whether the engine is sleeping") - is_sleeping = await engine_client(raw_request).is_sleeping() - return JSONResponse(content={"is_sleeping": is_sleeping}) - @router.post("/collective_rpc") async def collective_rpc(raw_request: Request): try: @@ -950,138 +723,13 @@ if envs.VLLM_SERVER_DEV_MODE: return Response(status_code=200) response: list[Any] = [] for result in results: - if result is None or isinstance(result, (dict, list)): + if result is None or isinstance(result, dict | list): response.append(result) else: response.append(str(result)) return JSONResponse(content={"results": response}) -@router.post( - "/scale_elastic_ep", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"model": dict}, - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -async def scale_elastic_ep(raw_request: Request): - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 - - new_data_parallel_size = body.get("new_data_parallel_size") - drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes - - if new_data_parallel_size is None: - raise HTTPException( - status_code=400, detail="new_data_parallel_size is required" - ) - - if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0: - raise HTTPException( - status_code=400, detail="new_data_parallel_size must be a positive integer" - ) - - if not isinstance(drain_timeout, int) or drain_timeout <= 0: - raise HTTPException( - status_code=400, detail="drain_timeout must be a positive integer" - ) - - # Set scaling flag to prevent new requests - global _scaling_elastic_ep - _scaling_elastic_ep = True - client = engine_client(raw_request) - try: - await client.scale_elastic_ep(new_data_parallel_size, drain_timeout) - return JSONResponse( - { - "message": f"Scaled to {new_data_parallel_size} data parallel engines", - } - ) - except TimeoutError as e: - raise HTTPException( - status_code=408, - detail="Scale failed due to request drain timeout " - f"after {drain_timeout} seconds", - ) from e - except Exception as e: - logger.error("Scale failed: %s", e) - raise HTTPException(status_code=500, detail="Scale failed") from e - finally: - _scaling_elastic_ep = False - - -@router.post("/is_scaling_elastic_ep") -async def is_scaling_elastic_ep(raw_request: Request): - return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep}) - - -@router.post( - "/inference/v1/generate", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -@load_aware_call -async def generate(request: GenerateRequest, raw_request: Request): - handler = generate_tokens(raw_request) - if handler is None: - return base(raw_request).create_error_response( - message="The model does not support generate tokens API" - ) - try: - generator = await handler.serve_tokens(request, raw_request) - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - - elif isinstance(generator, GenerateResponse): - return JSONResponse(content=generator.model_dump()) - - return StreamingResponse(content=generator, media_type="text/event-stream") - - -if envs.VLLM_TORCH_PROFILER_DIR: - logger.warning_once( - "Torch Profiler is enabled in the API server. This should ONLY be " - "used for local development!" - ) -elif envs.VLLM_TORCH_CUDA_PROFILE: - logger.warning_once( - "CUDA Profiler is enabled in the API server. This should ONLY be " - "used for local development!" - ) -if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE: - - @router.post("/start_profile") - async def start_profile(raw_request: Request): - logger.info("Starting profiler...") - await engine_client(raw_request).start_profile() - logger.info("Profiler started.") - return Response(status_code=200) - - @router.post("/stop_profile") - async def stop_profile(raw_request: Request): - logger.info("Stopping profiler...") - await engine_client(raw_request).stop_profile() - logger.info("Profiler stopped.") - return Response(status_code=200) - - def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None @@ -1174,41 +822,6 @@ class XRequestIdMiddleware: return self.app(scope, receive, send_with_request_id) -# Global variable to track scaling state -_scaling_elastic_ep = False - - -class ScalingMiddleware: - """ - Middleware that checks if the model is currently scaling and - returns a 503 Service Unavailable response if it is. - - This middleware applies to all HTTP requests and prevents - processing when the model is in a scaling state. - """ - - def __init__(self, app: ASGIApp) -> None: - self.app = app - - def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: - if scope["type"] != "http": - return self.app(scope, receive, send) - - # Check global scaling state - global _scaling_elastic_ep - if _scaling_elastic_ep: - # Return 503 Service Unavailable response - response = JSONResponse( - content={ - "error": "The model is currently scaling. Please try again later." - }, - status_code=503, - ) - return response(scope, receive, send) - - return self.app(scope, receive, send) - - def _extract_content_from_chunk(chunk_data: dict) -> str: """Extract content from a streaming response chunk.""" try: @@ -1351,15 +964,10 @@ def build_app(args: Namespace) -> FastAPI: ) else: app = FastAPI(lifespan=lifespan) + app.state.args = args + from vllm.entrypoints.serve import register_vllm_serve_api_routers - if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - logger.warning( - "LoRA dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!" - ) - from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes - - register_dynamic_lora_routes(router) + register_vllm_serve_api_routers(app) from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes @@ -1368,8 +976,6 @@ def build_app(args: Namespace) -> FastAPI: app.root_path = args.root_path - mount_metrics(app) - from vllm.entrypoints.pooling import register_pooling_api_routers register_pooling_api_routers(app) @@ -1460,31 +1066,6 @@ def build_app(args: Namespace) -> FastAPI: ) app = sagemaker_standards.bootstrap(app) - # Optional endpoints - if args.tokens_only: - - @app.post("/abort_requests") - async def abort_requests(raw_request: Request): - """ - Abort one or more requests. To be used in a - Disaggregated Everything setup. - """ - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}", - ) from e - request_ids = body.get("request_ids") - if request_ids is None: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Missing 'request_ids' in request body", - ) - # Abort requests in background - asyncio.create_task(engine_client(raw_request).abort(request_ids)) - return Response(status_code=200) return app @@ -1513,7 +1094,7 @@ async def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.vllm_config = vllm_config - + state.args = args supported_tasks = await engine_client.get_supported_tasks() logger.info("Supported tasks: %s", supported_tasks) @@ -1837,7 +1418,6 @@ async def run_server_worker( args, client_config=client_config, ) as engine_client: - maybe_register_tokenizer_info_endpoint(args) app = build_app(args) await init_app_state(engine_client, app.state, args) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 946362ce2ef0a..a8eef76cd8ae4 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -27,8 +27,8 @@ from vllm.entrypoints.constants import ( H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT, ) from vllm.entrypoints.openai.serving_models import LoRAModulePath -from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger +from vllm.tool_parsers import ToolParserManager from vllm.utils.argparse_utils import FlexibleArgumentParser logger = init_logger(__name__) @@ -176,7 +176,7 @@ class FrontendArgs: enable_force_include_usage: bool = False """If set to True, including usage on every request.""" enable_tokenizer_info_endpoint: bool = False - """Enable the /get_tokenizer_info endpoint. May expose chat + """Enable the `/tokenizer_info` endpoint. May expose chat templates and other tokenizer configuration.""" enable_log_outputs: bool = False """If True, log model outputs (generations). diff --git a/vllm/entrypoints/openai/parser/__init__.py b/vllm/entrypoints/openai/parser/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/openai/parser/harmony_utils.py b/vllm/entrypoints/openai/parser/harmony_utils.py new file mode 100644 index 0000000000000..376d97a03964e --- /dev/null +++ b/vllm/entrypoints/openai/parser/harmony_utils.py @@ -0,0 +1,825 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import datetime +import json +from collections.abc import Iterable, Sequence +from typing import Literal + +from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseOutputItem, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, +) +from openai.types.responses.response_function_web_search import ( + ActionFind, + ActionOpenPage, + ActionSearch, + ResponseFunctionWebSearch, +) +from openai.types.responses.response_output_item import McpCall +from openai.types.responses.response_reasoning_item import ( + Content as ResponseReasoningTextContent, +) +from openai.types.responses.tool import Tool +from openai_harmony import ( + Author, + ChannelConfig, + Conversation, + DeveloperContent, + HarmonyEncodingName, + Message, + ReasoningEffort, + Role, + StreamableParser, + SystemContent, + TextContent, + ToolDescription, + load_harmony_encoding, +) +from openai_harmony import Message as OpenAIHarmonyMessage +from openai_harmony import Role as OpenAIHarmonyRole + +from vllm import envs +from vllm.entrypoints.openai.protocol import ( + ChatCompletionToolsParam, + ResponseInputOutputItem, + ResponsesRequest, +) +from vllm.utils import random_uuid + +REASONING_EFFORT = { + "high": ReasoningEffort.HIGH, + "medium": ReasoningEffort.MEDIUM, + "low": ReasoningEffort.LOW, +} + +_harmony_encoding = None + +# Builtin tools that should be included in the system message when +# they are available and requested by the user. +# Tool args are provided by MCP tool descriptions. Output +# of the tools are stringified. +MCP_BUILTIN_TOOLS: set[str] = { + "web_search_preview", + "code_interpreter", + "container", +} + + +def has_custom_tools(tool_types: set[str]) -> bool: + """ + Checks if the given tool types are custom tools + (i.e. any tool other than MCP buildin tools) + """ + return not tool_types.issubset(MCP_BUILTIN_TOOLS) + + +def get_encoding(): + global _harmony_encoding + if _harmony_encoding is None: + _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + return _harmony_encoding + + +def get_system_message( + model_identity: str | None = None, + reasoning_effort: Literal["high", "medium", "low"] | None = None, + start_date: str | None = None, + browser_description: str | None = None, + python_description: str | None = None, + container_description: str | None = None, + instructions: str | None = None, + with_custom_tools: bool = False, +) -> Message: + sys_msg_content = SystemContent.new() + if model_identity is not None: + sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: + current_identity = sys_msg_content.model_identity + new_identity = ( + f"{current_identity}\n{instructions}" if current_identity else instructions + ) + sys_msg_content = sys_msg_content.with_model_identity(new_identity) + if reasoning_effort is not None: + sys_msg_content = sys_msg_content.with_reasoning_effort( + REASONING_EFFORT[reasoning_effort] + ) + if start_date is None: + # NOTE(woosuk): This brings non-determinism in vLLM. Be careful. + start_date = datetime.datetime.now().strftime("%Y-%m-%d") + sys_msg_content = sys_msg_content.with_conversation_start_date(start_date) + if browser_description is not None: + sys_msg_content = sys_msg_content.with_tools(browser_description) + if python_description is not None: + sys_msg_content = sys_msg_content.with_tools(python_description) + if container_description is not None: + sys_msg_content = sys_msg_content.with_tools(container_description) + if not with_custom_tools: + channel_config = sys_msg_content.channel_config + invalid_channel = "commentary" + new_config = ChannelConfig.require_channels( + [c for c in channel_config.valid_channels if c != invalid_channel] + ) + sys_msg_content = sys_msg_content.with_channel_config(new_config) + sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) + return sys_msg + + +def create_tool_definition(tool: ChatCompletionToolsParam | Tool): + if isinstance(tool, ChatCompletionToolsParam): + return ToolDescription.new( + name=tool.function.name, + description=tool.function.description, + parameters=tool.function.parameters, + ) + return ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) + + +def get_developer_message( + instructions: str | None = None, + tools: list[Tool | ChatCompletionToolsParam] | None = None, +) -> Message: + dev_msg_content = DeveloperContent.new() + if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: + dev_msg_content = dev_msg_content.with_instructions(instructions) + if tools is not None: + function_tools: list[Tool | ChatCompletionToolsParam] = [] + for tool in tools: + if tool.type in ( + "web_search_preview", + "code_interpreter", + "container", + ): + pass + + elif tool.type == "function": + function_tools.append(tool) + else: + raise ValueError(f"tool type {tool.type} not supported") + if function_tools: + function_tool_descriptions = [ + create_tool_definition(tool) for tool in function_tools + ] + dev_msg_content = dev_msg_content.with_function_tools( + function_tool_descriptions + ) + dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) + return dev_msg + + +def get_user_message(content: str) -> Message: + return Message.from_role_and_content(Role.USER, content) + + +def parse_response_input( + response_msg: ResponseInputOutputItem, + prev_responses: list[ResponseOutputItem | ResponseReasoningItem], +) -> Message: + if not isinstance(response_msg, dict): + response_msg = response_msg.model_dump() + if "type" not in response_msg or response_msg["type"] == "message": + role = response_msg["role"] + content = response_msg["content"] + if role == "system": + # User is trying to set a system message. Change it to: + # <|start|>developer<|message|># Instructions + # {instructions}<|end|> + role = "developer" + text_prefix = "Instructions:\n" + else: + text_prefix = "" + if isinstance(content, str): + msg = Message.from_role_and_content(role, text_prefix + content) + else: + contents = [TextContent(text=text_prefix + c["text"]) for c in content] + msg = Message.from_role_and_contents(role, contents) + if role == "assistant": + msg = msg.with_channel("final") + elif response_msg["type"] == "function_call_output": + call_id = response_msg["call_id"] + call_response: ResponseFunctionToolCall | None = None + for prev_response in reversed(prev_responses): + if ( + isinstance(prev_response, ResponseFunctionToolCall) + and prev_response.call_id == call_id + ): + call_response = prev_response + break + if call_response is None: + raise ValueError(f"No call message found for {call_id}") + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{call_response.name}"), + response_msg["output"], + ) + elif response_msg["type"] == "reasoning": + content = response_msg["content"] + assert len(content) == 1 + msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) + elif response_msg["type"] == "function_call": + msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"]) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{response_msg['name']}") + msg = msg.with_content_type("json") + else: + raise ValueError(f"Unknown input type: {response_msg['type']}") + return msg + + +def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]: + """ + Parse a list of messages from request.messages in the Chat Completion API to + Harmony messages. + """ + msgs: list[Message] = [] + tool_id_names: dict[str, str] = {} + + # Collect tool id to name mappings for tool response recipient values + for chat_msg in chat_msgs: + for tool_call in chat_msg.get("tool_calls", []): + tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get( + "name" + ) + + for chat_msg in chat_msgs: + msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names)) + + msgs = auto_drop_analysis_messages(msgs) + return msgs + + +def auto_drop_analysis_messages(msgs: list[Message]) -> list[Message]: + """ + Harmony models expect the analysis messages (representing raw chain of thought) to + be dropped after an assistant message to the final channel is produced from the + reasoning of those messages. + + The openai-harmony library does this if the very last assistant message is to the + final channel, but it does not handle the case where we're in longer multi-turn + conversations and the client gave us reasoning content from previous turns of + the conversation with multiple assistant messages to the final channel in the + conversation. + + So, we find the index of the last assistant message to the final channel and drop + all analysis messages that precede it, leaving only the analysis messages that + are relevant to the current part of the conversation. + """ + last_assistant_final_index = -1 + for i in range(len(msgs) - 1, -1, -1): + msg = msgs[i] + if msg.author.role == "assistant" and msg.channel == "final": + last_assistant_final_index = i + break + + cleaned_msgs: list[Message] = [] + for i, msg in enumerate(msgs): + if i < last_assistant_final_index and msg.channel == "analysis": + continue + cleaned_msgs.append(msg) + + return cleaned_msgs + + +def flatten_chat_text_content(content: str | list | None) -> str | None: + """ + Extract the text parts from a chat message content field and flatten them + into a single string. + """ + if isinstance(content, list): + return "".join( + item.get("text", "") + for item in content + if isinstance(item, dict) and item.get("type") == "text" + ) + return content + + +def parse_chat_input_to_harmony_message( + chat_msg, tool_id_names: dict[str, str] | None = None +) -> list[Message]: + """ + Parse a message from request.messages in the Chat Completion API to + Harmony messages. + """ + tool_id_names = tool_id_names or {} + + if not isinstance(chat_msg, dict): + # Handle Pydantic models + chat_msg = chat_msg.model_dump(exclude_none=True) + + role = chat_msg.get("role") + msgs: list[Message] = [] + + # Assistant message with tool calls + tool_calls = chat_msg.get("tool_calls", []) + + if role == "assistant" and tool_calls: + content = flatten_chat_text_content(chat_msg.get("content")) + if content: + commentary_msg = Message.from_role_and_content(Role.ASSISTANT, content) + commentary_msg = commentary_msg.with_channel("commentary") + msgs.append(commentary_msg) + + reasoning_content = chat_msg.get("reasoning") or chat_msg.get( + "reasoning_content" + ) + if reasoning_content: + analysis_msg = Message.from_role_and_content( + Role.ASSISTANT, reasoning_content + ) + analysis_msg = analysis_msg.with_channel("analysis") + msgs.append(analysis_msg) + + for call in tool_calls: + func = call.get("function", {}) + name = func.get("name", "") + arguments = func.get("arguments", "") or "" + msg = Message.from_role_and_content(Role.ASSISTANT, arguments) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{name}") + # Officially, this should be `<|constrain|>json` but there is not clear + # evidence that improves accuracy over `json` and some anecdotes to the + # contrary. Further testing of the different content_types is needed. + msg = msg.with_content_type("json") + msgs.append(msg) + return msgs + + # Tool role message (tool output) + if role == "tool": + tool_call_id = chat_msg.get("tool_call_id", "") + name = tool_id_names.get(tool_call_id, "") + content = chat_msg.get("content", "") or "" + content = flatten_chat_text_content(content) + + msg = ( + Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{name}"), content + ) + .with_channel("commentary") + .with_recipient("assistant") + ) + return [msg] + + # Non-tool reasoning content + reasoning_content = chat_msg.get("reasoning") or chat_msg.get("reasoning_content") + if role == "assistant" and reasoning_content: + analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning_content) + analysis_msg = analysis_msg.with_channel("analysis") + msgs.append(analysis_msg) + + # Default: user/assistant/system messages with content + content = chat_msg.get("content") or "" + if content is None: + content = "" + if isinstance(content, str): + contents = [TextContent(text=content)] + else: + # TODO: Support refusal. + contents = [TextContent(text=c.get("text", "")) for c in content] + + # Only add assistant messages if they have content, as reasoning or tool calling + # assistant messages were already added above. + if role == "assistant" and contents and contents[0].text: + msg = Message.from_role_and_contents(role, contents) + # Send non-tool assistant messages to the final channel + msg = msg.with_channel("final") + msgs.append(msg) + # For user/system/developer messages, add them directly even if no content. + elif role != "assistant": + msg = Message.from_role_and_contents(role, contents) + msgs.append(msg) + + return msgs + + +def parse_input_to_harmony_message(chat_msg) -> list[Message]: + """ + Parse a message from request.previous_input_messages in the Responsees API to + Harmony messages. + """ + if not isinstance(chat_msg, dict): + # Handle Pydantic models + chat_msg = chat_msg.model_dump(exclude_none=True) + + role = chat_msg.get("role") + + # Assistant message with tool calls + tool_calls = chat_msg.get("tool_calls") + if role == "assistant" and tool_calls: + msgs: list[Message] = [] + for call in tool_calls: + func = call.get("function", {}) + name = func.get("name", "") + arguments = func.get("arguments", "") or "" + msg = Message.from_role_and_content(Role.ASSISTANT, arguments) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{name}") + msg = msg.with_content_type("json") + msgs.append(msg) + return msgs + + # Tool role message (tool output) + if role == "tool": + name = chat_msg.get("name", "") + content = chat_msg.get("content", "") or "" + content = flatten_chat_text_content(content) + + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{name}"), content + ).with_channel("commentary") + return [msg] + + # Default: user/assistant/system messages with content + content = chat_msg.get("content", "") + if isinstance(content, str): + contents = [TextContent(text=content)] + else: + # TODO: Support refusal. + contents = [TextContent(text=c.get("text", "")) for c in content] + msg = Message.from_role_and_contents(role, contents) + return [msg] + + +def construct_harmony_previous_input_messages( + request: ResponsesRequest, +) -> list[OpenAIHarmonyMessage]: + messages: list[OpenAIHarmonyMessage] = [] + if request.previous_input_messages: + for message in request.previous_input_messages: + # Handle both OpenAIHarmonyMessage objects and dictionary inputs + if isinstance(message, OpenAIHarmonyMessage): + message_role = message.author.role + # To match OpenAI, instructions, reasoning and tools are + # always taken from the most recent Responses API request + # not carried over from previous requests + if ( + message_role == OpenAIHarmonyRole.SYSTEM + or message_role == OpenAIHarmonyRole.DEVELOPER + ): + continue + messages.append(message) + else: + harmony_messages = parse_input_to_harmony_message(message) + for harmony_msg in harmony_messages: + message_role = harmony_msg.author.role + # To match OpenAI, instructions, reasoning and tools are + # always taken from the most recent Responses API request + # not carried over from previous requests + if ( + message_role == OpenAIHarmonyRole.SYSTEM + or message_role == OpenAIHarmonyRole.DEVELOPER + ): + continue + messages.append(harmony_msg) + return messages + + +def render_for_completion(messages: list[Message]) -> list[int]: + conversation = Conversation.from_messages(messages) + token_ids = get_encoding().render_conversation_for_completion( + conversation, Role.ASSISTANT + ) + return token_ids + + +def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutputItem: + """Parse browser tool calls (search, open, find) into web search items.""" + if len(message.content) != 1: + raise ValueError("Invalid number of contents in browser message") + content = message.content[0] + + # Parse JSON args (with retry detection) + try: + browser_call = json.loads(content.text) + except json.JSONDecodeError: + json_retry_output_message = ( + f"Invalid JSON args, caught and retried: {content.text}" + ) + browser_call = { + "query": json_retry_output_message, + "url": json_retry_output_message, + "pattern": json_retry_output_message, + } + + # Create appropriate action based on recipient + if recipient == "browser.search": + action = ActionSearch( + query=f"cursor:{browser_call.get('query', '')}", type="search" + ) + elif recipient == "browser.open": + action = ActionOpenPage( + url=f"cursor:{browser_call.get('url', '')}", type="open_page" + ) + elif recipient == "browser.find": + action = ActionFind( + pattern=browser_call.get("pattern", ""), + url=f"cursor:{browser_call.get('url', '')}", + type="find", + ) + else: + raise ValueError(f"Unknown browser action: {recipient}") + + return ResponseFunctionWebSearch( + id=f"ws_{random_uuid()}", + action=action, + status="completed", + type="web_search_call", + ) + + +def _parse_function_call(message: Message, recipient: str) -> list[ResponseOutputItem]: + """Parse function calls into function tool call items.""" + function_name = recipient.split(".")[-1] + output_items = [] + for content in message.content: + random_id = random_uuid() + response_item = ResponseFunctionToolCall( + arguments=content.text, + call_id=f"call_{random_id}", + type="function_call", + name=function_name, + id=f"fc_{random_id}", + ) + output_items.append(response_item) + return output_items + + +def _parse_reasoning_content(message: Message) -> list[ResponseOutputItem]: + """Parse reasoning/analysis content into reasoning items.""" + output_items = [] + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[ + ResponseReasoningTextContent(text=content.text, type="reasoning_text") + ], + status=None, + ) + output_items.append(reasoning_item) + return output_items + + +def _parse_final_message(message: Message) -> ResponseOutputItem: + """Parse final channel messages into output message items.""" + contents = [] + for content in message.content: + output_text = ResponseOutputText( + text=content.text, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + contents.append(output_text) + return ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=contents, + role=message.author.role, + status="completed", + type="message", + ) + + +def _parse_mcp_recipient(recipient: str) -> tuple[str, str]: + """ + Parse MCP recipient into (server_label, tool_name). + + For dotted recipients like "repo_browser.list": + - server_label: "repo_browser" (namespace/server) + - tool_name: "list" (specific tool) + + For simple recipients like "filesystem": + - server_label: "filesystem" + - tool_name: "filesystem" + """ + if "." in recipient: + server_label = recipient.split(".")[0] + tool_name = recipient.split(".")[-1] + else: + server_label = recipient + tool_name = recipient + return server_label, tool_name + + +def _parse_mcp_call(message: Message, recipient: str) -> list[ResponseOutputItem]: + """Parse MCP calls into MCP call items.""" + server_label, tool_name = _parse_mcp_recipient(recipient) + output_items = [] + for content in message.content: + response_item = McpCall( + arguments=content.text, + type="mcp_call", + name=tool_name, + server_label=server_label, + id=f"mcp_{random_uuid()}", + status="completed", + ) + output_items.append(response_item) + return output_items + + +def parse_output_message(message: Message) -> list[ResponseOutputItem]: + """ + Parse a Harmony message into a list of output response items. + """ + if message.author.role != "assistant": + # This is a message from a tool to the assistant (e.g., search result). + # Don't include it in the final output for now. This aligns with + # OpenAI's behavior on models like o4-mini. + return [] + + output_items: list[ResponseOutputItem] = [] + recipient = message.recipient + + if recipient is not None: + # Browser tool calls + if recipient.startswith("browser."): + output_items.append(_parse_browser_tool_call(message, recipient)) + + # Function calls (should only happen on commentary channel) + elif message.channel == "commentary" and recipient.startswith("functions."): + output_items.extend(_parse_function_call(message, recipient)) + + # Built-in tools are treated as reasoning + elif recipient.startswith(("python", "browser", "container")): + # Built-in tool recipients (python/browser/container) + # generate reasoning output + output_items.extend(_parse_reasoning_content(message)) + + # All other recipients are MCP calls + else: + output_items.extend(_parse_mcp_call(message, recipient)) + + # No recipient - handle based on channel for non-tool messages + elif message.channel == "analysis": + output_items.extend(_parse_reasoning_content(message)) + + elif message.channel == "commentary": + # Per Harmony format, commentary channel can contain preambles to calling + # multiple functions - explanatory text with no recipient + output_items.extend(_parse_reasoning_content(message)) + + elif message.channel == "final": + output_items.append(_parse_final_message(message)) + + else: + raise ValueError(f"Unknown channel: {message.channel}") + + return output_items + + +def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]: + if not parser.current_content: + return [] + if parser.current_role != Role.ASSISTANT: + return [] + current_recipient = parser.current_recipient + if current_recipient is not None and current_recipient.startswith("browser."): + return [] + + if current_recipient and parser.current_channel in ("commentary", "analysis"): + if current_recipient.startswith("functions."): + rid = random_uuid() + return [ + ResponseFunctionToolCall( + arguments=parser.current_content, + call_id=f"call_{rid}", + type="function_call", + name=current_recipient.split(".")[-1], + id=f"fc_{rid}", + status="in_progress", + ) + ] + # Built-in tools (python, browser, container) should be treated as reasoning + elif not ( + current_recipient.startswith("python") + or current_recipient.startswith("browser") + or current_recipient.startswith("container") + ): + # All other recipients are MCP calls + rid = random_uuid() + server_label, tool_name = _parse_mcp_recipient(current_recipient) + return [ + McpCall( + arguments=parser.current_content, + type="mcp_call", + name=tool_name, + server_label=server_label, + id=f"mcp_{rid}", + status="in_progress", + ) + ] + + if parser.current_channel == "commentary": + return [ + ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=parser.current_content, type="reasoning_text" + ) + ], + status=None, + ) + ] + + if parser.current_channel == "analysis": + return [ + ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=parser.current_content, type="reasoning_text" + ) + ], + status=None, + ) + ] + + if parser.current_channel == "final": + output_text = ResponseOutputText( + text=parser.current_content, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + text_item = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=[output_text], + role="assistant", + # if the parser still has messages (ie if the generator got cut + # abruptly), this should be incomplete + status="incomplete", + type="message", + ) + return [text_item] + + return [] + + +def get_stop_tokens_for_assistant_actions() -> list[int]: + return get_encoding().stop_tokens_for_assistant_actions() + + +def get_streamable_parser_for_assistant() -> StreamableParser: + return StreamableParser(get_encoding(), role=Role.ASSISTANT) + + +def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: + parser = get_streamable_parser_for_assistant() + for token_id in token_ids: + parser.process(token_id) + return parser + + +def parse_chat_output( + token_ids: Sequence[int], +) -> tuple[str | None, str | None, bool]: + """ + Parse the output of a Harmony chat completion into reasoning and final content. + Note that when the `openai` tool parser is used, serving_chat only uses this + for the reasoning content and gets the final content from the tool call parser. + + When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is + in use,this needs to return the final content without any tool calls parsed. + + Empty reasoning or final content is returned as None instead of an empty string. + """ + parser = parse_output_into_messages(token_ids) + output_msgs = parser.messages + is_tool_call = False # TODO: update this when tool call is supported + + # Get completed messages from the parser + reasoning_texts = [ + msg.content[0].text for msg in output_msgs if msg.channel == "analysis" + ] + final_texts = [ + msg.content[0].text for msg in output_msgs if msg.channel != "analysis" + ] + + # Extract partial messages from the parser + if parser.current_channel == "analysis" and parser.current_content: + reasoning_texts.append(parser.current_content) + elif parser.current_channel != "analysis" and parser.current_content: + final_texts.append(parser.current_content) + + # Flatten multiple messages into a single string + reasoning: str | None = "\n".join(reasoning_texts) + final_content: str | None = "\n".join(final_texts) + + # Return None instead of empty string since existing callers check for None + reasoning = reasoning or None + final_content = final_content or None + + return reasoning, final_content, is_tool_call diff --git a/vllm/entrypoints/openai/parser/responses_parser.py b/vllm/entrypoints/openai/parser/responses_parser.py new file mode 100644 index 0000000000000..c364d6d80544d --- /dev/null +++ b/vllm/entrypoints/openai/parser/responses_parser.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +from collections.abc import Callable + +from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem +from openai.types.responses.response_function_tool_call_output_item import ( + ResponseFunctionToolCallOutputItem, +) +from openai.types.responses.response_output_item import McpCall +from openai.types.responses.response_output_message import ResponseOutputMessage +from openai.types.responses.response_output_text import ResponseOutputText +from openai.types.responses.response_reasoning_item import ( + Content, + ResponseReasoningItem, +) + +from vllm.entrypoints.constants import MCP_PREFIX +from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest +from vllm.outputs import CompletionOutput +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser +from vllm.tokenizers.protocol import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = logging.getLogger(__name__) + + +class ResponsesParser: + """Incremental parser over completion tokens with reasoning support.""" + + def __init__( + self, + *, + tokenizer: AnyTokenizer, + reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser], + response_messages: list[ResponseInputOutputItem], + request: ResponsesRequest, + tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, + ): + self.response_messages: list[ResponseInputOutputItem] = ( + # TODO: initial messages may not be properly typed + response_messages + ) + self.num_init_messages = len(response_messages) + self.tokenizer = tokenizer + self.request = request + + self.reasoning_parser_instance = reasoning_parser_cls(tokenizer) + self.tool_parser_instance = None + if tool_parser_cls is not None: + self.tool_parser_instance = tool_parser_cls(tokenizer) + + def process(self, output: CompletionOutput) -> "ResponsesParser": + reasoning_content, content = self.reasoning_parser_instance.extract_reasoning( + output.text, request=self.request + ) + if reasoning_content: + self.response_messages.append( + ResponseReasoningItem( + type="reasoning", + id=f"rs_{random_uuid()}", + summary=[], + content=[ + Content( + type="reasoning_text", + text=reasoning_content, + ) + ], + ) + ) + + function_calls: list[ResponseFunctionToolCall] = [] + if self.tool_parser_instance is not None: + tool_call_info = self.tool_parser_instance.extract_tool_calls( + content if content is not None else "", + request=self.request, # type: ignore + ) + if tool_call_info is not None and tool_call_info.tools_called: + # extract_tool_calls() returns a list of tool calls. + function_calls.extend( + ResponseFunctionToolCall( + id=f"fc_{random_uuid()}", + call_id=f"call_{random_uuid()}", + type="function_call", + status="completed", + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_call_info.tool_calls + ) + content = tool_call_info.content + if content and content.strip() == "": + content = None + + if content: + self.response_messages.append( + ResponseOutputMessage( + type="message", + id=f"msg_{random_uuid()}", + status="completed", + role="assistant", + content=[ + ResponseOutputText( + annotations=[], # TODO + type="output_text", + text=content, + logprobs=None, # TODO + ) + ], + ) + ) + if len(function_calls) > 0: + self.response_messages.extend(function_calls) + + return self + + def make_response_output_items_from_parsable_context( + self, + ) -> list[ResponseOutputItem]: + """Given a list of sentences, construct ResponseOutput Items.""" + response_messages = self.response_messages[self.num_init_messages :] + output_messages: list[ResponseOutputItem] = [] + for message in response_messages: + if not isinstance(message, ResponseFunctionToolCallOutputItem): + output_messages.append(message) + else: + if len(output_messages) == 0: + raise ValueError( + "Cannot have a FunctionToolCallOutput before FunctionToolCall." + ) + if isinstance(output_messages[-1], ResponseFunctionToolCall): + mcp_message = McpCall( + id=f"{MCP_PREFIX}{random_uuid()}", + arguments=output_messages[-1].arguments, + name=output_messages[-1].name, + server_label=output_messages[ + -1 + ].name, # TODO: store the server label + type="mcp_call", + status="completed", + output=message.output, + # TODO: support error output + ) + output_messages[-1] = mcp_message + + return output_messages + + +def get_responses_parser_for_simple_context( + *, + tokenizer: AnyTokenizer, + reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser], + response_messages: list[ResponseInputOutputItem], + request: ResponsesRequest, + tool_parser_cls, +) -> ResponsesParser: + """Factory function to create a ResponsesParser with + optional reasoning parser. + + Returns: + ResponsesParser instance configured with the provided parser + """ + return ResponsesParser( + tokenizer=tokenizer, + reasoning_parser_cls=reasoning_parser_cls, + response_messages=response_messages, + request=request, + tool_parser_cls=tool_parser_cls, + ) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fb73416f45b24..a7c4980cd3674 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -25,6 +25,10 @@ from openai.types.responses import ( ResponseContentPartDoneEvent, ResponseFunctionToolCall, ResponseInputItemParam, + ResponseMcpCallArgumentsDeltaEvent, + ResponseMcpCallArgumentsDoneEvent, + ResponseMcpCallCompletedEvent, + ResponseMcpCallInProgressEvent, ResponseOutputItem, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, @@ -316,6 +320,7 @@ class ResponsesRequest(OpenAIBaseModel): max_tool_calls: int | None = None metadata: Metadata | None = None model: str | None = None + logit_bias: dict[str, float] | None = None parallel_tool_calls: bool | None = True previous_response_id: str | None = None prompt: ResponsePrompt | None = None @@ -329,6 +334,7 @@ class ResponsesRequest(OpenAIBaseModel): tools: list[Tool] = Field(default_factory=list) top_logprobs: int | None = 0 top_p: float | None = None + top_k: int | None = None truncation: Literal["auto", "disabled"] | None = "disabled" user: str | None = None @@ -383,6 +389,7 @@ class ResponsesRequest(OpenAIBaseModel): _DEFAULT_SAMPLING_PARAMS = { "temperature": 1.0, "top_p": 1.0, + "top_k": 0, } def to_sampling_params( @@ -404,6 +411,10 @@ class ResponsesRequest(OpenAIBaseModel): top_p = default_sampling_params.get( "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] ) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output @@ -424,6 +435,7 @@ class ResponsesRequest(OpenAIBaseModel): return SamplingParams.from_optional( temperature=temperature, top_p=top_p, + top_k=top_k, max_tokens=max_tokens, logprobs=self.top_logprobs if self.is_include_output_logprobs() else None, stop_token_ids=stop_token_ids, @@ -431,6 +443,7 @@ class ResponsesRequest(OpenAIBaseModel): RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY ), structured_outputs=structured_outputs, + logit_bias=self.logit_bias, ) def is_include_output_logprobs(self) -> bool: @@ -1598,6 +1611,20 @@ def serialize_messages(msgs): return [serialize_message(msg) for msg in msgs] if msgs else None +class ResponseRawMessageAndToken(OpenAIBaseModel): + """Class to show the raw message. + If message / tokens diverge, tokens is the source of truth""" + + message: str + tokens: list[int] + type: Literal["raw_message_tokens"] = "raw_message_tokens" + + +ResponseInputOutputMessage: TypeAlias = ( + list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken] +) + + class ResponsesResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") created_at: int = Field(default_factory=lambda: int(time.time())) @@ -1631,8 +1658,8 @@ class ResponsesResponse(OpenAIBaseModel): # These are populated when enable_response_messages is set to True # NOTE: custom serialization is needed # see serialize_input_messages and serialize_output_messages - input_messages: list[ChatCompletionMessageParam] | None = None - output_messages: list[ChatCompletionMessageParam] | None = None + input_messages: ResponseInputOutputMessage | None = None + output_messages: ResponseInputOutputMessage | None = None # --8<-- [end:responses-extra-params] # NOTE: openAI harmony doesn't serialize TextContent properly, @@ -1658,8 +1685,8 @@ class ResponsesResponse(OpenAIBaseModel): output: list[ResponseOutputItem], status: ResponseStatus, usage: ResponseUsage | None = None, - input_messages: list[ChatCompletionMessageParam] | None = None, - output_messages: list[ChatCompletionMessageParam] | None = None, + input_messages: ResponseInputOutputMessage | None = None, + output_messages: ResponseInputOutputMessage | None = None, ) -> "ResponsesResponse": incomplete_details: IncompleteDetails | None = None if status == "incomplete": @@ -1776,6 +1803,10 @@ StreamingResponsesResponse: TypeAlias = ( | ResponseCodeInterpreterCallCodeDoneEvent | ResponseCodeInterpreterCallInterpretingEvent | ResponseCodeInterpreterCallCompletedEvent + | ResponseMcpCallArgumentsDeltaEvent + | ResponseMcpCallArgumentsDoneEvent + | ResponseMcpCallInProgressEvent + | ResponseMcpCallCompletedEvent ) @@ -2126,13 +2157,13 @@ class TranscriptionSegment(OpenAIBaseModel): id: int """Unique identifier of the segment.""" - avg_logprob: float + avg_logprob: float | None = None """Average logprob of the segment. If the value is lower than -1, consider the logprobs failed. """ - compression_ratio: float + compression_ratio: float | None = None """Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed. @@ -2141,7 +2172,7 @@ class TranscriptionSegment(OpenAIBaseModel): end: float """End time of the segment in seconds.""" - no_speech_prob: float + no_speech_prob: float | None = None """Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider @@ -2181,6 +2212,11 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): """Extracted words and their corresponding timestamps.""" +TranscriptionResponseVariant: TypeAlias = ( + TranscriptionResponse | TranscriptionResponseVerbose +) + + class TranslationResponseStreamChoice(OpenAIBaseModel): delta: DeltaMessage finish_reason: str | None = None @@ -2325,13 +2361,13 @@ class TranslationSegment(OpenAIBaseModel): id: int """Unique identifier of the segment.""" - avg_logprob: float + avg_logprob: float | None = None """Average logprob of the segment. If the value is lower than -1, consider the logprobs failed. """ - compression_ratio: float + compression_ratio: float | None = None """Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed. @@ -2340,7 +2376,7 @@ class TranslationSegment(OpenAIBaseModel): end: float """End time of the segment in seconds.""" - no_speech_prob: float + no_speech_prob: float | None = None """Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider @@ -2380,6 +2416,9 @@ class TranslationResponseVerbose(OpenAIBaseModel): """Extracted words and their corresponding timestamps.""" +TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose + + ####### Tokens IN <> Tokens OUT ####### class GenerateRequest(BaseModel): request_id: str = Field( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index cecd1da1e5548..98fc7810faf96 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -21,16 +21,16 @@ from vllm.entrypoints.chat_utils import ( get_history_tool_calls_cnt, make_tool_call_id, ) -from vllm.entrypoints.harmony_utils import ( +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.parser.harmony_utils import ( get_developer_message, get_stop_tokens_for_assistant_actions, get_streamable_parser_for_assistant, get_system_message, + parse_chat_inputs_to_harmony_messages, parse_chat_output, - parse_input_to_harmony_message, render_for_completion, ) -from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -51,13 +51,15 @@ from vllm.entrypoints.openai.protocol import ( ToolCall, UsageInfo, ) -from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_engine import ( + GenerationError, + OpenAIServing, + clamp_prompt_logprobs, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import get_max_tokens, should_include_usage -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput @@ -69,6 +71,8 @@ from vllm.tokenizers.mistral import ( truncate_tool_call_ids, validate_request_params, ) +from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.utils.collection_utils import as_list from vllm.v1.sample.logits_processor import validate_logits_processors_parameters @@ -230,11 +234,7 @@ class OpenAIServingChat(OpenAIServing): ) if error_check_ret is not None: return error_check_ret - ( - conversation, - request_prompts, - engine_prompts, - ) = await self._preprocess_chat( + conversation, engine_prompts = await self._preprocess_chat( request, tokenizer, request.messages, @@ -250,11 +250,7 @@ class OpenAIServingChat(OpenAIServing): ) else: # For GPT-OSS. - ( - conversation, - request_prompts, - engine_prompts, - ) = self._make_request_with_harmony(request) + conversation, engine_prompts = self._make_request_with_harmony(request) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") @@ -274,7 +270,7 @@ class OpenAIServingChat(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - prompt_text, _, _ = self._get_prompt_components(request_prompts[i]) + prompt_text, _, _ = self._get_prompt_components(engine_prompt) # If we are creating sub requests for multiple prompts, ensure that they # have unique request ids. sub_request_id = ( @@ -309,7 +305,7 @@ class OpenAIServingChat(OpenAIServing): self._log_inputs( sub_request_id, - request_prompts[i], + engine_prompt, params=sampling_params, lora_request=lora_request, ) @@ -380,6 +376,8 @@ class OpenAIServingChat(OpenAIServing): tokenizer, request_metadata, ) + except GenerationError as e: + return self._convert_generation_error_to_response(e) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -531,7 +529,7 @@ class OpenAIServingChat(OpenAIServing): request_id: str, model_name: str, conversation: list[ConversationMessage], - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: created_time = int(time.time()) @@ -585,6 +583,11 @@ class OpenAIServingChat(OpenAIServing): try: if self.reasoning_parser: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + reasoning_parser = self.reasoning_parser( tokenizer, chat_template_kwargs=request.chat_template_kwargs, # type: ignore @@ -598,6 +601,11 @@ class OpenAIServingChat(OpenAIServing): # Prepare the tool parser if it's needed try: if tool_choice_auto and self.tool_parser: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + tool_parsers: list[ToolParser | None] = [ self.tool_parser(tokenizer) ] * num_choices @@ -816,6 +824,9 @@ class OpenAIServingChat(OpenAIServing): if delta_message is not None: harmony_tools_streamed[i] = True + elif cur_channel == "commentary": + # Tool call preambles meant to be shown to the user + delta_message = DeltaMessage(content=delta_text) else: delta_message = None # handle streaming deltas for tools with named tool_choice @@ -953,21 +964,9 @@ class OpenAIServingChat(OpenAIServing): assert reasoning_end_arr is not None output_token_ids = as_list(output.token_ids) if not reasoning_end_arr[i]: - delta_message = ( - reasoning_parser.extract_reasoning_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output_token_ids, - ) - ) # When encountering think end id in prompt_token_ids # i.e {"enable_thinking": False}, # set reasoning status to end. - # Remove the text and token ids related - # to 'reasoning'. if ( res.prompt_token_ids and reasoning_parser.is_reasoning_end( @@ -976,30 +975,38 @@ class OpenAIServingChat(OpenAIServing): ): reasoning_end_arr[i] = True current_token_ids = output_token_ids - if delta_message and delta_message.content: - current_text = delta_message.content - delta_message.content = None - else: - current_text = "" - # When encountering think end id in delta_token_ids, - # set reasoning status to end. - # Remove the text and token ids related - # to 'reasoning'. - if reasoning_parser.is_reasoning_end(output_token_ids): - reasoning_end_arr[i] = True - current_token_ids = ( - reasoning_parser.extract_content_ids( - output_token_ids + # Don't update current_text, keep it as is from delta + else: + delta_message = ( + reasoning_parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, ) ) - if delta_message and delta_message.content: - current_text = delta_message.content - delta_message.content = None - else: - current_text = "" + + # When encountering think end id in delta_token_ids, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning'. + if reasoning_parser.is_reasoning_end(output_token_ids): + reasoning_end_arr[i] = True + current_token_ids = ( + reasoning_parser.extract_content_ids( + output_token_ids + ) + ) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" # handle tool calls only after reasoning is done, - else: + if reasoning_end_arr[i]: delta_token_ids = output_token_ids # First time to tool call, # add the remaining text and token ids @@ -1072,10 +1079,15 @@ class OpenAIServingChat(OpenAIServing): # wasn't ready to send a token, then # get the next token without streaming a chunk if delta_message is None: - if output.finish_reason is None: + # NOTE: If return_token_ids is enabled, we still need to + # send a chunk with token_ids even if delta_message is None + # to ensure all tokens are included in the response + if ( + output.finish_reason is None + and not request.return_token_ids + ): continue - else: - delta_message = DeltaMessage() + delta_message = DeltaMessage() # Log streaming delta if output logging is enabled if self.enable_log_outputs and self.request_logger: @@ -1115,6 +1127,10 @@ class OpenAIServingChat(OpenAIServing): # if the model is finished generating else: + # check for error finish reason and abort streaming + # finish_reason='error' indicates a retryable error + self._raise_if_error(output.finish_reason, request_id) + # check to make sure we haven't "forgotten" to stream # any tokens that were generated but previously # matched by partial json parsing @@ -1282,6 +1298,8 @@ class OpenAIServingChat(OpenAIServing): delta=False, ) + except GenerationError as e: + yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n" except Exception as e: # TODO: Use a vllm-specific Validation Error logger.exception("Error in chat completion stream generator.") @@ -1297,7 +1315,7 @@ class OpenAIServingChat(OpenAIServing): request_id: str, model_name: str, conversation: list[ConversationMessage], - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, request_metadata: RequestResponseMetadata, ) -> ErrorResponse | ChatCompletionResponse: created_time = int(time.time()) @@ -1322,6 +1340,9 @@ class OpenAIServingChat(OpenAIServing): role = self.get_chat_request_role(request) for output in final_res.outputs: + # check for error finish reason and raise GenerationError + # finish_reason='error' indicates a retryable request-level internal error + self._raise_if_error(output.finish_reason, request_id) token_ids = output.token_ids out_logprobs = output.logprobs tool_call_info = None @@ -1344,6 +1365,11 @@ class OpenAIServingChat(OpenAIServing): reasoning = None if self.tool_parser is not None: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + tool_parser = self.tool_parser(tokenizer) # NOTE: We use token_ids for openai tool parser tool_call_info = tool_parser.extract_tool_calls( @@ -1386,6 +1412,11 @@ class OpenAIServingChat(OpenAIServing): if self.reasoning_parser: try: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + reasoning_parser = self.reasoning_parser( tokenizer, chat_template_kwargs=request.chat_template_kwargs, # type: ignore @@ -1625,7 +1656,7 @@ class OpenAIServingChat(OpenAIServing): self, logprobs: dict[int, Logprob], top_logprobs: int | None, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, should_return_as_token_id: bool, ) -> list[ChatCompletionLogProb]: return [ @@ -1649,7 +1680,7 @@ class OpenAIServingChat(OpenAIServing): self, token_ids: GenericSequence[int], top_logprobs: GenericSequence[dict[int, Logprob] | None], - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, num_output_top_logprobs: int | None = None, return_as_token_id: bool | None = None, ) -> ChatCompletionLogProbs: @@ -1667,6 +1698,11 @@ class OpenAIServingChat(OpenAIServing): if should_return_as_token_id: token = f"token_id:{token_id}" else: + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + token = tokenizer.decode(token_id) logprobs_content.append( @@ -1750,6 +1786,11 @@ class OpenAIServingChat(OpenAIServing): ): messages: list[OpenAIMessage] = [] + # because of issues with pydantic we need to potentially + # re-serialize the tool_calls field of the request + # for more info: see comment in `maybe_serialize_tool_calls` + maybe_serialize_tool_calls(request) + # Add system message. # NOTE: In Chat Completion API, browsing is enabled by default # if the model supports it. TODO: Support browsing. @@ -1768,15 +1809,14 @@ class OpenAIServingChat(OpenAIServing): messages.append(dev_msg) # Add user message. - for chat_msg in request.messages: - messages.extend(parse_input_to_harmony_message(chat_msg)) + messages.extend(parse_chat_inputs_to_harmony_messages(request.messages)) # Render prompt token ids. prompt_token_ids = render_for_completion(messages) - engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) # Add cache_salt if provided in the request if request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt - return messages, [prompt_token_ids], [engine_prompt] + return messages, [engine_prompt] diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 3e421e21e3e80..1be0afc8c74e5 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -24,7 +24,11 @@ from vllm.entrypoints.openai.protocol import ( RequestResponseMetadata, UsageInfo, ) -from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_engine import ( + GenerationError, + OpenAIServing, + clamp_prompt_logprobs, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens, should_include_usage @@ -300,6 +304,8 @@ class OpenAIServingCompletion(OpenAIServing): ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") + except GenerationError as e: + return self._convert_generation_error_to_response(e) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -437,6 +443,8 @@ class OpenAIServingCompletion(OpenAIServing): finish_reason = output.finish_reason stop_reason = output.stop_reason + self._raise_if_error(finish_reason, request_id) + chunk = CompletionStreamResponse( id=request_id, created=created_time, @@ -498,8 +506,11 @@ class OpenAIServingCompletion(OpenAIServing): # report to FastAPI middleware aggregate usage across all choices request_metadata.final_usage_info = final_usage_info + except GenerationError as e: + yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n" except Exception as e: # TODO: Use a vllm-specific Validation Error + logger.exception("Error in completion stream generator.") data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" yield "data: [DONE]\n\n" @@ -530,6 +541,8 @@ class OpenAIServingCompletion(OpenAIServing): out_logprobs: GenericSequence[dict[int, Logprob] | None] | None for output in final_res.outputs: + self._raise_if_error(output.finish_reason, request_id) + assert request.max_tokens is not None if request.echo: if request.return_token_ids: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1d89aa011af21..5f7cfaa53ec18 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,19 +5,60 @@ import json import sys import time import traceback -from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence +from collections.abc import AsyncGenerator, Callable, Iterable, Mapping from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from http import HTTPStatus from typing import Any, ClassVar, Generic, TypeAlias, TypeVar import numpy as np -import torch from fastapi import Request +from openai.types.responses import ( + ToolChoiceFunction, +) from pydantic import ConfigDict, TypeAdapter from starlette.datastructures import Headers -from typing_extensions import TypeIs +import vllm.envs as envs +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages_futures, + resolve_chat_template_content_format, +) +from vllm.entrypoints.context import ( + ConversationContext, + HarmonyContext, + ParsableContext, + StreamingHarmonyContext, +) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + ErrorInfo, + ErrorResponse, + FunctionCall, + FunctionDefinition, + ResponseInputOutputItem, + ResponsesRequest, + TokenizeChatRequest, + TokenizeCompletionRequest, + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, + TranslationRequest, +) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.pooling.classify.protocol import ( ClassificationChatRequest, ClassificationCompletionRequest, @@ -39,57 +80,13 @@ from vllm.entrypoints.pooling.score.protocol import ( ScoreRequest, ScoreResponse, ) - -if sys.version_info >= (3, 12): - from typing import TypedDict -else: - from typing_extensions import TypedDict - -from openai.types.responses import ( - ToolChoiceFunction, -) - -import vllm.envs as envs -from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ( - ChatCompletionMessageParam, - ChatTemplateContentFormatOption, - ConversationMessage, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages_futures, - resolve_chat_template_content_format, -) -from vllm.entrypoints.context import ConversationContext -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import ( - ChatCompletionNamedToolChoiceParam, - ChatCompletionRequest, - ChatCompletionResponse, - CompletionRequest, - CompletionResponse, - DetokenizeRequest, - ErrorInfo, - ErrorResponse, - FunctionCall, - FunctionDefinition, - GenerateRequest, - GenerateResponse, - ResponsesRequest, - TokenizeChatRequest, - TokenizeCompletionRequest, - TokenizeResponse, - TranscriptionRequest, - TranscriptionResponse, - TranslationRequest, -) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig +from vllm.entrypoints.responses_utils import ( + construct_input_messages, +) +from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.utils import _validate_truncation_size -from vllm.inputs.data import PromptType -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.parse import ( PromptComponents, get_prompt_components, @@ -98,15 +95,15 @@ from vllm.inputs.parse import ( from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest -from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin - MultiModalDataDict, - MultiModalUUIDDict, -) +from vllm.multimodal import MultiModalDataDict from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers import ToolParser, ToolParserManager from vllm.tracing import ( contains_trace_headers, extract_trace_headers, @@ -122,6 +119,15 @@ from vllm.utils.async_utils import ( from vllm.utils.collection_utils import is_list_of from vllm.v1.engine import EngineCoreRequest + +class GenerationError(Exception): + """raised when finish_reason indicates internal server error (500)""" + + def __init__(self, message: str = "Internal server error"): + super().__init__(message) + self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + logger = init_logger(__name__) CompletionLikeRequest: TypeAlias = ( @@ -163,34 +169,6 @@ AnyResponse: TypeAlias = ( ) -class TextTokensPrompt(TypedDict): - prompt: str - prompt_token_ids: list[int] - - -class EmbedsPrompt(TypedDict): - prompt_embeds: torch.Tensor - - -RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt - - -def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: - return ( - isinstance(prompt, dict) - and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt - ) - - -def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: - return ( - isinstance(prompt, dict) - and "prompt_token_ids" not in prompt - and "prompt_embeds" in prompt - ) - - RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -201,8 +179,7 @@ class RequestProcessingMixin: handling prompt preparation and engine input. """ - request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list) - engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list) + engine_prompts: list[TokensPrompt] | None = field(default_factory=list) @dataclass(kw_only=True) @@ -403,7 +380,7 @@ class OpenAIServing: prompts_batch, lora_req_batch = zip( *[ ( - EngineTokensPrompt( + TokensPrompt( prompt_token_ids=beam.tokens, multi_modal_data=beam.multi_modal_data, mm_processor_kwargs=beam.mm_processor_kwargs, @@ -445,6 +422,29 @@ class OpenAIServing: # Iterate through all beam inference results for i, result in enumerate(output): current_beam = all_beams[i] + + # check for error finish reason and abort beam search + if result.outputs[0].finish_reason == "error": + # yield error output and terminate beam search + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + ) + return + if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] all_beams_token_id.extend(list(logprobs.keys())) @@ -769,6 +769,35 @@ class OpenAIServing: ) return json_str + def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None: + """Raise GenerationError if finish_reason indicates an error.""" + if finish_reason == "error": + logger.error( + "Request %s failed with an internal error during generation", + request_id, + ) + raise GenerationError("Internal server error") + + def _convert_generation_error_to_response( + self, e: GenerationError + ) -> ErrorResponse: + """Convert GenerationError to ErrorResponse.""" + return self.create_error_response( + str(e), + err_type="InternalServerError", + status_code=e.status_code, + ) + + def _convert_generation_error_to_streaming_response( + self, e: GenerationError + ) -> str: + """Convert GenerationError to streaming error response.""" + return self.create_streaming_error_response( + str(e), + err_type="InternalServerError", + status_code=e.status_code, + ) + async def _check_model( self, request: AnyRequest, @@ -873,7 +902,7 @@ class OpenAIServing: prompt: str, tokenizer: TokenizerLike, add_special_tokens: bool, - ) -> TextTokensPrompt: + ) -> TokensPrompt: async_tokenizer = self._get_async_tokenizer(tokenizer) if ( @@ -914,7 +943,7 @@ class OpenAIServing: request: AnyRequest, prompt_ids: list[int], tokenizer: TokenizerLike | None, - ) -> TextTokensPrompt: + ) -> TokensPrompt: truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) if truncate_prompt_tokens is None: @@ -937,7 +966,7 @@ class OpenAIServing: request: AnyRequest, input_ids: list[int], input_text: str, - ) -> TextTokensPrompt: + ) -> TokensPrompt: token_num = len(input_ids) # Note: EmbeddingRequest, ClassificationRequest, @@ -968,7 +997,7 @@ class OpenAIServing: f"{token_num} tokens in the input for {operation}. " f"Please reduce the length of the input." ) - return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation @@ -976,7 +1005,7 @@ class OpenAIServing: request, (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), ): - return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # chat completion endpoint supports max_completion_tokens if isinstance(request, ChatCompletionRequest): @@ -1004,7 +1033,7 @@ class OpenAIServing: f" - {token_num})." ) - return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) async def _tokenize_prompt_input_async( self, @@ -1012,7 +1041,7 @@ class OpenAIServing: tokenizer: TokenizerLike, prompt_input: str | list[int], add_special_tokens: bool = True, - ) -> TextTokensPrompt: + ) -> TokensPrompt: """ A simpler implementation that tokenizes a single prompt input. """ @@ -1031,7 +1060,7 @@ class OpenAIServing: tokenizer: TokenizerLike, prompt_inputs: Iterable[str | list[int]], add_special_tokens: bool = True, - ) -> AsyncGenerator[TextTokensPrompt, None]: + ) -> AsyncGenerator[TokensPrompt, None]: """ A simpler implementation that tokenizes multiple prompt inputs. """ @@ -1084,16 +1113,7 @@ class OpenAIServing: chat_template_kwargs: dict[str, Any] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, add_special_tokens: bool = False, - ) -> tuple[ - list[ConversationMessage], - Sequence[RequestPrompt], - list[EngineTokensPrompt], - ]: - if tokenizer is None: - raise ValueError( - "Unable to get tokenizer because `skip_tokenizer_init=True`" - ) - + ) -> tuple[list[ConversationMessage], list[TokensPrompt]]: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( @@ -1106,7 +1126,6 @@ class OpenAIServing: conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( messages, model_config, - tokenizer, content_format=resolved_content_format, ) @@ -1129,6 +1148,13 @@ class OpenAIServing: messages=messages, **_chat_template_kwargs, ) + elif isinstance(tokenizer, DeepseekV32Tokenizer): + request_prompt = tokenizer.apply_chat_template( + conversation=conversation, + messages=messages, + model_config=model_config, + **_chat_template_kwargs, + ) else: request_prompt = apply_hf_chat_template( tokenizer=tokenizer, @@ -1160,9 +1186,7 @@ class OpenAIServing: "Prompt has to be a string", "when the tokenizer is not initialised", ) - prompt_inputs = TextTokensPrompt( - prompt=request_prompt, prompt_token_ids=[1] - ) + prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1]) elif isinstance(request_prompt, str): prompt_inputs = await self._tokenize_prompt_input_async( request, @@ -1175,14 +1199,15 @@ class OpenAIServing: assert is_list_of(request_prompt, int), ( "Prompt has to be either a string or a list of token ids" ) - prompt_inputs = TextTokensPrompt( + prompt_inputs = TokensPrompt( prompt=tokenizer.decode(request_prompt), prompt_token_ids=request_prompt, ) - engine_prompt = EngineTokensPrompt( - prompt_token_ids=prompt_inputs["prompt_token_ids"] - ) + engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if "prompt" in prompt_inputs: + engine_prompt["prompt"] = prompt_inputs["prompt"] + if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data @@ -1195,7 +1220,7 @@ class OpenAIServing: if hasattr(request, "cache_salt") and request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt - return conversation, [request_prompt], [engine_prompt] + return conversation, [engine_prompt] async def _process_inputs( self, @@ -1224,18 +1249,43 @@ class OpenAIServing: ) return engine_request, tokenization_kwargs + async def _render_next_turn( + self, + request: ResponsesRequest, + tokenizer: TokenizerLike | None, + messages: list[ResponseInputOutputItem], + tool_dicts: list[dict[str, Any]] | None, + tool_parser, + chat_template: str | None, + chat_template_content_format: ChatTemplateContentFormatOption, + ): + new_messages = construct_input_messages( + request_input=messages, + ) + + _, engine_prompts = await self._preprocess_chat( + request, + tokenizer, + new_messages, + tool_dicts=tool_dicts, + tool_parser=tool_parser, + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + ) + return engine_prompts + async def _generate_with_builtin_tools( self, request_id: str, - request_prompt: RequestPrompt, - engine_prompt: EngineTokensPrompt, + engine_prompt: TokensPrompt, sampling_params: SamplingParams, context: ConversationContext, lora_request: LoRARequest | None = None, priority: int = 0, **kwargs, ): - prompt_text, _, _ = self._get_prompt_components(request_prompt) + prompt_text, _, _ = self._get_prompt_components(engine_prompt) + orig_priority = priority sub_request = 0 while True: @@ -1243,7 +1293,7 @@ class OpenAIServing: sub_request_id = f"{request_id}_{sub_request}" self._log_inputs( sub_request_id, - request_prompt, + engine_prompt, params=sampling_params, lora_request=lora_request, ) @@ -1286,28 +1336,37 @@ class OpenAIServing: # Create inputs for the next turn. # Render the next prompt token ids. - prompt_token_ids = context.render_for_completion() - engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) - request_prompt = prompt_token_ids + if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): + prompt_token_ids = context.render_for_completion() + engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + elif isinstance(context, ParsableContext): + engine_prompts = await self._render_next_turn( + context.request, + context.tokenizer, + context.parser.response_messages, + context.tool_dicts, + context.tool_parser_cls, + context.chat_template, + context.chat_template_content_format, + ) + engine_prompt = engine_prompts[0] + prompt_text, _, _ = self._get_prompt_components(engine_prompt) + # Update the sampling params. - sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids) + sampling_params.max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"] + ) # OPTIMIZATION priority = orig_priority - 1 sub_request += 1 - def _get_prompt_components( - self, - prompt: RequestPrompt | PromptType, - ) -> PromptComponents: - if isinstance(prompt, list): - return PromptComponents(token_ids=prompt) - - return get_prompt_components(prompt) # type: ignore[arg-type] + def _get_prompt_components(self, prompt: PromptType) -> PromptComponents: + return get_prompt_components(prompt) def _log_inputs( self, request_id: str, - inputs: RequestPrompt | PromptType, + inputs: PromptType, params: SamplingParams | PoolingParams | BeamSearchParams | None, lora_request: LoRARequest | None, ) -> None: @@ -1369,7 +1428,7 @@ class OpenAIServing: @staticmethod def _parse_tool_calls_from_content( request: ResponsesRequest | ChatCompletionRequest, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, enable_auto_tools: bool, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, content: str | None = None, @@ -1409,6 +1468,11 @@ class OpenAIServing: and enable_auto_tools and (request.tool_choice == "auto" or request.tool_choice is None) ): + if tokenizer is None: + raise ValueError( + "Tokenizer not available when `skip_tokenizer_init=True`" + ) + # Automatic Tool Call Parsing try: tool_parser = tool_parser_cls(tokenizer) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 81495a0777546..1f9b5704624ab 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -50,6 +50,7 @@ from openai.types.responses.response_reasoning_item import ( ) from openai.types.responses.tool import Mcp, Tool from openai_harmony import Message as OpenAIHarmonyMessage +from pydantic import TypeAdapter from vllm import envs from vllm.engine.protocol import EngineClient @@ -60,10 +61,12 @@ from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.context import ( ConversationContext, HarmonyContext, + ParsableContext, SimpleContext, StreamingHarmonyContext, ) -from vllm.entrypoints.harmony_utils import ( +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.parser.harmony_utils import ( construct_harmony_previous_input_messages, get_developer_message, get_stop_tokens_for_assistant_actions, @@ -75,7 +78,6 @@ from vllm.entrypoints.harmony_utils import ( parse_response_input, render_for_completion, ) -from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( DeltaMessage, ErrorResponse, @@ -85,6 +87,7 @@ from vllm.entrypoints.openai.protocol import ( ResponseCompletedEvent, ResponseCreatedEvent, ResponseInProgressEvent, + ResponseInputOutputMessage, ResponseReasoningPartAddedEvent, ResponseReasoningPartDoneEvent, ResponsesRequest, @@ -92,15 +95,18 @@ from vllm.entrypoints.openai.protocol import ( ResponseUsage, StreamingResponsesResponse, ) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import ( + GenerationError, + OpenAIServing, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.responses_utils import ( construct_input_messages, - convert_tool_responses_to_completions_format, + construct_tool_dicts, extract_tool_types, ) from vllm.entrypoints.tool_server import ToolServer -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs @@ -228,7 +234,6 @@ class OpenAIServingResponses(OpenAIServing): self.tool_parser = self._get_tool_parser( tool_parser_name=tool_parser, enable_auto_tools=enable_auto_tools ) - self.exclude_tools_when_tool_choice_none = False # HACK(woosuk): This is a hack. We should use a better store. # FIXME: If enable_store=True, this may cause a memory leak since we # never remove responses from the store. @@ -252,7 +257,7 @@ class OpenAIServingResponses(OpenAIServing): self.tool_server = tool_server def _validate_generator_input( - self, engine_prompt: EngineTokensPrompt + self, engine_prompt: TokensPrompt ) -> ErrorResponse | None: """Add validations to the input to the generator here.""" if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): @@ -347,11 +352,11 @@ class OpenAIServingResponses(OpenAIServing): tokenizer = await self.engine_client.get_tokenizer() if self.use_harmony: - messages, request_prompts, engine_prompts = ( - self._make_request_with_harmony(request, prev_response) + messages, engine_prompts = self._make_request_with_harmony( + request, prev_response ) else: - messages, request_prompts, engine_prompts = await self._make_request( + messages, engine_prompts = await self._make_request( request, prev_response, tokenizer ) @@ -373,7 +378,7 @@ class OpenAIServingResponses(OpenAIServing): generators: list[AsyncGenerator[ConversationContext, None]] = [] builtin_tool_list: list[str] = [] - if self.use_harmony and self.tool_server is not None: + if self.tool_server is not None: if self.tool_server.has_tool("browser"): builtin_tool_list.append("browser") if self.tool_server.has_tool("python"): @@ -387,7 +392,7 @@ class OpenAIServingResponses(OpenAIServing): assert len(builtin_tool_list) == 0 available_tools = [] try: - for i, engine_prompt in enumerate(engine_prompts): + for engine_prompt in engine_prompts: maybe_error = self._validate_generator_input(engine_prompt) if maybe_error is not None: return maybe_error @@ -413,7 +418,21 @@ class OpenAIServingResponses(OpenAIServing): else: context = HarmonyContext(messages, available_tools) else: - context = SimpleContext() + if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: + # This is a feature in development for parsing + # tokens during generation instead of at the end + context = ParsableContext( + response_messages=messages, + tokenizer=tokenizer, + reasoning_parser_cls=self.reasoning_parser, + request=request, + tool_parser_cls=self.tool_parser, + available_tools=available_tools, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + ) + else: + context = SimpleContext() if self.reasoning_parser is not None: reasoning_parser = self.reasoning_parser(tokenizer) @@ -429,7 +448,6 @@ class OpenAIServingResponses(OpenAIServing): ) generator = self._generate_with_builtin_tools( request_id=request.request_id, - request_prompt=request_prompts[i], engine_prompt=engine_prompt, sampling_params=sampling_params, context=context, @@ -525,6 +543,8 @@ class OpenAIServingResponses(OpenAIServing): tokenizer, request_metadata, ) + except GenerationError as e: + return self._convert_generation_error_to_response(e) except Exception as e: return self.create_error_response(str(e)) @@ -534,15 +554,7 @@ class OpenAIServingResponses(OpenAIServing): prev_response: ResponsesResponse | None, tokenizer: TokenizerLike, ): - if request.tools is None or ( - request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none - ): - tool_dicts = None - else: - tool_dicts = [ - convert_tool_responses_to_completions_format(tool.model_dump()) - for tool in request.tools - ] + tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) # Construct the input messages. messages = construct_input_messages( request_instructions=request.instructions, @@ -550,7 +562,7 @@ class OpenAIServingResponses(OpenAIServing): prev_msg=self.msg_store.get(prev_response.id) if prev_response else None, prev_response_output=prev_response.output if prev_response else None, ) - _, request_prompts, engine_prompts = await self._preprocess_chat( + _, engine_prompts = await self._preprocess_chat( request, tokenizer, messages, @@ -559,7 +571,7 @@ class OpenAIServingResponses(OpenAIServing): chat_template=self.chat_template, chat_template_content_format=self.chat_template_content_format, ) - return messages, request_prompts, engine_prompts + return messages, engine_prompts def _make_request_with_harmony( self, @@ -572,13 +584,13 @@ class OpenAIServingResponses(OpenAIServing): ) messages = self._construct_input_messages_with_harmony(request, prev_response) prompt_token_ids = render_for_completion(messages) - engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) # Add cache_salt if provided in the request if request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt - return messages, [prompt_token_ids], [engine_prompt] + return messages, [engine_prompt] async def _initialize_tool_sessions( self, @@ -626,8 +638,8 @@ class OpenAIServingResponses(OpenAIServing): # "completed" is implemented as the "catch-all" for now. status: ResponseStatus = "completed" - input_messages = None - output_messages = None + input_messages: ResponseInputOutputMessage | None = None + output_messages: ResponseInputOutputMessage | None = None if self.use_harmony: assert isinstance(context, HarmonyContext) output = self._make_response_output_items_with_harmony(context) @@ -640,28 +652,42 @@ class OpenAIServingResponses(OpenAIServing): status = "incomplete" elif context.finish_reason == "abort": status = "cancelled" + else: + self._raise_if_error(context.finish_reason, request.request_id) else: status = "incomplete" + elif isinstance(context, ParsableContext): + output = context.parser.make_response_output_items_from_parsable_context() + + if request.enable_response_messages: + input_messages = context.input_messages + output_messages = context.output_messages + + # TODO: Calculate usage. + # assert final_res.prompt_token_ids is not None + num_tool_output_tokens = 0 else: assert isinstance(context, SimpleContext) - final_res = context.last_output + # Use final_output which has accumulated text/token_ids/logprobs + final_res = context.final_output assert final_res is not None assert len(final_res.outputs) == 1 final_output = final_res.outputs[0] + # finish_reason='error' indicates retryable internal error + self._raise_if_error(final_output.finish_reason, request.request_id) + output = self._make_response_output_items(request, final_output, tokenizer) - # TODO: context for non-gptoss models doesn't use messages - # so we can't get them out yet if request.enable_response_messages: - raise NotImplementedError( - "enable_response_messages is currently only supported for gpt-oss" - ) + input_messages = context.input_messages + output_messages = context.output_messages + # Calculate usage. assert final_res.prompt_token_ids is not None num_tool_output_tokens = 0 - assert isinstance(context, (SimpleContext, HarmonyContext)) + assert isinstance(context, (SimpleContext, HarmonyContext, ParsableContext)) num_prompt_tokens = context.num_prompt_tokens num_generated_tokens = context.num_output_tokens num_cached_tokens = context.num_cached_tokens @@ -1044,6 +1070,8 @@ class OpenAIServingResponses(OpenAIServing): async for event in generator: event_deque.append(event) new_event_signal.set() # Signal new event available + except GenerationError as e: + response = self._convert_generation_error_to_response(e) except Exception as e: logger.exception("Background request failed for %s", request.request_id) response = self.create_error_response(str(e)) @@ -1067,6 +1095,8 @@ class OpenAIServingResponses(OpenAIServing): ): try: response = await self.responses_full_generator(request, *args, **kwargs) + except GenerationError as e: + response = self._convert_generation_error_to_response(e) except Exception as e: logger.exception("Background request failed for %s", request.request_id) response = self.create_error_response(str(e)) @@ -1205,6 +1235,8 @@ class OpenAIServingResponses(OpenAIServing): continue if ctx.last_output.outputs: output = ctx.last_output.outputs[0] + # finish_reason='error' indicates a retryable error + self._raise_if_error(output.finish_reason, request.request_id) if reasoning_parser: delta_message = reasoning_parser.extract_reasoning_streaming( previous_text=previous_text, @@ -1500,6 +1532,9 @@ class OpenAIServingResponses(OpenAIServing): async for ctx in result_generator: assert isinstance(ctx, StreamingHarmonyContext) + # finish_reason='error' indicates a retryable error + self._raise_if_error(ctx.finish_reason, request.request_id) + if ctx.is_expecting_start(): current_output_index += 1 sent_output_item_added = False @@ -1994,18 +2029,25 @@ class OpenAIServingResponses(OpenAIServing): ) ) - async for event_data in processer( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - created_time, - _increment_sequence_number_and_return, - ): - yield event_data + try: + async for event_data in processer( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + _increment_sequence_number_and_return, + ): + yield event_data + except GenerationError as e: + error_json = self._convert_generation_error_to_streaming_response(e) + yield _increment_sequence_number_and_return( + TypeAdapter(StreamingResponsesResponse).validate_json(error_json) + ) + return async def empty_async_generator(): # A hack to trick Python to think this is a generator but diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 33da7034afabc..189b532810b43 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -12,10 +12,12 @@ from vllm.entrypoints.openai.protocol import ( TranscriptionRequest, TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionResponseVerbose, TranscriptionStreamResponse, TranslationRequest, TranslationResponse, TranslationResponseStreamChoice, + TranslationResponseVerbose, TranslationStreamResponse, ) from vllm.entrypoints.openai.serving_models import OpenAIServingModels @@ -51,7 +53,12 @@ class OpenAIServingTranscription(OpenAISpeechToText): async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request - ) -> TranscriptionResponse | AsyncGenerator[str, None] | ErrorResponse: + ) -> ( + TranscriptionResponse + | TranscriptionResponseVerbose + | AsyncGenerator[str, None] + | ErrorResponse + ): """Transcription API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranscription @@ -61,7 +68,11 @@ class OpenAIServingTranscription(OpenAISpeechToText): audio_data=audio_data, request=request, raw_request=raw_request, - response_class=TranscriptionResponse, + response_class=( + TranscriptionResponseVerbose + if request.response_format == "verbose_json" + else TranscriptionResponse + ), stream_generator_method=self.transcription_stream_generator, ) @@ -112,7 +123,12 @@ class OpenAIServingTranslation(OpenAISpeechToText): async def create_translation( self, audio_data: bytes, request: TranslationRequest, raw_request: Request - ) -> TranslationResponse | AsyncGenerator[str, None] | ErrorResponse: + ) -> ( + TranslationResponse + | TranslationResponseVerbose + | AsyncGenerator[str, None] + | ErrorResponse + ): """Translation API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranslation @@ -122,7 +138,11 @@ class OpenAIServingTranslation(OpenAISpeechToText): audio_data=audio_data, request=request, raw_request=raw_request, - response_class=TranslationResponse, + response_class=( + TranslationResponseVerbose + if request.response_format == "verbose_json" + else TranslationResponse + ), stream_generator_method=self.translation_stream_generator, ) diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 3dece07748cc4..cea9924ebbaca 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -10,6 +10,7 @@ from typing import Literal, TypeAlias, TypeVar, cast import numpy as np from fastapi import Request +from transformers import PreTrainedTokenizerBase import vllm.envs as envs from vllm.engine.protocol import EngineClient @@ -20,9 +21,13 @@ from vllm.entrypoints.openai.protocol import ( RequestResponseMetadata, TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionResponseVerbose, + TranscriptionSegment, TranscriptionStreamResponse, TranslationResponse, TranslationResponseStreamChoice, + TranslationResponseVerbose, + TranslationSegment, TranslationStreamResponse, UsageInfo, ) @@ -32,6 +37,7 @@ from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.models import SupportsTranscription from vllm.outputs import RequestOutput +from vllm.tokenizers import get_tokenizer from vllm.utils.import_utils import PlaceholderModule try: @@ -40,7 +46,20 @@ except ImportError: librosa = PlaceholderModule("librosa") # type: ignore[assignment] SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse +SpeechToTextResponseVerbose: TypeAlias = ( + TranscriptionResponseVerbose | TranslationResponseVerbose +) +SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment T = TypeVar("T", bound=SpeechToTextResponse) +V = TypeVar("V", bound=SpeechToTextResponseVerbose) +S = TypeVar("S", bound=SpeechToTextSegment) + +ResponseType: TypeAlias = ( + TranscriptionResponse + | TranslationResponse + | TranscriptionResponseVerbose + | TranslationResponseVerbose +) logger = init_logger(__name__) @@ -78,6 +97,14 @@ class OpenAISpeechToText(OpenAIServing): self.enable_force_include_usage = enable_force_include_usage self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB + if self.model_cls.supports_segment_timestamp: + self.tokenizer = cast( + PreTrainedTokenizerBase, + get_tokenizer( + tokenizer_name=self.model_config.tokenizer, + tokenizer_mode=self.model_config.tokenizer_mode, + ), + ) if self.default_sampling_params: logger.info( @@ -133,17 +160,87 @@ class OpenAISpeechToText(OpenAIServing): request_prompt=request.prompt, to_language=to_language, ) + if request.response_format == "verbose_json": + if not isinstance(prompt, dict): + raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}") + prompt_dict = cast(dict, prompt) + decoder_prompt = prompt.get("decoder_prompt") + if not isinstance(decoder_prompt, str): + raise ValueError( + f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}" + ) + prompt_dict["decoder_prompt"] = decoder_prompt.replace( + "<|notimestamps|>", "<|0.00|>" + ) prompts.append(prompt) return prompts, duration + def _get_verbose_segments( + self, + tokens: tuple, + request: SpeechToTextRequest, + segment_class: type[SpeechToTextSegment], + start_time: float = 0, + ) -> list[SpeechToTextSegment]: + """ + Convert tokens to verbose segments. + + This method expects the model to produce + timestamps as tokens (similar to Whisper). + If the tokens do not include timestamp information, + the segments may not be generated correctly. + + Note: Fields like avg_logprob, compression_ratio, + and no_speech_prob are not supported + in this implementation and will be None. See docs for details. + """ + BASE_OFFSET = 0.02 + init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0] + if tokens[-1] == self.tokenizer.eos_token_id: + tokens = tokens[:-1] + + tokens_with_start = (init_token,) + tokens + segments: list[SpeechToTextSegment] = [] + last_timestamp_start = 0 + + if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token: + tokens_with_start = tokens_with_start + (tokens_with_start[-1],) + for idx, token in enumerate(tokens_with_start): + # Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted. + # If the ordering is violated, this slicing may produce incorrect results. + if ( + token >= init_token + and idx != 0 + and tokens_with_start[idx - 1] >= init_token + ): + sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx] + start_timestamp = sliced_timestamp_tokens[0] - init_token + end_timestamp = sliced_timestamp_tokens[-1] - init_token + + casting_segment = cast( + SpeechToTextSegment, + segment_class( + id=len(segments), + seek=start_time, + start=start_time + BASE_OFFSET * start_timestamp, + end=start_time + BASE_OFFSET * end_timestamp, + temperature=request.temperature, + text=self.tokenizer.decode(sliced_timestamp_tokens[1:-1]), + tokens=sliced_timestamp_tokens[1:-1], + ), + ) + segments.append(casting_segment) + last_timestamp_start = idx + return segments + async def _create_speech_to_text( self, audio_data: bytes, request: SpeechToTextRequest, raw_request: Request, - response_class: type[T], + response_class: type[T | V], stream_generator_method: Callable[..., AsyncGenerator[str, None]], - ) -> T | AsyncGenerator[str, None] | ErrorResponse: + ) -> T | V | AsyncGenerator[str, None] | ErrorResponse: """Base method for speech-to-text operations like transcription and translation.""" error_check_ret = await self._check_model(request) @@ -156,11 +253,24 @@ class OpenAISpeechToText(OpenAIServing): if self.engine_client.errored: raise self.engine_client.dead_error - if request.response_format not in ["text", "json"]: + if request.response_format not in ["text", "json", "verbose_json"]: return self.create_error_response( - "Currently only support response_format `text` or `json`" + ("Currently only support response_format") + + ("`text`, `json` or `verbose_json`") ) + if ( + request.response_format == "verbose_json" + and not self.model_cls.supports_segment_timestamp + ): + return self.create_error_response( + f"Currently do not support verbose_json for {request.model}" + ) + + if request.response_format == "verbose_json" and request.stream: + return self.create_error_response( + "verbose_json format doesn't support streaming case" + ) request_id = f"{self.task_type}-{self._base_request_id(raw_request)}" request_metadata = RequestResponseMetadata(request_id=request_id) @@ -215,25 +325,69 @@ class OpenAISpeechToText(OpenAIServing): request, list_result_generator, request_id, request_metadata, duration_s ) # Non-streaming response. + total_segments = [] + text_parts = [] try: assert list_result_generator is not None + segments_types: dict[str, type[SpeechToTextSegment]] = { + "transcribe": TranscriptionSegment, + "translate": TranslationSegment, + } + segment_class: type[SpeechToTextSegment] = segments_types[self.task_type] text = "" - for result_generator in list_result_generator: + for idx, result_generator in enumerate(list_result_generator): async for op in result_generator: - text += op.outputs[0].text + if request.response_format == "verbose_json": + segments: list[SpeechToTextSegment] = ( + self._get_verbose_segments( + tokens=tuple(op.outputs[0].token_ids), + segment_class=segment_class, + request=request, + start_time=idx * self.asr_config.max_audio_clip_s, + ) + ) + total_segments.extend(segments) + text_parts.extend([seg.text for seg in segments]) + else: + text_parts.append(op.outputs[0].text) + text = "".join(text_parts) if self.task_type == "transcribe": + final_response: ResponseType # add usage in TranscriptionResponse. usage = { "type": "duration", # rounded up as per openAI specs "seconds": int(math.ceil(duration_s)), } - final_response = cast(T, response_class(text=text, usage=usage)) + if request.response_format != "verbose_json": + final_response = cast( + T, TranscriptionResponse(text=text, usage=usage) + ) + else: + final_response = cast( + V, + TranscriptionResponseVerbose( + text=text, + language=request.language, + duration=str(duration_s), + segments=total_segments, + ), + ) else: # no usage in response for translation task - final_response = cast(T, response_class(text=text)) # type: ignore[call-arg] - + if request.response_format != "verbose_json": + final_response = cast(T, TranslationResponse(text=text)) + else: + final_response = cast( + V, + TranslationResponseVerbose( + text=text, + language=request.language, + duration=str(duration_s), + segments=total_segments, + ), + ) return final_response except asyncio.CancelledError: return self.create_error_response("Client disconnected") diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 89e439dd53f5f..ad1b682a9ef65 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,142 +1,33 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, - ToolParserManager, -) - -__all__ = ["ToolParser", "ToolParserManager"] +import warnings -""" -Register a lazy module mapping. +def __getattr__(name: str): + if name == "ToolParser": + from vllm.tool_parsers import ToolParser -Example: - ToolParserManager.register_lazy_module( - name="kimi_k2", - module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser", - class_name="KimiK2ToolParser", - ) -""" + warnings.warn( + "`vllm.entrypoints.openai.tool_parsers.ToolParser` has been moved to " + "`vllm.tool_parsers.ToolParser`. " + "The old name will be removed in v0.14.", + DeprecationWarning, + stacklevel=2, + ) + return ToolParser + if name == "ToolParserManager": + from vllm.tool_parsers import ToolParserManager -_TOOL_PARSERS_TO_REGISTER = { - "deepseek_v3": ( # name - "deepseekv3_tool_parser", # filename - "DeepSeekV3ToolParser", # class_name - ), - "deepseek_v31": ( - "deepseekv31_tool_parser", - "DeepSeekV31ToolParser", - ), - "ernie45": ( - "ernie45_tool_parser", - "Ernie45ToolParser", - ), - "glm45": ( - "glm4_moe_tool_parser", - "Glm4MoeModelToolParser", - ), - "granite-20b-fc": ( - "granite_20b_fc_tool_parser", - "Granite20bFCToolParser", - ), - "granite": ( - "granite_tool_parser", - "GraniteToolParser", - ), - "hermes": ( - "hermes_tool_parser", - "Hermes2ProToolParser", - ), - "hunyuan_a13b": ( - "hunyuan_a13b_tool_parser", - "HunyuanA13BToolParser", - ), - "internlm": ( - "internlm2_tool_parser", - "Internlm2ToolParser", - ), - "jamba": ( - "jamba_tool_parser", - "JambaToolParser", - ), - "kimi_k2": ( - "kimi_k2_tool_parser", - "KimiK2ToolParser", - ), - "llama3_json": ( - "llama_tool_parser", - "Llama3JsonToolParser", - ), - "llama4_json": ( - "llama_tool_parser", - "Llama3JsonToolParser", - ), - "llama4_pythonic": ( - "llama4_pythonic_tool_parser", - "Llama4PythonicToolParser", - ), - "longcat": ( - "longcat_tool_parser", - "LongcatFlashToolParser", - ), - "minimax_m2": ( - "minimax_m2_tool_parser", - "MinimaxM2ToolParser", - ), - "minimax": ( - "minimax_tool_parser", - "MinimaxToolParser", - ), - "mistral": ( - "mistral_tool_parser", - "MistralToolParser", - ), - "olmo3": ( - "olmo3_tool_parser", - "Olmo3PythonicToolParser", - ), - "openai": ( - "openai_tool_parser", - "OpenAIToolParser", - ), - "phi4_mini_json": ( - "phi4mini_tool_parser", - "Phi4MiniJsonToolParser", - ), - "pythonic": ( - "pythonic_tool_parser", - "PythonicToolParser", - ), - "qwen3_coder": ( - "qwen3coder_tool_parser", - "Qwen3CoderToolParser", - ), - "qwen3_xml": ( - "qwen3xml_tool_parser", - "Qwen3XMLToolParser", - ), - "seed_oss": ( - "seed_oss_tool_parser", - "SeedOssToolParser", - ), - "step3": ( - "step3_tool_parser", - "Step3ToolParser", - ), - "xlam": ( - "xlam_tool_parser", - "xLAMToolParser", - ), -} + warnings.warn( + "`vllm.entrypoints.openai.tool_parsers.ToolParserManager` " + "has been moved to `vllm.tool_parsers.ToolParserManager`. " + "The old name will be removed in v0.14.", + DeprecationWarning, + stacklevel=2, + ) + return ToolParserManager -def register_lazy_tool_parsers(): - for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items(): - module_path = f"vllm.entrypoints.openai.tool_parsers.{file_name}" - ToolParserManager.register_lazy_module(name, module_path, class_name) - - -register_lazy_tool_parsers() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py deleted file mode 100644 index 7e2d67a1fb659..0000000000000 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ /dev/null @@ -1,390 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import json -from collections.abc import Sequence -from random import choices -from string import ascii_letters, digits - -import partial_json_parser -import regex as re -from partial_json_parser.core.options import Allow -from pydantic import Field - -from vllm.entrypoints.openai.protocol import ( - ChatCompletionRequest, - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, - ToolCall, -) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff -from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike - -logger = init_logger(__name__) - -ALPHANUMERIC = ascii_letters + digits - - -class MistralToolCall(ToolCall): - id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id()) - - @staticmethod - def generate_random_id(): - # Mistral Tool Call Ids must be alphanumeric with a length of 9. - # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 - return "".join(choices(ALPHANUMERIC, k=9)) - - @staticmethod - def is_valid_id(id: str) -> bool: - return id.isalnum() and len(id) == 9 - - -def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool: - return ( - isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11 - ) - - -class MistralToolParser(ToolParser): - """ - Tool call parser for Mistral 7B Instruct v0.3, intended for use with - - [`mistral_common`](https://github.com/mistralai/mistral-common/) - - the examples/tool_chat_template_mistral.jinja template. - - Used when --enable-auto-tool-choice --tool-call-parser mistral are all set - """ - - def __init__(self, tokenizer: TokenizerLike): - super().__init__(tokenizer) - - if not isinstance(self.model_tokenizer, MistralTokenizer): - logger.info("Non-Mistral tokenizer detected when using a Mistral model...") - - # initialize properties used for state when parsing tool calls in - # streaming mode - self.prev_tool_call_arr: list[dict] = [] - self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[ - str - ] = [] # map what has been streamed for each tool so far to a list - self.bot_token = "[TOOL_CALLS]" - self.bot_token_id = self.vocab.get(self.bot_token) - self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) - if _is_fn_name_regex_support(self.model_tokenizer): - self.fn_name_regex = re.compile( - r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)", re.DOTALL - ) - else: - self.fn_name_regex = None - - if self.bot_token_id is None: - raise RuntimeError( - "Mistral Tool Parser could not locate the tool call token in " - "the tokenizer!" - ) - - def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: - request = super().adjust_request(request) - if ( - not isinstance(self.model_tokenizer, MistralTokenizer) - and request.tools - and request.tool_choice != "none" - ): - # Do not skip special tokens when using chat template - # with Mistral parser as TOOL_CALL token is needed - # for tool detection. - # Note: we don't want skip_special_tokens=False - # with MistralTokenizer as it is incompatible - request.skip_special_tokens = False - return request - - def extract_tool_calls( - self, - model_output: str, - request: ChatCompletionRequest, - ) -> ExtractedToolCallInformation: - """ - Extract the tool calls from a complete model response. Requires - find-and-replacing single quotes with double quotes for JSON parsing, - make sure your tool call arguments don't ever include quotes! - """ - - # case -- if a tool call token is not present, return a text response - if self.bot_token not in model_output: - return ExtractedToolCallInformation( - tools_called=False, tool_calls=[], content=model_output - ) - - # first remove the BOT token - tool_content = model_output.replace(self.bot_token, "").strip() - - try: - # we first try to directly load the json as parsing very nested - # jsons is difficult - try: - if self.fn_name_regex: - matches = self.fn_name_regex.findall(tool_content) - - function_call_arr = [] - for match in matches: - fn_name = match[0] - args = match[1] - - # fn_name is encoded outside serialized json dump - # only arguments are serialized - function_call_arr.append( - {"name": fn_name, "arguments": json.loads(args)} - ) - else: - function_call_arr = json.loads(tool_content) - except json.JSONDecodeError: - # use a regex to find the part corresponding to the tool call. - # NOTE: This use case should not happen if the model is trained - # correctly. It's an easy possible fix so it's included, but - # can be brittle for very complex / highly nested tool calls - raw_tool_call = self.tool_call_regex.findall(tool_content)[0] - function_call_arr = json.loads(raw_tool_call) - - # Tool Call - tool_calls: list[MistralToolCall] = [ - MistralToolCall( - type="function", - function=FunctionCall( - name=raw_function_call["name"], - # function call args are JSON but as a string - arguments=json.dumps( - raw_function_call["arguments"], ensure_ascii=False - ), - ), - ) - for raw_function_call in function_call_arr - ] - - # get any content before the tool call - content = model_output.split(self.bot_token)[0] - return ExtractedToolCallInformation( - tools_called=True, - tool_calls=tool_calls, - content=content if len(content) > 0 else None, - ) - - except Exception: - logger.exception("Error in extracting tool call from response.") - # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation( - tools_called=False, tool_calls=[], content=tool_content - ) - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - request: ChatCompletionRequest, - ) -> DeltaMessage | None: - # if the tool call token is not in the tokens generated so far, append - # output to contents since it's not a tool - if self.bot_token not in current_text: - return DeltaMessage(content=delta_text) - - # if the tool call token ID IS in the tokens generated so far, that - # means we're parsing as tool calls now - - # handle if we detected the BOT token which means the start of tool - # calling - if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1: - # if it's the only token, return None, so we don't send a chat - # completion any don't send a control token - return None - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR - try: - # replace BOT token with empty string, and convert single quotes - # to double to allow parsing as JSON since mistral uses single - # quotes instead of double for tool calls - parsable_arr = current_text.split(self.bot_token)[-1] - - # tool calls are generated in an array, so do partial JSON - # parsing on the entire array - try: - tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags - ) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug("not enough tokens to parse into JSON yet") - return None - - # select as the current tool call the one we're on the state at - - current_tool_call: dict = ( - tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} - ) - - # case -- if no tokens have been streamed for the tool, e.g. - # only the array brackets, stream nothing - if len(tool_call_arr) == 0: - return None - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - elif ( - len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 - ): - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. - if self.current_tool_id >= 0: - diff: str | None = current_tool_call.get("arguments") - - if diff: - diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], "" - ) - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", self.current_tool_id) - return delta - - # case: update an existing tool - this is handled below - - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - if not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=MistralToolCall.generate_random_id(), - function=DeltaFunctionCall( - name=function_name - ).model_dump(exclude_none=True), - ) - ] - ) - self.current_tool_name_sent = True - else: - delta = None - - # now we know we're on the same tool call and we're streaming - # arguments - else: - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments" - ) - cur_arguments = current_tool_call.get("arguments") - - new_text = delta_text.replace("'", '"') - if '"}' in new_text: - new_text = new_text[: new_text.rindex('"}')] - - if not cur_arguments and not prev_arguments: - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments reset mid-arguments" - ) - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[ - :-2 - ] - logger.debug("finding %s in %s", new_text, cur_arguments_json) - - if new_text not in cur_arguments_json: - return None - arguments_delta = cur_arguments_json[ - : cur_arguments_json.rindex(new_text) + len(new_text) - ] - logger.debug( - "First tokens in arguments received: %s", arguments_delta - ) - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) - logger.debug( - "Searching for diff between \n%s\n%s", - cur_args_json, - prev_args_json, - ) - - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json - ) - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += argument_diff - else: - # try parsing it with regular JSON - if it works we're - # at the end, and we need to send the difference between - # tokens streamed so far and the valid JSON - delta = None - - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - self.prev_tool_call_arr = tool_call_arr - return delta - - except Exception: - logger.exception("Error trying to handle streaming tool call.") - logger.debug( - "Skipping chunk as a result of tool streaming extraction error" - ) - return None diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index d6d3825daf7bb..e166405a6f05a 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -72,11 +72,7 @@ class ClassificationMixin(OpenAIServing): if ret: return ret - ( - _, - _, - engine_prompts, - ) = await self._preprocess_chat( + _, engine_prompts = await self._preprocess_chat( cast(ChatCompletionRequest, chat_request), ctx.tokenizer, messages, diff --git a/vllm/entrypoints/pooling/embed/api_router.py b/vllm/entrypoints/pooling/embed/api_router.py index 5b10a32e79f81..24b0c8c2b3cf6 100644 --- a/vllm/entrypoints/pooling/embed/api_router.py +++ b/vllm/entrypoints/pooling/embed/api_router.py @@ -59,8 +59,8 @@ async def create_embedding( return JSONResponse(content=generator.model_dump()) elif isinstance(generator, EmbeddingBytesResponse): return StreamingResponse( - content=generator.body, - headers={"metadata": generator.metadata}, + content=generator.content, + headers=generator.headers, media_type=generator.media_type, ) diff --git a/vllm/entrypoints/pooling/embed/protocol.py b/vllm/entrypoints/pooling/embed/protocol.py index 7eb53e14d5d8a..6a8f8c4434e55 100644 --- a/vllm/entrypoints/pooling/embed/protocol.py +++ b/vllm/entrypoints/pooling/embed/protocol.py @@ -203,6 +203,6 @@ class EmbeddingResponse(OpenAIBaseModel): class EmbeddingBytesResponse(OpenAIBaseModel): - body: list[bytes] - metadata: str + content: list[bytes] + headers: dict[str, str] | None = None media_type: str = "application/octet-stream" diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index 868a3cb017a6b..f5a21208ed802 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -20,7 +20,6 @@ from vllm.entrypoints.openai.serving_engine import ( EmbeddingServeContext, OpenAIServing, ServeContext, - TextTokensPrompt, ) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.pooling.embed.protocol import ( @@ -32,7 +31,7 @@ from vllm.entrypoints.pooling.embed.protocol import ( EmbeddingResponseData, ) from vllm.entrypoints.renderer import RenderConfig -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.outputs import ( EmbeddingRequestOutput, @@ -83,11 +82,7 @@ class EmbeddingMixin(OpenAIServing): renderer = self._get_renderer(tokenizer) if isinstance(ctx.request, EmbeddingChatRequest): - ( - _, - _, - ctx.engine_prompts, - ) = await self._preprocess_chat( + _, ctx.engine_prompts = await self._preprocess_chat( ctx.request, tokenizer, ctx.request.messages, @@ -163,29 +158,35 @@ class EmbeddingMixin(OpenAIServing): usage=usage, ) - def encode_bytes(): - body, items, usage = encode_pooling_bytes( + def encode_bytes(bytes_only: bool) -> EmbeddingBytesResponse: + content, items, usage = encode_pooling_bytes( pooling_outputs=final_res_batch_checked, embed_dtype=embed_dtype, endianness=endianness, ) - metadata = { - "id": ctx.request_id, - "created": ctx.created_time, - "model": ctx.model_name, - "data": items, - "usage": usage, - } - return EmbeddingBytesResponse( - body=body, - metadata=json.dumps(metadata), + headers = ( + None + if bytes_only + else { + "metadata": json.dumps( + { + "id": ctx.request_id, + "created": ctx.created_time, + "model": ctx.model_name, + "data": items, + "usage": usage, + } + ) + } ) + return EmbeddingBytesResponse(content=content, headers=headers) + if encoding_format == "float" or encoding_format == "base64": return encode_float_base64() - elif encoding_format == "bytes": - return encode_bytes() + elif encoding_format == "bytes" or encoding_format == "bytes_only": + return encode_bytes(bytes_only=encoding_format == "bytes_only") else: assert_never(encoding_format) @@ -203,14 +204,13 @@ class EmbeddingMixin(OpenAIServing): async def _process_chunked_request( self, ctx: EmbeddingServeContext, - original_prompt: TextTokensPrompt, + token_ids: list[int], pooling_params, trace_headers, prompt_idx: int, ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: """Process a single prompt using chunked processing.""" generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - token_ids = original_prompt["prompt_token_ids"] # Split into chunks using max_position_embeddings max_pos_embeddings = self._get_max_position_embeddings() @@ -222,18 +222,12 @@ class EmbeddingMixin(OpenAIServing): chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}" # Create engine prompt for this chunk - chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens) - - # Create chunk request prompt for logging - chunk_text = "" - chunk_request_prompt = TextTokensPrompt( - prompt=chunk_text, prompt_token_ids=chunk_tokens - ) + chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens) # Log the chunk self._log_inputs( chunk_request_id, - chunk_request_prompt, + chunk_engine_prompt, params=pooling_params, lora_request=ctx.lora_request, ) @@ -257,7 +251,7 @@ class EmbeddingMixin(OpenAIServing): request, input_ids: list[int], input_text: str, - ) -> TextTokensPrompt: + ) -> TokensPrompt: """Override to support chunked processing for embedding requests.""" token_num = len(input_ids) @@ -322,23 +316,15 @@ class EmbeddingMixin(OpenAIServing): ) ) - return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # For other request types, use the parent's implementation return super()._validate_input(request, input_ids, input_text) - def _is_text_tokens_prompt(self, prompt) -> bool: - """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" - return ( - isinstance(prompt, dict) - and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt - ) - async def _create_single_prompt_generator( self, ctx: EmbeddingServeContext, - engine_prompt: EngineTokensPrompt, + engine_prompt: TokensPrompt, pooling_params: PoolingParams, trace_headers: Mapping[str, str] | None, prompt_index: int, @@ -407,14 +393,16 @@ class EmbeddingMixin(OpenAIServing): for i, engine_prompt in enumerate(ctx.engine_prompts): # Check if this specific prompt needs chunked processing - if self._is_text_tokens_prompt(engine_prompt): - # Cast to TextTokensPrompt since we've verified - # prompt_token_ids - text_tokens_prompt = cast(TextTokensPrompt, engine_prompt) - if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings: + if "prompt_token_ids" in engine_prompt: + prompt_token_ids = engine_prompt["prompt_token_ids"] + if len(prompt_token_ids) > max_pos_embeddings: # Use chunked processing for this prompt chunk_generators = await self._process_chunked_request( - ctx, text_tokens_prompt, pooling_params, trace_headers, i + ctx, + prompt_token_ids, + pooling_params, + trace_headers, + i, ) generators.extend(chunk_generators) continue @@ -572,14 +560,13 @@ class EmbeddingMixin(OpenAIServing): # Get original prompt token IDs for this prompt original_prompt = ctx.engine_prompts[prompt_idx] - if not self._is_text_tokens_prompt(original_prompt): + if "prompt_token_ids" not in original_prompt: return self.create_error_response( - f"Chunked prompt {prompt_idx} is not a TextTokensPrompt" + f"Chunked prompt {prompt_idx} does not contain " + "token IDs" ) - original_token_ids = cast(TextTokensPrompt, original_prompt)[ - "prompt_token_ids" - ] + original_token_ids = original_prompt["prompt_token_ids"] pooling_request_output = PoolingRequestOutput( request_id=aggregator["request_id"], diff --git a/vllm/entrypoints/pooling/pooling/api_router.py b/vllm/entrypoints/pooling/pooling/api_router.py index 674da94d126cf..4baaf8f30f6bb 100644 --- a/vllm/entrypoints/pooling/pooling/api_router.py +++ b/vllm/entrypoints/pooling/pooling/api_router.py @@ -55,8 +55,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) elif isinstance(generator, PoolingBytesResponse): return StreamingResponse( - content=generator.body, - headers={"metadata": generator.metadata}, + content=generator.content, + headers=generator.headers, media_type=generator.media_type, ) diff --git a/vllm/entrypoints/pooling/pooling/protocol.py b/vllm/entrypoints/pooling/pooling/protocol.py index 364cd93738b84..76b361b49b668 100644 --- a/vllm/entrypoints/pooling/pooling/protocol.py +++ b/vllm/entrypoints/pooling/pooling/protocol.py @@ -143,6 +143,6 @@ class PoolingResponse(OpenAIBaseModel): class PoolingBytesResponse(OpenAIBaseModel): - body: list[bytes] - metadata: str + content: list[bytes] + headers: dict[str, str] | None = None media_type: str = "application/octet-stream" diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 7fb767e26d019..4e1b326806eae 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -137,11 +137,8 @@ class OpenAIServingPooling(OpenAIServing): ) if error_check_ret is not None: return error_check_ret - ( - _, - _, - engine_prompts, - ) = await self._preprocess_chat( + + _, engine_prompts = await self._preprocess_chat( request, tokenizer, request.messages, @@ -314,29 +311,38 @@ class OpenAIServingPooling(OpenAIServing): usage=usage, ) - def encode_bytes(): - body, items, usage = encode_pooling_bytes( + def encode_bytes(bytes_only: bool) -> PoolingBytesResponse: + content, items, usage = encode_pooling_bytes( pooling_outputs=final_res_batch, embed_dtype=embed_dtype, endianness=endianness, ) - metadata = { - "id": request_id, - "created": created_time, - "model": model_name, - "data": items, - "usage": usage, - } + headers = ( + None + if bytes_only + else { + "metadata": json.dumps( + { + "id": request_id, + "created": created_time, + "model": model_name, + "data": items, + "usage": usage, + } + ) + } + ) + return PoolingBytesResponse( - body=body, - metadata=json.dumps(metadata), + content=content, + headers=headers, ) if encoding_format == "float" or encoding_format == "base64": return encode_float_base64() - elif encoding_format == "bytes": - return encode_bytes() + elif encoding_format == "bytes" or encoding_format == "bytes_only": + return encode_bytes(bytes_only=encoding_format == "bytes_only") else: assert_never(encoding_format) diff --git a/vllm/entrypoints/pooling/score/protocol.py b/vllm/entrypoints/pooling/score/protocol.py index a22219707c357..e81bda2eec3d7 100644 --- a/vllm/entrypoints/pooling/score/protocol.py +++ b/vllm/entrypoints/pooling/score/protocol.py @@ -120,6 +120,7 @@ class RerankResult(BaseModel): class RerankUsage(BaseModel): + prompt_tokens: int total_tokens: int diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index e5a66783005a6..edbfcd03ac92c 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -38,7 +38,8 @@ from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.async_utils import make_async, merge_async_iterators logger = init_logger(__name__) @@ -501,5 +502,7 @@ class ServingScores(OpenAIServing): id=request_id, model=model_name, results=results, - usage=RerankUsage(total_tokens=num_prompt_tokens), + usage=RerankUsage( + total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens + ), ) diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index 10b90bbbb0f32..0f89c840be80f 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -12,9 +12,7 @@ import torch from pydantic import Field from vllm.config import ModelConfig -from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt -from vllm.inputs.data import TextPrompt as EngineTextPrompt -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.tokenizers import TokenizerLike from vllm.utils.async_utils import AsyncMicrobatchTokenizer @@ -33,7 +31,7 @@ class RenderConfig: `0` yields an empty list (and skips embeds). `-1` maps to `model_config.max_model_len`.""" - add_special_tokens: bool | None = True + add_special_tokens: bool = True """Whether to add model-specific special tokens during tokenization.""" cache_salt: str | None = None @@ -97,7 +95,7 @@ class BaseRenderer(ABC): *, prompt_or_prompts: str | list[str] | list[int] | list[list[int]], config: RenderConfig, - ) -> list[EngineTokensPrompt]: + ) -> list[TokensPrompt]: """ Convert text or token inputs into engine-ready TokensPrompt objects. @@ -115,7 +113,7 @@ class BaseRenderer(ABC): (e.g., tokenization and length handling). Returns: - list[EngineTokensPrompt]: Engine-ready token prompts. + list[TokensPrompt]: Engine-ready token prompts. Raises: ValueError: If input formats are invalid or length limits exceeded. @@ -129,7 +127,7 @@ class BaseRenderer(ABC): prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, prompt_embeds: bytes | list[bytes] | None = None, config: RenderConfig, - ) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: + ) -> list[TokensPrompt | EmbedsPrompt]: """ Convert text/token and/or base64-encoded embeddings inputs into engine-ready prompt objects using a unified RenderConfig. @@ -146,7 +144,7 @@ class BaseRenderer(ABC): (e.g., tokenization and length handling). Returns: - list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + list[Union[TokensPrompt, EmbedsPrompt]]: Engine-ready prompt objects. Raises: @@ -161,31 +159,34 @@ class BaseRenderer(ABC): prompt_embeds: bytes | list[bytes], truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None, cache_salt: str | None = None, - ) -> list[EngineEmbedsPrompt]: + ) -> list[EmbedsPrompt]: """Load and validate base64-encoded embeddings into prompt objects.""" if not self.model_config.enable_prompt_embeds: raise ValueError( "You must set `--enable-prompt-embeds` to input `prompt_embeds`." ) - def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: - tensor = torch.load( - io.BytesIO(pybase64.b64decode(embed, validate=True)), - weights_only=True, - map_location=torch.device("cpu"), - ) - assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( - torch.float32, - torch.bfloat16, - torch.float16, - ) - tensor = tensor.to_dense() + def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load( + io.BytesIO(pybase64.b64decode(embed, validate=True)), + weights_only=True, + map_location=torch.device("cpu"), + ) + assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( + torch.float32, + torch.bfloat16, + torch.float16, + ) + tensor = tensor.to_dense() if tensor.dim() > 2: tensor = tensor.squeeze(0) assert tensor.dim() == 2 if truncate_prompt_tokens is not None: tensor = tensor[-truncate_prompt_tokens:] - embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor) + embeds_prompt = EmbedsPrompt(prompt_embeds=tensor) if cache_salt is not None: embeds_prompt["cache_salt"] = cache_salt return embeds_prompt @@ -213,7 +214,7 @@ class CompletionRenderer(BaseRenderer): *, prompt_or_prompts: str | list[str] | list[int] | list[list[int]], config: RenderConfig, - ) -> list[EngineTokensPrompt]: + ) -> list[TokensPrompt]: """Implementation of prompt rendering for completion-style requests. Uses async tokenizer pooling for improved performance. See base class @@ -240,7 +241,7 @@ class CompletionRenderer(BaseRenderer): prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, prompt_embeds: bytes | list[bytes] | None = None, config: RenderConfig, - ) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: + ) -> list[TokensPrompt | EmbedsPrompt]: """ Render text/token prompts and/or precomputed embedding prompts. At least one of `prompt_or_prompts` or `prompt_embeds` must be provided. @@ -249,7 +250,7 @@ class CompletionRenderer(BaseRenderer): if truncate_prompt_tokens == 0: return [] - rendered: list[EngineTokensPrompt | EngineEmbedsPrompt] = [] + rendered: list[TokensPrompt | EmbedsPrompt] = [] if prompt_embeds is not None: rendered.extend( @@ -281,10 +282,10 @@ class CompletionRenderer(BaseRenderer): async def _create_prompt( self, - prompt_input: EngineTextPrompt | EngineTokensPrompt, + prompt_input: TextPrompt | TokensPrompt, config: RenderConfig, truncate_prompt_tokens: int | None, - ) -> EngineTokensPrompt: + ) -> TokensPrompt: prompt, prompt_token_ids, _ = get_prompt_components(prompt_input) if prompt_token_ids is not None: @@ -315,9 +316,9 @@ class CompletionRenderer(BaseRenderer): text: str, max_length: int | None, truncate_prompt_tokens: int | None, - add_special_tokens: bool | None, + add_special_tokens: bool, cache_salt: str | None, - ) -> EngineTokensPrompt: + ) -> TokensPrompt: """Tokenize text input asynchronously.""" async_tokenizer = self._get_async_tokenizer() @@ -350,7 +351,7 @@ class CompletionRenderer(BaseRenderer): truncate_prompt_tokens: int | None, cache_salt: str | None, needs_detokenization: bool | None = False, - ) -> EngineTokensPrompt: + ) -> TokensPrompt: """Optionally detokenize token IDs and build a tokens prompt.""" token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens) @@ -392,8 +393,8 @@ class CompletionRenderer(BaseRenderer): max_length: int | None = None, cache_salt: str | None = None, prompt: str | None = None, - ) -> EngineTokensPrompt: - """Create validated EngineTokensPrompt.""" + ) -> TokensPrompt: + """Create validated TokensPrompt.""" if max_length is not None and len(token_ids) > max_length: raise ValueError( f"This model's maximum context length is {max_length} tokens. " @@ -401,7 +402,7 @@ class CompletionRenderer(BaseRenderer): "Please reduce the length of the input messages." ) - tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) + tokens_prompt = TokensPrompt(prompt_token_ids=token_ids) if cache_salt is not None: tokens_prompt["cache_salt"] = cache_salt if prompt is not None: diff --git a/vllm/entrypoints/responses_utils.py b/vllm/entrypoints/responses_utils.py index 07abb80ebc9e3..df3d0495755da 100644 --- a/vllm/entrypoints/responses_utils.py +++ b/vllm/entrypoints/responses_utils.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionMessageToolCallParam, @@ -10,6 +12,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import ( Function as FunctionCallTool, ) from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem +from openai.types.responses.response import ToolChoice from openai.types.responses.response_function_tool_call_output_item import ( ResponseFunctionToolCallOutputItem, ) @@ -18,6 +21,7 @@ from openai.types.responses.response_reasoning_item import ResponseReasoningItem from openai.types.responses.tool import Tool from vllm import envs +from vllm.entrypoints.constants import MCP_PREFIX from vllm.entrypoints.openai.protocol import ( ChatCompletionMessageParam, ResponseInputOutputItem, @@ -62,12 +66,63 @@ def construct_input_messages( if isinstance(request_input, str): messages.append({"role": "user", "content": request_input}) else: - for item in request_input: - messages.append(construct_chat_message_with_tool_call(item)) + input_messages = construct_chat_messages_with_tool_call(request_input) + messages.extend(input_messages) return messages -def construct_chat_message_with_tool_call( +def _maybe_combine_reasoning_and_tool_call( + item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam] +) -> ChatCompletionMessageParam | None: + """Many models treat MCP calls and reasoning as a single message. + This function checks if the last message is a reasoning message and + the current message is a tool call""" + if not ( + isinstance(item, ResponseFunctionToolCall) and item.id.startswith(MCP_PREFIX) + ): + return None + if len(messages) == 0: + return None + last_message = messages[-1] + if not ( + last_message.get("role") == "assistant" + and last_message.get("reasoning") is not None + ): + return None + + last_message["tool_calls"] = [ + ChatCompletionMessageToolCallParam( + id=item.call_id, + function=FunctionCallTool( + name=item.name, + arguments=item.arguments, + ), + type="function", + ) + ] + return last_message + + +def construct_chat_messages_with_tool_call( + input_messages: list[ResponseInputOutputItem], +) -> list[ChatCompletionMessageParam]: + """This function wraps _construct_single_message_from_response_item + Because some chatMessages come from multiple response items + for example a reasoning item and a MCP tool call are two response items + but are one chat message + """ + messages: list[ChatCompletionMessageParam] = [] + for item in input_messages: + maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages) + if maybe_combined_message is not None: + messages[-1] = maybe_combined_message + else: + messages.append(_construct_single_message_from_response_item(item)) + + return messages + + +def _construct_single_message_from_response_item( item: ResponseInputOutputItem, ) -> ChatCompletionMessageParam: if isinstance(item, ResponseFunctionToolCall): @@ -97,13 +152,18 @@ def construct_chat_message_with_tool_call( "role": "assistant", "reasoning": reasoning_content, } + elif isinstance(item, ResponseOutputMessage): + return { + "role": "assistant", + "content": item.content[0].text, + } elif isinstance(item, ResponseFunctionToolCallOutputItem): return ChatCompletionToolMessageParam( role="tool", content=item.output, tool_call_id=item.call_id, ) - elif item.get("type") == "function_call_output": + elif isinstance(item, dict) and item.get("type") == "function_call_output": # Append the function call output as a tool message. return ChatCompletionToolMessageParam( role="tool", @@ -141,3 +201,16 @@ def convert_tool_responses_to_completions_format(tool: dict) -> dict: "type": "function", "function": tool, } + + +def construct_tool_dicts( + tools: list[Tool], tool_choice: ToolChoice +) -> list[dict[str, Any]] | None: + if tools is None or (tool_choice == "none"): + tool_dicts = None + else: + tool_dicts = [ + convert_tool_responses_to_completions_format(tool.model_dump()) + for tool in tools + ] + return tool_dicts diff --git a/vllm/entrypoints/sagemaker/routes.py b/vllm/entrypoints/sagemaker/routes.py index 108fdd773e321..ea88c0fc4b979 100644 --- a/vllm/entrypoints/sagemaker/routes.py +++ b/vllm/entrypoints/sagemaker/routes.py @@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import ( completion, create_chat_completion, create_completion, - health, validate_json_request, ) from vllm.entrypoints.openai.protocol import ( @@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import ( score, ) from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest +from vllm.entrypoints.serve.instrumentator.health import health # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 602f59ac09f55..072ddd4c90b16 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -19,7 +19,7 @@ from vllm.inputs import TokensPrompt from vllm.model_executor.models.interfaces import supports_score_template from vllm.multimodal.inputs import MultiModalDataDict from vllm.outputs import PoolingRequestOutput -from vllm.transformers_utils.tokenizer import TokenizerLike +from vllm.tokenizers import TokenizerLike ScoreContentPartParam: TypeAlias = ( ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam @@ -89,12 +89,10 @@ def parse_score_data( data_1: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam, model_config: ModelConfig, - tokenizer: TokenizerLike, ) -> tuple[str, str, MultiModalDataDict | None]: - mm_tracker = MultiModalItemTracker(model_config, tokenizer) + mm_tracker = MultiModalItemTracker(model_config) content_1 = _parse_score_content(data_1, mm_tracker) - content_2 = _parse_score_content(data_2, mm_tracker) def ensure_str(content: _ContentPart | None) -> str: @@ -188,7 +186,6 @@ def get_score_prompt( data_1, data_2, model_config, - tokenizer, ) from vllm.model_executor.model_loader import get_model_cls diff --git a/vllm/entrypoints/serve/__init__.py b/vllm/entrypoints/serve/__init__.py new file mode 100644 index 0000000000000..c4fcc92db931f --- /dev/null +++ b/vllm/entrypoints/serve/__init__.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import FastAPI + + +def register_vllm_serve_api_routers(app: FastAPI): + from vllm.entrypoints.serve.lora.api_router import ( + attach_router as attach_lora_router, + ) + + attach_lora_router(app) + from vllm.entrypoints.serve.elastic_ep.api_router import ( + attach_router as attach_elastic_ep_router, + ) + + attach_elastic_ep_router(app) + + from vllm.entrypoints.serve.profile.api_router import ( + attach_router as attach_profile_router, + ) + + attach_profile_router(app) + + from vllm.entrypoints.serve.sleep.api_router import ( + attach_router as attach_sleep_router, + ) + + attach_sleep_router(app) + + from vllm.entrypoints.serve.tokenize.api_router import ( + attach_router as attach_tokenize_router, + ) + + attach_tokenize_router(app) + + from vllm.entrypoints.serve.disagg.api_router import ( + attach_router as attach_disagg_router, + ) + + attach_disagg_router(app) + + from vllm.entrypoints.serve.rlhf.api_router import ( + attach_router as attach_rlhf_router, + ) + + attach_rlhf_router(app) + + from vllm.entrypoints.serve.instrumentator.metrics import ( + attach_router as attach_metrics_router, + ) + + attach_metrics_router(app) + + from vllm.entrypoints.serve.instrumentator.health import ( + attach_router as attach_health_router, + ) + + attach_health_router(app) diff --git a/vllm/entrypoints/serve/disagg/__init__.py b/vllm/entrypoints/serve/disagg/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/disagg/api_router.py b/vllm/entrypoints/serve/disagg/api_router.py new file mode 100644 index 0000000000000..c38ede30dad1c --- /dev/null +++ b/vllm/entrypoints/serve/disagg/api_router.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import asyncio +import json +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.api_server import validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, +) +from vllm.entrypoints.serve.disagg.protocol import ( + GenerateRequest, + GenerateResponse, +) +from vllm.entrypoints.serve.disagg.serving import ( + ServingTokens, +) +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization +from vllm.entrypoints.utils import ( + load_aware_call, + with_cancellation, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def generate_tokens(request: Request) -> ServingTokens | None: + return request.app.state.serving_tokens + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post( + "/inference/v1/generate", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def generate(request: GenerateRequest, raw_request: Request): + handler = generate_tokens(raw_request) + if handler is None: + return tokenization(raw_request).create_error_response( + message="The model does not support generate tokens API" + ) + try: + generator = await handler.serve_tokens(request, raw_request) + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + + elif isinstance(generator, GenerateResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +def attach_router(app: FastAPI): + if getattr(app.state.args, "tokens_only", False): + + @router.post("/abort_requests") + async def abort_requests(raw_request: Request): + """ + Abort one or more requests. To be used in a + Disaggregated Everything setup. + """ + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + request_ids = body.get("request_ids") + if request_ids is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'request_ids' in request body", + ) + # Abort requests in background + asyncio.create_task(engine_client(raw_request).abort(request_ids)) + return Response(status_code=200) + + app.include_router(router) diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py new file mode 100644 index 0000000000000..251fcf12ed7dd --- /dev/null +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from pydantic import BaseModel, Field + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProbs, + Logprob, + SamplingParams, + StreamOptions, +) +from vllm.utils import random_uuid + + +####### Tokens IN <> Tokens OUT ####### +class GenerateRequest(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + token_ids: list[int] + """The token ids to generate text from.""" + + # features: MultiModalFeatureSpec + # TODO (NickLucche): implement once Renderer work is completed + features: str | None = None + """The processed MM inputs for the model.""" + + sampling_params: SamplingParams + """The sampling parameters for the model.""" + + model: str | None = None + + stream: bool | None = False + stream_options: StreamOptions | None = None + cache_salt: str | None = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit)." + ), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling." + ), + ) + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) + + +class GenerateResponseChoice(BaseModel): + index: int + logprobs: ChatCompletionLogProbs | None = None + # per OpenAI spec this is the default + finish_reason: str | None = "stop" + token_ids: list[int] | None = None + + +class GenerateResponse(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + choices: list[GenerateResponseChoice] + + prompt_logprobs: list[dict[int, Logprob] | None] | None = None + + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) diff --git a/vllm/entrypoints/openai/serving_tokens.py b/vllm/entrypoints/serve/disagg/serving.py similarity index 97% rename from vllm/entrypoints/openai/serving_tokens.py rename to vllm/entrypoints/serve/disagg/serving.py index daa739e41fa07..1798b174b1413 100644 --- a/vllm/entrypoints/openai/serving_tokens.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + import asyncio import time from collections.abc import AsyncGenerator @@ -14,16 +16,18 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProbs, ChatCompletionLogProbsContent, ErrorResponse, - GenerateRequest, - GenerateResponse, - GenerateResponseChoice, PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo, ) from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.entrypoints.serve.disagg.protocol import ( + GenerateRequest, + GenerateResponse, + GenerateResponseChoice, +) +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import RequestOutput @@ -95,7 +99,7 @@ class ServingTokens(OpenAIServing): # TODO(NickLucche): Change to EngineCoreRequest once Renderer work is # completed - engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids) + engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids) if request.features is not None: engine_prompt["multi_modal_data"] = None @@ -111,7 +115,7 @@ class ServingTokens(OpenAIServing): self._log_inputs( request_id, - request.token_ids, + TokensPrompt(prompt_token_ids=request.token_ids), params=sampling_params, lora_request=lora_request, ) diff --git a/vllm/entrypoints/serve/elastic_ep/__init__.py b/vllm/entrypoints/serve/elastic_ep/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/elastic_ep/api_router.py b/vllm/entrypoints/serve/elastic_ep/api_router.py new file mode 100644 index 0000000000000..21d5d2e60778a --- /dev/null +++ b/vllm/entrypoints/serve/elastic_ep/api_router.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import json +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.api_server import validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, +) +from vllm.entrypoints.serve.elastic_ep.middleware import ( + get_scaling_elastic_ep, + set_scaling_elastic_ep, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post( + "/scale_elastic_ep", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def scale_elastic_ep(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 + + new_data_parallel_size = body.get("new_data_parallel_size") + drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes + + if new_data_parallel_size is None: + raise HTTPException( + status_code=400, detail="new_data_parallel_size is required" + ) + + if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0: + raise HTTPException( + status_code=400, + detail="new_data_parallel_size must be a positive integer", + ) + + if not isinstance(drain_timeout, int) or drain_timeout <= 0: + raise HTTPException( + status_code=400, detail="drain_timeout must be a positive integer" + ) + + # Set scaling flag to prevent new requests + set_scaling_elastic_ep(True) + client = engine_client(raw_request) + try: + await client.scale_elastic_ep(new_data_parallel_size, drain_timeout) + return JSONResponse( + { + "message": f"Scaled to {new_data_parallel_size} data parallel engines", + } + ) + except TimeoutError as e: + raise HTTPException( + status_code=408, + detail="Scale failed due to request drain timeout " + f"after {drain_timeout} seconds", + ) from e + except Exception as e: + logger.error("Scale failed: %s", e) + raise HTTPException(status_code=500, detail="Scale failed") from e + finally: + set_scaling_elastic_ep(False) + + +@router.post("/is_scaling_elastic_ep") +async def is_scaling_elastic_ep(raw_request: Request): + return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()}) + + +def attach_router(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/serve/elastic_ep/middleware.py b/vllm/entrypoints/serve/elastic_ep/middleware.py new file mode 100644 index 0000000000000..23f45eafeaa0d --- /dev/null +++ b/vllm/entrypoints/serve/elastic_ep/middleware.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Awaitable + +from fastapi.responses import JSONResponse +from starlette.types import ASGIApp, Receive, Scope, Send + +# Global variable to track scaling state +_scaling_elastic_ep = False + + +def get_scaling_elastic_ep(): + return _scaling_elastic_ep + + +def set_scaling_elastic_ep(value): + global _scaling_elastic_ep + _scaling_elastic_ep = value + + +class ScalingMiddleware: + """ + Middleware that checks if the model is currently scaling and + returns a 503 Service Unavailable response if it is. + + This middleware applies to all HTTP requests and prevents + processing when the model is in a scaling state. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: + if scope["type"] != "http": + return self.app(scope, receive, send) + + # Check global scaling state + if get_scaling_elastic_ep(): + # Return 503 Service Unavailable response + response = JSONResponse( + content={ + "error": "The model is currently scaling. Please try again later." + }, + status_code=503, + ) + return response(scope, receive, send) + + return self.app(scope, receive, send) diff --git a/vllm/entrypoints/serve/instrumentator/__init__.py b/vllm/entrypoints/serve/instrumentator/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/instrumentator/health.py b/vllm/entrypoints/serve/instrumentator/health.py new file mode 100644 index 0000000000000..029ef677aaa25 --- /dev/null +++ b/vllm/entrypoints/serve/instrumentator/health.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, Request +from fastapi.responses import Response + +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger +from vllm.v1.engine.exceptions import EngineDeadError + +logger = init_logger(__name__) + + +router = APIRouter() + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health", response_class=Response) +async def health(raw_request: Request) -> Response: + """Health check.""" + try: + await engine_client(raw_request).check_health() + return Response(status_code=200) + except EngineDeadError: + return Response(status_code=503) + + +def attach_router(app): + app.include_router(router) diff --git a/vllm/entrypoints/serve/instrumentator/metrics.py b/vllm/entrypoints/serve/instrumentator/metrics.py new file mode 100644 index 0000000000000..5231451383a2b --- /dev/null +++ b/vllm/entrypoints/serve/instrumentator/metrics.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import prometheus_client +import regex as re +from fastapi import FastAPI, Response +from prometheus_client import make_asgi_app +from prometheus_fastapi_instrumentator import Instrumentator +from starlette.routing import Mount + +from vllm.v1.metrics.prometheus import get_prometheus_registry + + +class PrometheusResponse(Response): + media_type = prometheus_client.CONTENT_TYPE_LATEST + + +def attach_router(app: FastAPI): + """Mount prometheus metrics to a FastAPI app.""" + + registry = get_prometheus_registry() + + # `response_class=PrometheusResponse` is needed to return an HTTP response + # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" + # instead of the default "application/json" which is incorrect. + # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364 + Instrumentator( + excluded_handlers=[ + "/metrics", + "/health", + "/load", + "/ping", + "/version", + "/server_info", + ], + registry=registry, + ).add().instrument(app).expose(app, response_class=PrometheusResponse) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$") + app.routes.append(metrics_route) diff --git a/vllm/entrypoints/serve/lora/__init__.py b/vllm/entrypoints/serve/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/dynamic_lora.py b/vllm/entrypoints/serve/lora/api_router.py similarity index 80% rename from vllm/entrypoints/dynamic_lora.py rename to vllm/entrypoints/serve/lora/api_router.py index cc0f437e5c77f..6a57e73f334f2 100644 --- a/vllm/entrypoints/dynamic_lora.py +++ b/vllm/entrypoints/serve/lora/api_router.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + import model_hosting_container_standards.sagemaker as sagemaker_standards -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, FastAPI, Request from fastapi.responses import JSONResponse, Response +from vllm import envs from vllm.entrypoints.openai.api_server import models, validate_json_request from vllm.entrypoints.openai.protocol import ( ErrorResponse, @@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger logger = init_logger(__name__) +router = APIRouter() -def register_dynamic_lora_routes(router: APIRouter): +def attach_router(app: FastAPI): + if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + """If LoRA dynamic loading & unloading is not enabled, do nothing.""" + return + logger.warning( + "LoRA dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!" + ) + @sagemaker_standards.register_load_adapter_handler( request_shape={ "lora_name": "body.name", @@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter): return Response(status_code=200, content=response) - return router + # register the router + app.include_router(router) diff --git a/vllm/entrypoints/serve/profile/__init__.py b/vllm/entrypoints/serve/profile/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/profile/api_router.py b/vllm/entrypoints/serve/profile/api_router.py new file mode 100644 index 0000000000000..eeed6b45ef4e9 --- /dev/null +++ b/vllm/entrypoints/serve/profile/api_router.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import Response + +from vllm.config import ProfilerConfig +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger + +logger = init_logger(__name__) + +router = APIRouter() + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.post("/start_profile") +async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + +@router.post("/stop_profile") +async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +def attach_router(app: FastAPI): + profiler_config = getattr(app.state.args, "profiler_config", None) + assert profiler_config is None or isinstance(profiler_config, ProfilerConfig) + if profiler_config is not None and profiler_config.profiler is not None: + logger.warning_once( + "Profiler with mode '%s' is enabled in the " + "API server. This should ONLY be used for local development!", + profiler_config.profiler, + ) + app.include_router(router) diff --git a/vllm/entrypoints/serve/rlhf/__init__.py b/vllm/entrypoints/serve/rlhf/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py new file mode 100644 index 0000000000000..3b37840ae0899 --- /dev/null +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from http import HTTPStatus + +from fastapi import APIRouter, FastAPI, Query, Request +from fastapi.responses import JSONResponse + +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post("/pause") +async def pause_generation( + raw_request: Request, + wait_for_inflight_requests: bool = Query(False), + clear_cache: bool = Query(True), +) -> JSONResponse: + """Pause generation requests to allow weight updates. + + Args: + wait_for_inflight_requests: When ``True`` waits for in-flight + requests to finish before pausing. When ``False`` (default), + aborts any in-flight requests immediately. + clear_cache: Whether to clear KV/prefix caches after draining. + """ + + engine = engine_client(raw_request) + + try: + await engine.pause_generation( + wait_for_inflight_requests=wait_for_inflight_requests, + clear_cache=clear_cache, + ) + return JSONResponse( + content={"status": "paused"}, + status_code=HTTPStatus.OK.value, + ) + + except ValueError as err: + return JSONResponse( + content={"error": str(err)}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to pause generation") + return JSONResponse( + content={"error": f"Failed to pause generation: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + +@router.post("/resume") +async def resume_generation(raw_request: Request) -> JSONResponse: + """Resume generation after a pause.""" + + engine = engine_client(raw_request) + + try: + await engine.resume_generation() + return JSONResponse( + content={"status": "resumed"}, + status_code=HTTPStatus.OK.value, + ) + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to resume generation") + return JSONResponse( + content={"error": f"Failed to resume generation: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + +@router.get("/is_paused") +async def is_paused(raw_request: Request) -> JSONResponse: + """Return the current pause status.""" + + engine = engine_client(raw_request) + + try: + paused = await engine.is_paused() + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to fetch pause status") + return JSONResponse( + content={"error": f"Failed to fetch pause status: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + return JSONResponse(content={"is_paused": paused}) + + +def attach_router(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/serve/sleep/__init__.py b/vllm/entrypoints/serve/sleep/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/sleep/api_router.py b/vllm/entrypoints/serve/sleep/api_router.py new file mode 100644 index 0000000000000..bc01e185315c8 --- /dev/null +++ b/vllm/entrypoints/serve/sleep/api_router.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import JSONResponse, Response + +import vllm.envs as envs +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post("/sleep") +async def sleep(raw_request: Request): + # get POST params + level = raw_request.query_params.get("level", "1") + await engine_client(raw_request).sleep(int(level)) + # FIXME: in v0 with frontend multiprocessing, the sleep command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + +@router.post("/wake_up") +async def wake_up(raw_request: Request): + tags = raw_request.query_params.getlist("tags") + if tags == []: + # set to None to wake up all tags if no tags are provided + tags = None + logger.info("wake up the engine with tags: %s", tags) + await engine_client(raw_request).wake_up(tags) + # FIXME: in v0 with frontend multiprocessing, the wake-up command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + +@router.get("/is_sleeping") +async def is_sleeping(raw_request: Request): + logger.info("check whether the engine is sleeping") + is_sleeping = await engine_client(raw_request).is_sleeping() + return JSONResponse(content={"is_sleeping": is_sleeping}) + + +def attach_router(app: FastAPI): + if not envs.VLLM_SERVER_DEV_MODE: + return + logger.warning( + "SECURITY WARNING: Development endpoints are enabled! " + "This should NOT be used in production!" + ) + + app.include_router(router) diff --git a/vllm/entrypoints/serve/tokenize/__init__.py b/vllm/entrypoints/serve/tokenize/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/serve/tokenize/api_router.py b/vllm/entrypoints/serve/tokenize/api_router.py new file mode 100644 index 0000000000000..a10e78c8d28ee --- /dev/null +++ b/vllm/entrypoints/serve/tokenize/api_router.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from typing_extensions import assert_never + +from vllm.entrypoints.openai.api_server import validate_json_request +from vllm.entrypoints.openai.protocol import ( + DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeRequest, + TokenizeResponse, +) +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization +from vllm.entrypoints.utils import ( + with_cancellation, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +router = APIRouter() + + +@router.post( + "/tokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +async def tokenize(request: TokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + try: + generator = await handler.create_tokenize(request, raw_request) + except NotImplementedError as e: + raise HTTPException( + status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e) + ) from e + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + elif isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post( + "/detokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +async def detokenize(request: DetokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + try: + generator = await handler.create_detokenize(request, raw_request) + except OverflowError as e: + raise RequestValidationError(errors=[str(e)]) from e + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + elif isinstance(generator, DetokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +def attach_router(app: FastAPI): + if getattr(app.state.args, "enable_tokenizer_info_endpoint", False): + """Conditionally register the tokenizer info endpoint if enabled.""" + + @router.get("/tokenizer_info") + async def get_tokenizer_info(raw_request: Request): + """Get comprehensive tokenizer information.""" + result = await tokenization(raw_request).get_tokenizer_info() + return JSONResponse( + content=result.model_dump(), + status_code=result.error.code + if isinstance(result, ErrorResponse) + else 200, + ) + + app.include_router(router) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/serve/tokenize/serving.py similarity index 96% rename from vllm/entrypoints/openai/serving_tokenization.py rename to vllm/entrypoints/serve/tokenize/serving.py index 979da02d14500..0b07f0b18dfd5 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig +from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike @@ -80,11 +81,8 @@ class OpenAIServingTokenization(OpenAIServing): ) if error_check_ret is not None: return error_check_ret - ( - _, - _, - engine_prompts, - ) = await self._preprocess_chat( + + _, engine_prompts = await self._preprocess_chat( request, tokenizer, request.messages, @@ -141,7 +139,10 @@ class OpenAIServingTokenization(OpenAIServing): tokenizer = await self.engine_client.get_tokenizer() self._log_inputs( - request_id, request.tokens, params=None, lora_request=lora_request + request_id, + TokensPrompt(prompt_token_ids=request.tokens), + params=None, + lora_request=lora_request, ) prompt_input = await self._tokenize_prompt_input_async( diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py index c74ce1ee16de1..4feed827385d1 100644 --- a/vllm/entrypoints/tool.py +++ b/vllm/entrypoints/tool.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json import os from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from openai.types.responses.response_function_tool_call_output_item import ( + ResponseFunctionToolCallOutputItem, +) from openai_harmony import Author, Message, Role, TextContent from vllm.logger import init_logger +from vllm.utils import random_uuid if TYPE_CHECKING: # Avoid circular import. @@ -46,6 +51,10 @@ class Tool(ABC): async def get_result(self, context: "ConversationContext") -> Any: pass + @abstractmethod + async def get_result_parsable_context(self, context: "ConversationContext") -> Any: + pass + class HarmonyBrowserTool(Tool): def __init__(self): @@ -81,6 +90,9 @@ class HarmonyBrowserTool(Tool): tool_output_msgs.append(msg) return tool_output_msgs + async def get_result_parsable_context(self, context: "ConversationContext") -> Any: + raise NotImplementedError("Not implemented yet") + @property def tool_config(self) -> Any: return self.browser_tool.tool_config @@ -138,6 +150,38 @@ class HarmonyPythonTool(Tool): tool_output_msgs.append(msg) return tool_output_msgs + async def get_result_parsable_context(self, context: "ConversationContext") -> Any: + """ + This function converts parsable context types to harmony and + back so we can use GPTOSS demo python tool + """ + from vllm.entrypoints.context import ParsableContext + + assert isinstance(context, ParsableContext) + + last_msg = context.parser.response_messages[-1] + args = json.loads(last_msg.arguments) + + last_msg_harmony = Message( + author=Author(role="assistant", name=None), + content=[TextContent(text=args["code"])], + channel="analysis", + recipient="python", + content_type="code", + ) + + tool_output_msgs = [] + async for msg in self.python_tool.process(last_msg_harmony): + processed = ResponseFunctionToolCallOutputItem( + id=f"fco_{random_uuid()}", + type="function_call_output", + call_id=f"call_{random_uuid()}", + output=msg.content[0].text, + status="completed", + ) + tool_output_msgs.append(processed) + return tool_output_msgs + @property def tool_config(self) -> Any: return self.python_tool.tool_config diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index daeeb995bc749..f4a633c69cb0b 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.argparse_utils import FlexibleArgumentParser logger = init_logger(__name__) diff --git a/vllm/envs.py b/vllm/envs.py index 46f1aa3222be7..d0f2798096263 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: VLLM_DISABLE_FLASHINFER_PREFILL: bool = False VLLM_DO_NOT_TRACK: bool = False VLLM_USAGE_SOURCE: str = "" - VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_CONFIGURE_LOGGING: bool = True VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_STREAM: str = "ext://sys.stdout" @@ -72,14 +72,14 @@ if TYPE_CHECKING: VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MEDIA_CONNECTOR: str = "http" - VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_TARGET_DEVICE: str = "cuda" - VLLM_MAIN_CUDA_VERSION: str = "12.8" + VLLM_MAIN_CUDA_VERSION: str = "12.9" + VLLM_FLOAT32_MATMUL_PRECISION: Literal["ieee", "tf32"] = "ieee" MAX_JOBS: str | None = None NVCC_THREADS: str | None = None VLLM_USE_PRECOMPILED: bool = False + VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX: bool = False VLLM_DOCKER_BUILD_CONTEXT: bool = False - VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False CMAKE_BUILD_TYPE: Literal["Debug", "Release", "RelWithDebInfo"] | None = None VERBOSE: bool = False @@ -88,20 +88,23 @@ if TYPE_CHECKING: VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds VLLM_PLUGINS: list[str] | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None - VLLM_TORCH_CUDA_PROFILE: bool = False + # Deprecated env variables for profiling, kept for backward compatibility + # See also vllm/config/profiler.py and `--profiler-config` argument + VLLM_TORCH_CUDA_PROFILE: str | None = None VLLM_TORCH_PROFILER_DIR: str | None = None - VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False - VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False - VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False + VLLM_TORCH_PROFILER_RECORD_SHAPES: str | None = None + VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: str | None = None + VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: str | None = None + VLLM_TORCH_PROFILER_WITH_STACK: str | None = None + VLLM_TORCH_PROFILER_WITH_FLOPS: str | None = None + VLLM_TORCH_PROFILER_USE_GZIP: str | None = None + VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: str | None = None + VLLM_PROFILER_DELAY_ITERS: str | None = None + VLLM_PROFILER_MAX_ITERS: str | None = None + # End of deprecated env variables for profiling VLLM_USE_AOT_COMPILE: bool = False VLLM_USE_BYTECODE_HOOK: bool = False VLLM_FORCE_AOT_LOAD: bool = False - VLLM_TORCH_PROFILER_WITH_STACK: bool = True - VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False - VLLM_PROFILER_DELAY_ITERS: int = 0 - VLLM_PROFILER_MAX_ITERS: int = 0 - VLLM_TORCH_PROFILER_USE_GZIP: bool = True - VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: bool = True VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False @@ -144,6 +147,7 @@ if TYPE_CHECKING: VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MOE_DP_CHUNK_SIZE: int = 256 + VLLM_ENABLE_MOE_DP_CHUNK: bool = True VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_MARLIN_USE_ATOMIC_ADD: bool = False @@ -175,6 +179,7 @@ if TYPE_CHECKING: VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 + VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998 VLLM_ALL2ALL_BACKEND: Literal[ "naive", "pplx", @@ -197,6 +202,7 @@ if TYPE_CHECKING: VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 + VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False @@ -214,6 +220,7 @@ if TYPE_CHECKING: VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_TUNED_CONFIG_FOLDER: str | None = None VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set() + VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False @@ -232,6 +239,7 @@ if TYPE_CHECKING: VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" + VLLM_DEBUG_WORKSPACE: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" @@ -445,10 +453,19 @@ environment_variables: dict[str, Callable[[], Any]] = { # Target device of vLLM, supporting [cuda (by default), # rocm, cpu] "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), - # Main CUDA version of vLLM, supporting [12.6, 12.8, 12.9], - # 12.8 is the default. This follows PyTorch but can be overridden. + # Main CUDA version of vLLM. This follows PyTorch but can be overridden. "VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() - or "12.8", + or "12.9", + # Controls PyTorch float32 matmul precision mode within vLLM workers. + # Accepted values: + # - "ieee" (default): force full IEEE FP32 matmul precision. + # - "tf32": enable TensorFloat32-based fast matmul. + "VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices( + "VLLM_FLOAT32_MATMUL_PRECISION", + "ieee", + ["ieee", "tf32"], + case_sensitive=False, + ), # Maximum number of compilation jobs to run in parallel. # By default this is the number of CPUs "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), @@ -462,17 +479,16 @@ environment_variables: dict[str, Callable[[], Any]] = { .lower() in ("1", "true") or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), + # If set, skip adding +precompiled suffix to version string + "VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX": lambda: bool( + int(os.environ.get("VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX", "0")) + ), # Used to mark that setup.py is running in a Docker build context, # in order to force the use of precompiled binaries. "VLLM_DOCKER_BUILD_CONTEXT": lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "") .strip() .lower() in ("1", "true"), - # Whether to force using nightly wheel in python build. - # This is used for testing the nightly wheel in python build. - "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": lambda: bool( - int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) - ), # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" @@ -618,7 +634,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # If set to 0, vllm will not configure logging # If set to 1, vllm will configure logging using the default configuration # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH - "VLLM_CONFIGURE_LOGGING": lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_CONFIGURE_LOGGING": lambda: bool( + int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) + ), "VLLM_LOGGING_CONFIG_PATH": lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), # this is used for configuring the default logging level "VLLM_LOGGING_LEVEL": lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), @@ -770,9 +788,6 @@ environment_variables: dict[str, Callable[[], Any]] = { # imported at runtime. # If a non-existing backend is used, an AssertionError will be thrown. "VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"), - # [DEPRECATED] Cache size (in GiB per process) for multimodal input cache - # Default is 4 GiB per API process + 4 GiB per engine core process - "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. "VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser( @@ -837,71 +852,52 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( "VLLM_LORA_RESOLVER_CACHE_DIR", None ), - # Enables torch CUDA profiling if set. - # On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered. - "VLLM_TORCH_CUDA_PROFILE": lambda: bool( - os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0" - ), + # Enables torch CUDA profiling if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"), # Enables torch profiler if set. - # Both AsyncLLM's CPU traces as well as workers' - # traces (CPU & GPU) will be saved under this directory. - # Note that it must be an absolute path. - "VLLM_TORCH_PROFILER_DIR": lambda: ( - None - if (val := os.getenv("VLLM_TORCH_PROFILER_DIR")) is None - else ( - val - if val.startswith("gs://") and val[5:] and val[5] != "/" - else os.path.abspath(os.path.expanduser(val)) - ) + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"), + # Enable torch profiler to record shapes if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES") ), - # Enable torch profiler to record shapes if set - # VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will - # not record shapes. - "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0" + # Enable torch profiler to profile memory if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY") ), - # Enable torch profiler to profile memory if set - # VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler - # will not profile memory. - "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0" + # Enable torch profiler to profile stack if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_WITH_STACK": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_WITH_STACK") ), - # Enable torch profiler to profile stack if set - # VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL - # profile stack by default. - "VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0" + # Enable torch profiler to profile flops if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS") ), - # Enable torch profiler to profile flops if set - # VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will - # not profile flops. - "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0" - ), - # Disable torch profiling of the AsyncLLMEngine process. - # If set to 1, will not profile the engine process. - "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0" + # Disable torch profiling of the AsyncLLMEngine process if set to 1. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM") ), # Delay number of iterations before starting profiling when using # the torch/torch CUDA profiler. If set to 0, will start profiling immediately. - "VLLM_PROFILER_DELAY_ITERS": lambda: int( - os.getenv("VLLM_PROFILER_DELAY_ITERS", "0") - ), + # Deprecated, see profiler_config. + "VLLM_PROFILER_DELAY_ITERS": lambda: (os.getenv("VLLM_PROFILER_DELAY_ITERS")), # Maximum number of iterations to profile when using the torch/torch CUDA profiler. # If set to 0, will not limit the number of iterations. - "VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")), + "VLLM_PROFILER_MAX_ITERS": lambda: os.getenv("VLLM_PROFILER_MAX_ITERS"), # Control whether torch profiler gzip-compresses profiling files. - # Set VLLM_TORCH_PROFILER_USE_GZIP=0 to disable gzip (enabled by default). - "VLLM_TORCH_PROFILER_USE_GZIP": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_USE_GZIP", "1") != "0" - ), + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_USE_GZIP": lambda: os.getenv("VLLM_TORCH_PROFILER_USE_GZIP"), # Control whether torch profiler dumps the self_cuda_time_total table. - # Set VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0 to disable dumping - # (enabled by default). - "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", "1") != "0" + # Set to 0 to disable dumping the table. + # Deprecated, see profiler_config. + "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: ( + os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL") ), # If set, vLLM will use Triton implementations of AWQ. "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), @@ -1098,6 +1094,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE # units. "VLLM_MOE_DP_CHUNK_SIZE": lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), + "VLLM_ENABLE_MOE_DP_CHUNK": lambda: bool( + int(os.getenv("VLLM_ENABLE_MOE_DP_CHUNK", "1")) + ), # Randomize inputs during dummy runs when using Data Parallel "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: os.environ.get( "VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0" @@ -1259,6 +1258,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int( os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600") ), + # Port used for Mooncake handshake between remote agents. + "VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int( + os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998") + ), # all2all backend for vllm's expert parallel communication # Available options: # - "naive": naive all2all implementation using broadcasts @@ -1368,6 +1371,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") ), + # Timeout (in seconds) for MooncakeConnector in PD disaggregated setup. + "VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int( + os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480") + ), # Controls whether or not to use cudnn prefill "VLLM_USE_CUDNN_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) @@ -1445,6 +1452,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) ), + # Experimental: use this to enable MCP tool calling for non harmony models + "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool( + int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0")) + ), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), # Valid values are container,code_interpreter,web_search_preview @@ -1527,6 +1538,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # top 5 collected objects "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), + # Debug workspace allocations. + # logging of workspace resize operations. + "VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))), # Disables parallel execution of shared_experts via separate cuda stream "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) @@ -1568,6 +1582,12 @@ def __getattr__(name: str): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +def _is_envs_cache_enabled() -> bool: + """Checked if __getattr__ is wrapped with functools.cache""" + global __getattr__ + return hasattr(__getattr__, "cache_clear") + + def enable_envs_cache() -> None: """ Enables caching of environment variables. This is useful for performance @@ -1578,6 +1598,9 @@ def enable_envs_cache() -> None: runtime overhead. This also means that environment variables should NOT be updated after the service is initialized. """ + if _is_envs_cache_enabled(): + # Avoid wrapping functools.cache multiple times + return # Tag __getattr__ with functools.cache global __getattr__ __getattr__ = functools.cache(__getattr__) @@ -1587,6 +1610,17 @@ def enable_envs_cache() -> None: __getattr__(key) +def disable_envs_cache() -> None: + """ + Resets the environment variables cache. It could be used to isolate environments + between unit tests. + """ + global __getattr__ + # If __getattr__ is wrapped by functions.cache, unwrap the caching layer. + if _is_envs_cache_enabled(): + __getattr__ = __getattr__.__wrapped__ + + def __dir__(): return list(environment_variables.keys()) @@ -1649,7 +1683,6 @@ def compile_factors() -> dict[str, object]: "VLLM_MEDIA_CONNECTOR", "VLLM_ASSETS_CACHE", "VLLM_ASSETS_CACHE_MODEL_CLEAN", - "VLLM_MM_INPUT_CACHE_GIB", "VLLM_WORKER_MULTIPROC_METHOD", "VLLM_ENABLE_V1_MULTIPROCESSING", "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 173d366267e87..033cc1f544b3b 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -292,7 +292,7 @@ def set_forward_context( if num_tokens_across_dp is None: assert ubatch_slices is None assert num_tokens is not None - _, num_tokens_across_dp = coordinate_batch_across_dp( + _, num_tokens_across_dp, _ = coordinate_batch_across_dp( num_tokens_unpadded=num_tokens, parallel_config=vllm_config.parallel_config, allow_microbatching=False, diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 211551be8e60b..71289277eb987 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -33,22 +33,31 @@ def parse_raw_prompts( if len(prompt) == 0: raise ValueError("please provide at least one prompt") + # case 2: array of strings if is_list_of(prompt, str): - # case 2: array of strings prompt = cast(list[str], prompt) return [TextPrompt(prompt=elem) for elem in prompt] + + # case 3: array of tokens if is_list_of(prompt, int): - # case 3: array of tokens prompt = cast(list[int], prompt) return [TokensPrompt(prompt_token_ids=prompt)] - if is_list_of(prompt, list): - prompt = cast(list[list[int]], prompt) - if len(prompt[0]) == 0: - raise ValueError("please provide at least one prompt") - if is_list_of(prompt[0], int): - # case 4: array of token arrays - return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] + # case 4: array of token arrays + if is_list_of(prompt, list): + first = prompt[0] + if not isinstance(first, list): + raise ValueError("prompt expected to be a list of lists") + + if len(first) == 0: + raise ValueError("Please provide at least one prompt") + + # strict validation: every nested list must be list[int] + if not all(is_list_of(elem, int) for elem in prompt): + raise TypeError("Nested lists must contain only integers") + + prompt = cast(list[list[int]], prompt) + return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] raise TypeError( "prompt must be a string, array of strings, " diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 2893a56b1190f..0372b06d0017f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -198,7 +198,7 @@ class InputPreprocessor: ) -> dict[str, Any]: kwargs = dict[str, Any]() - if self.model_config.hf_config.model_type == "whisper": + if self.model_config.is_encoder_decoder: # For Whisper, special tokens should be provided by the user based # on the task and language of their request. Also needed to avoid # appending an EOS token to the prompt which disrupts generation. @@ -573,7 +573,6 @@ class InputPreprocessor: """ encoder_inputs: SingletonInputs decoder_inputs: SingletonInputs | None - if is_explicit_encoder_decoder_prompt(prompt): # `cast` is needed for mypy, but not pyright prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt) @@ -585,7 +584,9 @@ class InputPreprocessor: if (decoder_input := prompt_["decoder_prompt"]) is None: decoder_inputs = None else: - decoder_inputs = self._prompt_to_llm_inputs(decoder_input) + decoder_inputs = self._prompt_to_llm_inputs( + decoder_input, tokenization_kwargs=tokenization_kwargs + ) # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model: diff --git a/vllm/logger.py b/vllm/logger.py index ad3123c0f0149..5506e09b8a65b 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -62,7 +62,7 @@ DEFAULT_LOGGING_CONFIG = { "loggers": { "vllm": { "handlers": ["vllm"], - "level": "DEBUG", + "level": envs.VLLM_LOGGING_LEVEL, "propagate": False, }, }, @@ -175,6 +175,9 @@ def _configure_vllm_root_logger() -> None: vllm_handler["stream"] = envs.VLLM_LOGGING_STREAM vllm_handler["formatter"] = "vllm_color" if _use_color() else "vllm" + vllm_loggers = logging_config["loggers"]["vllm"] + vllm_loggers["level"] = envs.VLLM_LOGGING_LEVEL + if envs.VLLM_LOGGING_CONFIG_PATH: if not path.exists(envs.VLLM_LOGGING_CONFIG_PATH): raise RuntimeError( @@ -226,6 +229,11 @@ def suppress_logging(level: int = logging.INFO) -> Generator[None, Any, None]: # guaranteed by the Python GIL. _configure_vllm_root_logger() +# Transformers uses httpx to access the Hugging Face Hub. httpx is quite verbose, +# so we set its logging level to WARNING when vLLM's logging level is INFO. +if envs.VLLM_LOGGING_LEVEL == "INFO": + logging.getLogger("httpx").setLevel(logging.WARNING) + logger = init_logger(__name__) diff --git a/vllm/lora/lora_model.py b/vllm/lora/lora_model.py new file mode 100644 index 0000000000000..f5e36697ed18c --- /dev/null +++ b/vllm/lora/lora_model.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os + +import safetensors +import torch + +from vllm.logger import init_logger +from vllm.lora.lora_weights import LoRALayerWeights +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.utils import ( + get_lora_id, + is_base_embeddding_weights, + is_regex_target_modules, + parse_fine_tuned_lora_name, +) +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models.utils import WeightsMapper +from vllm.utils.platform_utils import is_pin_memory_available + +logger = init_logger(__name__) + + +class LoRAModel: + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: dict[str, LoRALayerWeights], + ) -> None: + """ + Args: + lora_model_id: The integer id for the lora model. + rank: lora rank. + loras: module name -> weights for lora-replaced layers. + + """ + self.id = lora_model_id + + assert lora_model_id > 0, ( + f"a valid lora id should be greater than 0, got {self.id}" + ) + self.rank = rank + self.loras: dict[str, LoRALayerWeights] = loras + + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + + def get_lora(self, module_name: str) -> LoRALayerWeights | None: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + def check_lora_name(self, lora_name: str) -> bool: + return lora_name in self.loras + + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + tensors: dict[str, torch.Tensor], + peft_helper: PEFTHelper, + device: str = "cuda", + dtype: torch.dtype | None = None, + model_vocab_size: int | None = None, + weights_mapper: WeightsMapper | None = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and is_pin_memory_available() + loras: dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + if is_base_embeddding_weights(tensor_name): + continue + module_name, is_lora_a = parse_fine_tuned_lora_name( + tensor_name, weights_mapper + ) + if module_name not in loras: + loras[module_name] = LoRALayerWeights.from_config( + module_name, peft_helper + ) + + if is_lora_a: + if ( + "lora_embedding_A" in tensor_name + and model_vocab_size is not None + and model_vocab_size != tensor.shape[1] + ): + raise RuntimeError( + f"The embedding LoRA size({tensor.shape[1]}) must be consistent" + f" with the base model's vocabulary size({model_vocab_size})." + ) + loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) + if pin_memory: + loras[module_name].lora_a = loras[module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, dtype=dtype) + + if pin_memory: + loras[module_name].lora_b = loras[module_name].lora_b.pin_memory() + + return cls(lora_model_id, peft_helper.r, loras) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + expected_lora_modules: set[str], + peft_helper: PEFTHelper, + *, + lora_model_id: int | None = None, + device: str = "cuda", + dtype: torch.dtype | None = None, + model_vocab_size: int | None = None, + weights_mapper: WeightsMapper | None = None, + tensorizer_config_dict: dict | None = None, + ) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + peft_helper: Loaded lora configuration information. + lora_model_id: LoRA model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") + + tensors: dict[str, torch.Tensor] = {} + unexpected_modules: list[list[str] | str] = [] + + def check_unexpected_modules(modules: dict): + for lora_module in modules.keys(): # noqa + if is_base_embeddding_weights(lora_module): + continue + # Handle PEFT file format where experts.base_layer is the + # gate_up_proj and experts is the down_proj + if "base_layer" in lora_module: + continue + module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) + # Case for expert lora weights + if ".experts" in module_name: + expert_idx = module_name.find(".experts") + expert_suffix = module_name[expert_idx + 1 :] + if expert_suffix not in expected_lora_modules: + unexpected_modules.append(module_name) + + elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules: + unexpected_modules.append(module_name) + + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct" + ) + + if tensorizer_config_dict: + from tensorizer import TensorDeserializer + + tensorizer_config = TensorizerConfig(**tensorizer_config_dict) + lora_tensor_path = os.path.join( + tensorizer_config.tensorizer_dir, "adapter_model.tensors" + ) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + tensors = TensorDeserializer( + lora_tensor_path, + dtype=tensorizer_config.dtype, + **tensorizer_args.deserialization_kwargs, + ) + check_unexpected_modules(tensors) + + elif os.path.isfile(lora_tensor_path): + # Find unexpected modules. + # Use safetensor key as a source of truth to find expected modules. + # in peft if you have target_modules A, B, C and C does not exist + # in the model it won’t error and model will be trained with A, B + # loraified. C won’t exist in the safetensor but it will exist in + # the target_modules of the adapter_config.json. + unexpected_modules = [] + with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore + # Load tensors if there are only expected modules. + check_unexpected_modules(f) + for module in f.keys(): # noqa + tensors[module] = f.get_tensor(module) + elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path): + # When a bin/pt file is provided, we rely on config to find + # unexpected modules. + unexpected_modules = [] + target_modules = peft_helper.target_modules + if not isinstance(target_modules, list): + target_modules = [target_modules] + for module in target_modules: + # Compatible with more modules, + # such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of + # expected_lora_modules. It is not reliable. See + # https://github.com/vllm-project/vllm/pull/5909. But there's no + # other better mechanism. + if unexpected_modules and not is_regex_target_modules( + peft_helper.target_modules, expected_lora_modules + ): + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct" + ) + lora_file_path = ( + lora_bin_file_path + if os.path.isfile(lora_bin_file_path) + else lora_pt_file_path + ) + tensors = torch.load(lora_file_path, map_location=device, weights_only=True) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + return cls.from_lora_tensors( + lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, + tensors=tensors, + peft_helper=peft_helper, + device=device, + dtype=dtype, + model_vocab_size=model_vocab_size, + weights_mapper=weights_mapper, + ) diff --git a/vllm/lora/models.py b/vllm/lora/model_manager.py similarity index 72% rename from vllm/lora/models.py rename to vllm/lora/model_manager.py index f568b8b9ba595..44e0448d92de0 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/model_manager.py @@ -2,38 +2,32 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -import os from collections.abc import Callable from typing import TypeVar import regex as re -import safetensors.torch import torch from torch import nn from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping +from vllm.lora.lora_model import LoRAModel from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.utils import ( from_layer, from_layer_logits_processor, get_supported_lora_modules, - is_base_embeddding_weights, is_moe_model, - is_regex_target_modules, - parse_fine_tuned_lora_name, process_packed_modules_mapping, replace_submodule, ) from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper +from vllm.model_executor.models.utils import PPMissingLayer from vllm.utils.cache import LRUCache from vllm.utils.platform_utils import is_pin_memory_available @@ -53,239 +47,6 @@ class AdapterLRUCache(LRUCache[int, T]): return super()._on_remove(key, value) -_GLOBAL_LORA_ID = 0 - - -def get_lora_id(): - global _GLOBAL_LORA_ID - _GLOBAL_LORA_ID += 1 - return _GLOBAL_LORA_ID - - -class LoRAModel: - """A LoRA fine-tuned model.""" - - def __init__( - self, - lora_model_id: int, - rank: int, - loras: dict[str, LoRALayerWeights], - ) -> None: - """ - Args: - lora_model_id: The integer id for the lora model. - rank: lora rank. - loras: module name -> weights for lora-replaced layers. - - """ - self.id = lora_model_id - - assert lora_model_id > 0, ( - f"a valid lora id should be greater than 0, got {self.id}" - ) - self.rank = rank - self.loras: dict[str, LoRALayerWeights] = loras - - def clone(self, lora_model_id: int) -> "LoRAModel": - """Return a copy of the object with different ids. - - Will share the underlying tensors.""" - return self.__class__( - lora_model_id, - rank=self.rank, - loras=self.loras.copy(), - ) - - def get_lora(self, module_name: str) -> LoRALayerWeights | None: - """Get LoRA for a given module by name""" - return self.loras.get(module_name, None) - - def check_lora_name(self, lora_name: str) -> bool: - return lora_name in self.loras - - @classmethod - def from_lora_tensors( - cls, - lora_model_id: int, - tensors: dict[str, torch.Tensor], - peft_helper: PEFTHelper, - device: str = "cuda", - dtype: torch.dtype | None = None, - model_vocab_size: int | None = None, - weights_mapper: WeightsMapper | None = None, - ) -> "LoRAModel": - """Create a LoRAModel from a dictionary of tensors.""" - pin_memory = str(device) == "cpu" and is_pin_memory_available() - loras: dict[str, LoRALayerWeights] = {} - for tensor_name, tensor in tensors.items(): - if is_base_embeddding_weights(tensor_name): - continue - module_name, is_lora_a = parse_fine_tuned_lora_name( - tensor_name, weights_mapper - ) - if module_name not in loras: - loras[module_name] = LoRALayerWeights.from_config( - module_name, peft_helper - ) - - if is_lora_a: - if ( - "lora_embedding_A" in tensor_name - and model_vocab_size is not None - and model_vocab_size != tensor.shape[1] - ): - raise RuntimeError( - f"The embedding LoRA size({tensor.shape[1]}) must be consistent" - f" with the base model's vocabulary size({model_vocab_size})." - ) - loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) - if pin_memory: - loras[module_name].lora_a = loras[module_name].lora_a.pin_memory() - else: - loras[module_name].lora_b = tensor.to(device=device, dtype=dtype) - - if pin_memory: - loras[module_name].lora_b = loras[module_name].lora_b.pin_memory() - - return cls(lora_model_id, peft_helper.r, loras) - - @classmethod - def from_local_checkpoint( - cls, - lora_dir: str, - expected_lora_modules: set[str], - peft_helper: PEFTHelper, - *, - lora_model_id: int | None = None, - device: str = "cuda", - dtype: torch.dtype | None = None, - model_vocab_size: int | None = None, - weights_mapper: WeightsMapper | None = None, - tensorizer_config_dict: dict | None = None, - ) -> "LoRAModel": - """Create a LoRAModel from a local checkpoint. - - Args: - lora_dir: The local path that has lora data. - expected_lora_modules: Name of modules that are expected to be - replaced by lora. - peft_helper: Loaded lora configuration information. - lora_model_id: LoRA model id. If not given, automatically set by - a global counter. - device: Device where the lora model is loaded. - dtype: dtype of the lora model weights. - - Returns: - Loaded LoRA Model. - """ - lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") - lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") - lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") - - tensors: dict[str, torch.Tensor] = {} - unexpected_modules: list[list[str] | str] = [] - - def check_unexpected_modules(modules: dict): - for lora_module in modules.keys(): # noqa - if is_base_embeddding_weights(lora_module): - continue - # Handle PEFT file format where experts.base_layer is the - # gate_up_proj and experts is the down_proj - if "base_layer" in lora_module: - continue - module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) - # Case for expert lora weights - if ".experts" in module_name: - expert_idx = module_name.find(".experts") - expert_suffix = module_name[expert_idx + 1 :] - if expert_suffix not in expected_lora_modules: - unexpected_modules.append(module_name) - - elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules: - unexpected_modules.append(module_name) - - if unexpected_modules: - raise ValueError( - f"While loading {lora_dir}, expected" - f" target modules in {expected_lora_modules}" - f" but received {unexpected_modules}." - f" Please verify that the loaded LoRA module is correct" - ) - - if tensorizer_config_dict: - from tensorizer import TensorDeserializer - - tensorizer_config = TensorizerConfig(**tensorizer_config_dict) - lora_tensor_path = os.path.join( - tensorizer_config.tensorizer_dir, "adapter_model.tensors" - ) - tensorizer_args = tensorizer_config._construct_tensorizer_args() - tensors = TensorDeserializer( - lora_tensor_path, - dtype=tensorizer_config.dtype, - **tensorizer_args.deserialization_kwargs, - ) - check_unexpected_modules(tensors) - - elif os.path.isfile(lora_tensor_path): - # Find unexpected modules. - # Use safetensor key as a source of truth to find expected modules. - # in peft if you have target_modules A, B, C and C does not exist - # in the model it won’t error and model will be trained with A, B - # loraified. C won’t exist in the safetensor but it will exist in - # the target_modules of the adapter_config.json. - unexpected_modules = [] - with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore - # Load tensors if there are only expected modules. - check_unexpected_modules(f) - for module in f.keys(): # noqa - tensors[module] = f.get_tensor(module) - elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path): - # When a bin/pt file is provided, we rely on config to find - # unexpected modules. - unexpected_modules = [] - target_modules = peft_helper.target_modules - if not isinstance(target_modules, list): - target_modules = [target_modules] - for module in target_modules: - # Compatible with more modules, - # such as:layers.11.self_attn.k_proj - part_name = module.split(".")[-1] - if part_name not in expected_lora_modules: - unexpected_modules.append(module) - # loaded lora's target modules must be a subset of - # expected_lora_modules. It is not reliable. See - # https://github.com/vllm-project/vllm/pull/5909. But there's no - # other better mechanism. - if unexpected_modules and not is_regex_target_modules( - peft_helper.target_modules, expected_lora_modules - ): - raise ValueError( - f"While loading {lora_dir}, expected" - f" target modules in {expected_lora_modules}" - f" but received {unexpected_modules}." - f" Please verify that the loaded LoRA module is correct" - ) - lora_file_path = ( - lora_bin_file_path - if os.path.isfile(lora_bin_file_path) - else lora_pt_file_path - ) - tensors = torch.load(lora_file_path, map_location=device, weights_only=True) - else: - raise ValueError(f"{lora_dir} doesn't contain tensors") - - return cls.from_lora_tensors( - lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, - tensors=tensors, - peft_helper=peft_helper, - device=device, - dtype=dtype, - model_vocab_size=model_vocab_size, - weights_mapper=weights_mapper, - ) - - class LoRAModelManager: """A manager that manages multiple LoRA-fine-tuned models.""" @@ -574,9 +335,9 @@ class LoRAModelManager: def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA), ( - f"Module {module_name} must be a BaseLayerWithLoRA instance," + f"Module {module_name} must be a BaseLayerWithLoRA instance, " + f"got {type(module)}" ) - f" got {type(module)}" self.modules[module_name] = module def create_dummy_lora( @@ -742,6 +503,32 @@ class LoRAModelManager: for lora in lora_model.loras.values(): lora.optimize() + first_lora: LoRALayerWeights = next(iter(lora_model.loras.values())) + assert first_lora.lora_a is not None + if isinstance(first_lora.lora_a, list): + lora_device = next(iter(first_lora.lora_a)) + else: + lora_device = first_lora.lora_a.device + # Execute pin_memory after LoRA weight merging, mainly because: + # 1. Some MoE models have a large number of LoRA weights. If we + # perform # pin_memory immediately after loading weights, the + # overhead is significant. + # 2. The weight packing above (e.g., pack_moe) may invalidate the + # pin_memory allocation, so we execute it after packing. + + pin_memory = str(lora_device) == "cpu" and is_pin_memory_available() + if pin_memory: + for lora in lora_model.loras.values(): + if isinstance(lora.lora_a, list): + for index in range(len(lora.lora_a)): + if lora.lora_a[index] is None: + continue + lora.lora_a[index] = lora.lora_a[index].pin_memory() + lora.lora_b[index] = lora.lora_b[index].pin_memory() + else: + lora.lora_a = lora.lora_a.pin_memory() + lora.lora_b = lora.lora_b.pin_memory() + def _get_lora_layer_weights( self, lora_model: LoRAModel, module_name: str ) -> LoRALayerWeights | None: diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 413ee8ecbbf96..34383cdf1767c 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -96,10 +96,14 @@ def _fused_moe_lora_kernel( slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) lora_id = tl.load(lora_ids + lora_idx) - moe_enabled = tl.load(adapter_enabled + lora_id) - if lora_id == -1 or moe_enabled == 0: + + if lora_id == -1: # Early exit for the no-lora case. return + moe_enabled = tl.load(adapter_enabled + lora_id) + if moe_enabled == 0: + # Early exit for the no moe lora case. + return max_loras = tl.num_programs(axis=2) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index c97e435e32165..55756bdb103bd 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -14,11 +14,6 @@ class LoRARequest( """ Request for a LoRA adapter. - Note that this class should be used internally. For online - serving, it is recommended to not allow users to use this class but - instead provide another layer of abstraction to prevent users from - accessing unauthorized LoRA adapters. - lora_int_id must be globally unique for a given adapter. This is currently not enforced in vLLM. """ diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 47484b2b984df..4d264c06826b8 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -48,6 +48,15 @@ if TYPE_CHECKING: logger = init_logger(__name__) +_GLOBAL_LORA_ID = 0 + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + _all_lora_classes: set[type[BaseLayerWithLoRA]] = { VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 7d77ba7247ef0..28c2a53d84e42 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -8,8 +8,8 @@ import torch from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.lora.models import ( - LoRAModel, +from vllm.lora.lora_model import LoRAModel +from vllm.lora.model_manager import ( LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager, diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 9ef696d80712c..66250f816f459 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -38,8 +38,9 @@ class CustomOp(nn.Module): ) return super().__new__(op_cls_to_instantiate) - def __init__(self): + def __init__(self, enforce_enable: bool = False): super().__init__() + self._enforce_enable = enforce_enable self._forward_method = self.dispatch_forward() def forward(self, *args, **kwargs): @@ -84,7 +85,11 @@ class CustomOp(nn.Module): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. compilation_config = get_cached_compilation_config() - enabled = self.enabled() + + # CustomOp object can be enforce enabled, e.g., enable device-specific + # kernels in ViT models when enabling graph mode. By default, it will + # follow the compilation_config to determine whether enable itself. + enabled = self._enforce_enable or self.enabled() if enabled: compilation_config.enabled_custom_ops.update([self.__class__.name]) else: diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 3471ee327cf8c..7038d0868c7eb 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -159,6 +159,13 @@ class GeluAndMulSparse(CustomOp): self.approximate = approximate if approximate not in ("none", "tanh"): raise ValueError(f"Unknown approximate mode: {approximate}") + if current_platform.is_rocm() and approximate == "tanh": + # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile + logger.warning_once( + "[ROCm] Pytorch's native GELU with tanh approximation is currently " + "unstable and produces garbage. Fallback to 'none' approximation." + ) + self.approximate = "none" # Sparsity. if activation_sparsity == 0.0: @@ -209,6 +216,12 @@ class GeluAndMul(CustomOp): self.op = torch.ops._C.gelu_and_mul elif approximate == "tanh": self.op = torch.ops._C.gelu_tanh_and_mul + if current_platform.is_rocm() and approximate == "tanh": + logger.warning_once( + "[ROCm] PyTorch's native GELU with tanh approximation is unstable " + "with torch.compile. For native implementation, fallback to 'none' " + "approximation. The custom kernel implementation is unaffected." + ) elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops @@ -219,8 +232,12 @@ class GeluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" + # TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile + approximate = self.approximate + if current_platform.is_rocm() and approximate == "tanh": + approximate = "none" d = x.shape[-1] // 2 - return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + return F.gelu(x[..., :d], approximate=approximate) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 @@ -522,7 +539,16 @@ _ACTIVATION_REGISTRY = LazyDict( "gelu": lambda: nn.GELU(), "gelu_fast": lambda: FastGELU(), "gelu_new": lambda: NewGELU(), - "gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"), + "gelu_pytorch_tanh": lambda: ( + # TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile + logger.warning_once( + "[ROCm] PyTorch's native GELU with tanh approximation is unstable. " + "Falling back to GELU(approximate='none')." + ), + nn.GELU(approximate="none"), + )[1] + if current_platform.is_rocm() + else nn.GELU(approximate="tanh"), "relu": lambda: nn.ReLU(), "relu2": lambda: ReLUSquaredActivation(), "silu": lambda: nn.SiLU(), diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 4154122636dcf..fde0826779eb1 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -6,7 +6,7 @@ from typing import Any import torch -import vllm.envs as envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -935,7 +935,11 @@ def enable_batch_invariant_mode(): # Batch invariant matmuls are no longer needed after cublas overrides if not is_torch_equal_or_newer("2.10.0.dev"): - if current_platform.is_device_capability(100): + if ( + current_platform.is_device_capability_family(100) + or current_platform.is_device_capability(80) + or current_platform.is_device_capability(89) + ): # For PyTorch 2.9, B200 uses GEMV for bs=1 # Requires https://github.com/pytorch/pytorch/pull/166735 _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") @@ -1000,27 +1004,30 @@ def vllm_is_batch_invariant() -> bool: return VLLM_BATCH_INVARIANT -def override_envs_for_invariance(): - curr_attn_backend = envs.VLLM_ATTENTION_BACKEND +def override_envs_for_invariance( + attention_backend: AttentionBackendEnum | None, +): supported_backends = [ - "FLASH_ATTN", # best supported backend - "FLASHINFER", - "FLASH_ATTN_MLA", + AttentionBackendEnum.FLASH_ATTN, # best supported backend + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.TRITON_MLA, # Not yet supported MLA backends - # "FLASHMLA", - # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance - # "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967 - # "TRITON_MLA", + # AttentionBackendEnum.FLASHMLA, + # AttentionBackendEnum.FLEX_ATTENTION, # IMA issue + # AttentionBackendEnum.FLASHINFER_MLA, # PR #28967 ] - if curr_attn_backend not in supported_backends: + if attention_backend not in supported_backends: + supported_names = [b.name for b in supported_backends] + backend_name = attention_backend.name if attention_backend else None error = ( "VLLM batch_invariant mode requires an attention backend in " - f"{supported_backends}, but got '{curr_attn_backend}'. " - "Please set the 'VLLM_ATTENTION_BACKEND' environment variable " - "to one of the supported backends before enabling batch_invariant." + f"{supported_names}, but got '{backend_name}'. " + "Please use --attention-backend or attention_config to set " + "one of the supported backends before enabling batch_invariant." ) raise RuntimeError(error) - if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]: + if attention_backend != supported_backends[0]: warning = ( "You are using a decode-invariant form of batch invariance. " "This will not be invariant between prefill and decode." @@ -1046,10 +1053,12 @@ def override_envs_for_invariance(): os.environ["VLLM_USE_AOT_COMPILE"] = "0" -def init_batch_invariance(): +def init_batch_invariance( + attention_backend: AttentionBackendEnum | None, +): # this will hit all the csrc overrides as well if vllm_is_batch_invariant(): - override_envs_for_invariance() + override_envs_for_invariance(attention_backend) enable_batch_invariant_mode() # Disable TF32 for batch invariance - it causes non-deterministic rounding diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 669abcb3d6ff1..d71cfc5ad8200 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -4,7 +4,10 @@ from contextlib import contextmanager from typing import Any -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + RoutingMethodType, +) from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) @@ -49,6 +52,7 @@ __all__ = [ "FusedMoEPermuteExpertsUnpermute", "FusedMoEActivationFormat", "FusedMoEPrepareAndFinalize", + "RoutingMethodType", "SharedFusedMoE", "activation_without_mul", "override_config", @@ -60,14 +64,13 @@ if HAS_TRITON: from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) - from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts, - ) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassBatchedExpertsFp8, CutlassExpertsFp8, + CutlassExpertsW4A8Fp8, cutlass_moe_fp4, cutlass_moe_fp8, + cutlass_moe_w4a8_fp8, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( @@ -91,14 +94,15 @@ if HAS_TRITON: "grouped_topk", "cutlass_moe_fp8", "cutlass_moe_fp4", + "cutlass_moe_w4a8_fp8", "CutlassExpertsFp8", "CutlassBatchedExpertsFp8", + "CutlassExpertsW4A8Fp8", "TritonExperts", "BatchedTritonExperts", "DeepGemmExperts", "BatchedDeepGemmExperts", "TritonOrDeepGemmExperts", - "BatchedTritonOrDeepGemmExperts", ] else: # Some model classes directly use the custom ops. Add placeholders diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 53362277dae8a..15f6e3a18ed6c 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -287,7 +287,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): """ DeepGemm supports packed ue8m0 activation scales format in devices == sm100 """ - return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100) + return ( + is_deep_gemm_e8m0_used() + and current_platform.is_device_capability_family(100) + ) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py deleted file mode 100644 index e69e9fd307aeb..0000000000000 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ /dev/null @@ -1,180 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts, -) -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts -from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout - - -class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - max_num_tokens: int, - num_dispatchers: int, - quant_config: FusedMoEQuantConfig, - allow_deep_gemm: bool = False, - ): - super().__init__(quant_config) - - self.batched_triton_experts = BatchedTritonExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - quant_config=self.quant_config, - ) - - self.allow_deep_gemm = ( - allow_deep_gemm - and self.quant_config.use_fp8_w8a8 - and self.block_shape == get_mk_alignment_for_contiguous_layout() - ) - - self.batched_deep_gemm_experts = ( - BatchedDeepGemmExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - quant_config=self.quant_config, - ) - if self.allow_deep_gemm - else None - ) - - assert ( - self.batched_deep_gemm_experts is not None - or self.batched_triton_experts is not None - ) - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - if self.batched_triton_experts is not None: - assert ( - self.batched_deep_gemm_experts is None - or self.batched_deep_gemm_experts.activation_formats - == self.batched_triton_experts.activation_formats - ) - return self.batched_triton_experts.activation_formats - else: - assert self.batched_deep_gemm_experts is not None - return self.batched_deep_gemm_experts.activation_formats - - def supports_chunking(self) -> bool: - bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - return (bdge is None or bdge.supports_chunking()) and ( - bte is None or bte.supports_chunking() - ) - - def supports_expert_map(self) -> bool: - bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - return (bdge is None or bdge.supports_expert_map()) and ( - bte is None or bte.supports_expert_map() - ) - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None - bte_war = bte.finalize_weight_and_reduce_impl() if bte else None - is_bdge_war = bdge_war is not None - is_bte_war = bte_war is not None - - if is_bdge_war and is_bte_war: - assert bdge_war == bte_war, ( - "Both implementations should agree on WeightAndReduce impls. " - f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}" - ) - - if bdge_war is not None: - return bdge_war - - assert bte_war is not None - return bte_war - - def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: - return act_dtype - - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - global_num_experts: int, - local_num_experts: int, - expert_tokens_metadata: mk.ExpertTokensMetadata | None, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - # Note: the deep gemm workspaces are strictly larger than the triton - # workspaces so we can be pessimistic here and allocate for DeepGemm - # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm: - assert self.batched_deep_gemm_experts is not None - return self.batched_deep_gemm_experts.workspace_shapes( - M, - N, - K, - topk, - global_num_experts, - local_num_experts, - expert_tokens_metadata, - ) - else: - assert self.batched_triton_experts is not None - return self.batched_triton_experts.workspace_shapes( - M, - N, - K, - topk, - global_num_experts, - local_num_experts, - expert_tokens_metadata, - ) - - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool, - ): - experts = ( - self.batched_deep_gemm_experts - if self.allow_deep_gemm - else self.batched_triton_experts - ) - assert experts is not None - experts.apply( - output, - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - activation, - global_num_experts, - expert_map, - a1q_scale, - a2_scale, - workspace13, - workspace2, - expert_tokens_meta, - apply_router_weight_on_input, - ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 1826fafa8c4f5..a9a2990ca2b53 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -143,6 +143,7 @@ class FusedMoEQuantDesc: scale: Union[torch.Tensor, "PrecisionConfig", None] = None # Quantization alphas or gscales, used for nvfp4 types. + # W4A8 FP8: used for per-channel scales # TODO(bnell): put some of these in subclasses alpha_or_gscale: torch.Tensor | None = None @@ -345,6 +346,10 @@ class FusedMoEQuantConfig: def use_mxfp4_w4a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == "mxfp4" + @property + def use_mxfp4_w4a4(self) -> bool: + return self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4" + @property def use_nvfp4_w4a4(self) -> bool: return self.quant_dtype == "nvfp4" @@ -438,7 +443,9 @@ class FusedMoEQuantConfig: - a1_scale: Optional scale to be used for a1. - a2_scale: Optional scale to be used for a2. - g1_alphas: Optional global quantization scales for w1 (for nvfp4). + per-channel scales for w1 (for W4A8 FP8). - g2_alphas: Optional global quantization scales for w2 (for nvfp4). + per-channel scales for w2 (for W4A8 FP8). - a1_gscale: Optional global quantization scales for a1 (for nvfp4). - a2_gscale: Optional global quantization scales for a2 (for nvfp4). - w1_bias: Optional biases for w1 (GPT OSS Triton). @@ -457,6 +464,7 @@ class FusedMoEQuantConfig: "mxfp4", "mxfp6_e3m2", "mxfp6_e2m3", + "int4", } if weight_dtype is None: @@ -535,6 +543,42 @@ def int8_w8a8_moe_quant_config( ) +def gptq_marlin_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + weight_bits: int, + group_size: int, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, +): + """ + Construct a quant config for gptq marlin quantization. + """ + from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape + + w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size) + + # Activations are NOT quantized for GPTQ (fp16/bf16) + a_shape = w_shape # Same as weight shape for alignment + + # Determine weight dtype + if weight_bits == 4: + weight_dtype = "int4" + elif weight_bits == 8: + weight_dtype = torch.int8 + else: + raise ValueError(f"Unsupported weight_bits: {weight_bits}") + + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(dtype=None, shape=a_shape), + _a2=FusedMoEQuantDesc(dtype=None, shape=a_shape), + _w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias), + _w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias), + ) + + def mxfp4_w4a16_moe_quant_config( w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], @@ -667,6 +711,67 @@ def int8_w8a16_moe_quant_config( ) +def int4_w4afp8_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: list[int] | None = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for fp8 activations and int4 weights. + """ + return FusedMoEQuantConfig.make( + torch.float8_e4m3fn, # quant dtype for activations + w1_scale=w1_scale, + w2_scale=w2_scale, + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + weight_dtype="int4", # weight dtype for weights + ) + + +def awq_marlin_moe_quant_config( + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + w1_zp: torch.Tensor | None, + w2_zp: torch.Tensor | None, + weight_bits: int, + group_size: int, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for awq marlin quantization. + """ + from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape + + w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size) + + # Activations are NOT quantized for AWQ (fp16/bf16) + a_shape = w_shape # Same as weight shape for alignment + + # Determine weight dtype + if weight_bits == 4: + weight_dtype = "int4" + elif weight_bits == 8: + weight_dtype = torch.int8 + else: + raise ValueError(f"Unsupported weight_bits: {weight_bits}") + + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(dtype=None, shape=a_shape), + _a2=FusedMoEQuantDesc(dtype=None, shape=a_shape), + _w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias), + _w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias), + ) + + def biased_moe_quant_config( w1_bias: torch.Tensor | None, w2_bias: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..a9f24c20a25a2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 6753a19250b3b..4a0b4e82c1b39 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1050,3 +1050,404 @@ def run_cutlass_block_scaled_fused_experts( return ( c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype) ).sum(dim=1) + + +# W4A8 +def run_cutlass_moe_w4a8_fp8( + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation_callable: Callable, + global_num_experts: int, + expert_map: torch.Tensor | None, + w1_scale: torch.Tensor | None, + w2_scale: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + w1_chan_scale: torch.Tensor, + w2_chan_scale: torch.Tensor, + a_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides1: torch.Tensor, + b_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + s_strides1: torch.Tensor, + s_strides2: torch.Tensor, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: torch.Tensor | None, + out_dtype: torch.dtype, + per_act_token: bool, + per_out_ch: bool, + use_batched_format: bool, + topk_weights: torch.Tensor | None, + group_size: int, +): + a1q = hidden_states + M = a1q.size(0) + local_E = w1.size(0) + device = a1q.device + _, K, N_packed = w2.shape + N = N_packed * 8 # logical N, pack 8 int4 into 1 int32 + + assert per_act_token, "W4A8 must use per-token scales" + assert per_out_ch, "W4A8 must use per-channel scales" + assert w1_scale is not None + assert w2_scale is not None + assert w1_scale.dtype == torch.float8_e4m3fn + assert w2_scale.dtype == torch.float8_e4m3fn + assert w1.dtype == torch.int32 + assert w2.dtype == torch.int32 + assert w1_chan_scale.dtype == torch.float32 + assert w2_chan_scale.dtype == torch.float32 + assert w1.size(0) == w2.size(0), "Weights expert number mismatch" + assert a1q_scale is not None + assert a2_scale is None + assert out_dtype in [torch.bfloat16], f"Invalid output dtype: {out_dtype}" + if expert_map is not None: + assert expert_num_tokens is None + assert not use_batched_format, "batched format not supported yet" + assert group_size == 128, f"Only group size 128 supported but got {group_size=}" + + assert global_num_experts != -1 + assert w1.size(2) * 8 == K, ( + f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}" + ) + + # Translate info from expert_map to topk_ids + if expert_map is not None: + local_topk_ids = torch.where( + expert_map[topk_ids] != -1, expert_map[topk_ids], -1 + ) + else: + local_topk_ids = topk_ids + + topk = local_topk_ids.size(1) + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K)) + mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) + act_out = _resize_cache(workspace2, (M * topk, N)) + # original workspace are based on input hidden_states dtype (bf16) + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M * topk, N) + ) + mm2_out = _resize_cache(workspace2, (M * topk, K)) + + problem_sizes1 = torch.empty( + (global_num_experts, 3), dtype=torch.int32, device=device + ) + problem_sizes2 = torch.empty( + (global_num_experts, 3), dtype=torch.int32, device=device + ) + + num_expert = global_num_experts if expert_map is None else expert_map.size(0) + # permuted a1q reuses workspace2 + a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( + a1q, + a1q_scale, + topk_ids, + num_expert, + local_E, + expert_map, + permuted_hidden_states=a1q_perm, + ) + expert_offsets = expert_offsets[:-1] + + # For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape) + ops.get_cutlass_moe_mm_problem_sizes( + local_topk_ids, + problem_sizes1, + problem_sizes2, + global_num_experts, + N, + K, + force_swap_ab=True, + ) + + ops.cutlass_w4a8_moe_mm( + mm1_out, + a1q, + w1, + a1q_scale, + w1_chan_scale, + w1_scale, + group_size, + expert_offsets, + problem_sizes1, + a_strides1, + b_strides1, + c_strides1, + s_strides1, + ) + + activation_callable(act_out, mm1_out) + + a2q, a2q_scale = ops.scaled_fp8_quant( + act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out + ) + + if expert_map is not None: + mm2_out.fill_(0) + + ops.cutlass_w4a8_moe_mm( + mm2_out, + a2q, + w2, + a2q_scale, + w2_chan_scale, + w2_scale, + group_size, + expert_offsets, + problem_sizes2, + a_strides2, + b_strides2, + c_strides2, + s_strides2, + ) + + # for non-chunking mode the output is resized from workspace13 + # so we need to make sure mm2_out uses workspace2. + moe_unpermute( + out=output, + permuted_hidden_states=mm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm, + ) + + +class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( + self, + out_dtype: torch.dtype | None, + a_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides1: torch.Tensor, + b_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + s_strides1: torch.Tensor, + s_strides2: torch.Tensor, + quant_config: FusedMoEQuantConfig, + group_size: int, + ): + super().__init__(quant_config) + self.out_dtype = out_dtype + self.a_strides1 = a_strides1 + self.a_strides2 = a_strides2 + self.b_strides1 = b_strides1 + self.b_strides2 = b_strides2 + self.c_strides1 = c_strides1 + self.c_strides2 = c_strides2 + self.s_strides1 = s_strides1 + self.s_strides2 = s_strides2 + self.group_size = group_size + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # topk weights and reduction are fused in moe_unpermute cuda kernel + return TopKWeightAndReduceNoOP() + + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + workspace1 = (M * topk, max(N, K)) + workspace2 = (M * topk, max(N // 2, K)) + output = (M, K) + return (workspace1, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor | None, + workspace2: torch.Tensor | None, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + + expert_num_tokens = None + activation_callable = lambda o, i: self.activation(activation, o, i) + + use_batched_format = ( + self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + ) + assert not use_batched_format, "batched format not supported" + + in_dtype = hidden_states.dtype + + run_cutlass_moe_w4a8_fp8( + output, + hidden_states, + w1, + w2, + topk_ids, + activation_callable, + global_num_experts, + expert_map, + self.w1_scale, + self.w2_scale, + a1q_scale, + a2_scale, + self.g1_alphas, # per-channel scales + self.g2_alphas, # per-channel scales + self.a_strides1, + self.a_strides2, + self.b_strides1, + self.b_strides2, + self.c_strides1, + self.c_strides2, + self.s_strides1, + self.s_strides2, + workspace13, + workspace2, + expert_num_tokens, + self.out_dtype if self.out_dtype is not None else in_dtype, + self.per_act_token_quant, + self.per_out_ch_quant, + use_batched_format, + topk_weights, + self.group_size, + ) + + +def cutlass_moe_w4a8_fp8( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + a_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides1: torch.Tensor, + b_strides2: torch.Tensor, + c_strides1: torch.Tensor, + c_strides2: torch.Tensor, + s_strides1: torch.Tensor, + s_strides2: torch.Tensor, + quant_config: FusedMoEQuantConfig, + activation: str = "silu", + expert_map: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + group_size: int = 128, +) -> torch.Tensor: + """ + This function computes a w4a8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + mixed-dtype grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + Shape: [num_experts, 2*N, K // packed_factor] + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + Shape: [num_experts, K, N // packed_factor] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mappings. + - a_strides1 (torch.Tensor): The input strides for the first gemm. + Shape: [num_experts] + - a_strides2 (torch.Tensor): The input strides for the second gemm. + Shape: [num_experts] + - b_strides1 (torch.Tensor): The packed layout for the first gemm weights. + Shape: [num_experts, 3] + dtype: torch.int32 + - b_strides2 (torch.Tensor): The packed layout for the second gemm weights. + Shape: [num_experts, 3] + dtype: torch.int32 + - c_strides1 (torch.Tensor): The output strides for the first gemm. + Shape: [num_experts] + - c_strides2 (torch.Tensor): The output strides for the second gemm. + Shape: [num_experts] + - s_strides1 (torch.Tensor): strides for the group-wise scales for the first gemm. + Shape: [num_experts, 2] + dtype: torch.int64 + - s_strides2 (torch.Tensor): strides for the group-wise scales for the second gemm. + Shape: [num_experts, 2] + dtype: torch.int64 + - per_act_token (Optional[bool]): Whether the scale is per-token or + per-tensor. + - activation (str): The activation function to use. + - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, + every Rank is responsible for a subset of experts. expert_map is a + mapping from global expert-id to local expert-id. When expert_map[i] + is -1, it means that this Rank is not responsible for global + expert-id i. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. + - global_num_experts (int): The total number of experts. + - group_size (int): The number of weights per scale factor + + Returns: + - torch.Tensor: The bf16 output tensor after applying the MoE layer. + """ + assert quant_config is not None + + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) + + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsW4A8Fp8( + out_dtype=a.dtype, + a_strides1=a_strides1, + a_strides2=a_strides2, + b_strides1=b_strides1, + b_strides2=b_strides2, + c_strides1=c_strides1, + c_strides2=c_strides2, + s_strides1=s_strides1, + s_strides2=s_strides2, + quant_config=quant_config, + group_size=group_size, + ), + ) + + return fn( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 86cdd25f2c873..5ca91768c9760 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -from tqdm import tqdm -import vllm.envs as env import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( @@ -25,12 +23,14 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, + per_token_group_quant_fp8_packed_for_deepgemm, + silu_mul_per_token_group_quant_fp8_colmajor, ) from vllm.utils.deep_gemm import ( + DeepGemmQuantScaleFMT, get_mk_alignment_for_contiguous_layout, m_grouped_fp8_gemm_nt_contiguous, ) -from vllm.utils.func_utils import run_once from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) @@ -108,70 +108,6 @@ def _valid_deep_gemm( return True -@run_once -def warmup_deepgemm_gg_contiguous_kernels( - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - num_topk: int, -): - """ - DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the - input tensor shapes. In this function, we construct all possible input - tensor shapes so all the kernels are JIT'ed and cached. - Note that this warmup is expected to happen during the model profile - call and not during actual model inference. - """ - - assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" - - block_m = get_mk_alignment_for_contiguous_layout()[0] - num_experts = w1.size(0) - device = w1.device - - # This is the maximum GroupedGemm M size that we expect to run - # the grouped_gemm with. - MAX_M = compute_aligned_M( - env.VLLM_FUSED_MOE_CHUNK_SIZE, - num_topk, - num_experts, - block_m, - expert_tokens_meta=None, - ) - # Distribute expert-ids evenly. - MAX_BLOCKS = MAX_M // block_m - expert_ids_block = torch.randint( - low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32 - ) - expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) - - def _warmup(w: torch.Tensor, w_scale: torch.Tensor): - _, n, k = w.size() - a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn) - a1q_scales = torch.empty( - (MAX_M, k // block_m), device=device, dtype=torch.float32 - ) - out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) - - pbar = tqdm( - total=MAX_BLOCKS, desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})" - ) - num_tokens = MAX_M - while num_tokens > 0: - m_grouped_fp8_gemm_nt_contiguous( - (a1q[:num_tokens], a1q_scales[:num_tokens]), - (w, w_scale), - out[:num_tokens], - expert_ids[:num_tokens], - ) - pbar.update(1) - num_tokens = num_tokens - block_m - - _warmup(w1, w1_scale) - _warmup(w2, w2_scale) - - class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config: FusedMoEQuantConfig): super().__init__(quant_config) @@ -215,11 +151,49 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ) assert M_sum % block_m == 0 - workspace1 = (M_sum, N) - workspace2 = (M_sum, max(N // 2, K)) + workspace1 = (M_sum, max(N // 2, K)) + workspace2 = (M_sum, max(N, K)) output = (M, K) return (workspace1, workspace2, output) + def _act_mul_quant( + self, input: torch.Tensor, output: torch.Tensor, activation: str + ) -> tuple[torch.Tensor, torch.Tensor]: + assert self.block_shape is not None + block_k = self.block_shape[1] + scale_fmt = DeepGemmQuantScaleFMT.from_oracle() + + # 1. DeepGemm UE8M0: use packed per-token-group quant + if scale_fmt == DeepGemmQuantScaleFMT.UE8M0: + M_sum, N = input.size() + act_out = torch.empty( + (M_sum, N // 2), dtype=input.dtype, device=input.device + ) + self.activation(activation, act_out, input) + a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm( + act_out, + block_k, + out_q=output, + ) + return a2q, a2q_scale + + # 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel + if activation == "silu": + use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 + return silu_mul_per_token_group_quant_fp8_colmajor( + input=input, + output=output, + use_ue8m0=use_ue8m0, + ) + + # 3. fallback path for non-SiLU activations in non‑UE8M0 cases. + M_sum, N = input.size() + act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device) + self.activation(activation, act_out, input) + return per_token_group_quant_fp8( + act_out, block_k, column_major_scales=True, out_q=output + ) + def apply( self, output: torch.Tensor, @@ -261,14 +235,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta=expert_tokens_meta, ) - a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, K)) - mm1_out = _resize_cache(workspace13, (M_sum, N)) - act_out = _resize_cache(workspace2, (M_sum, N // 2)) - quant_out = _resize_cache( - workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) + a1q_perm = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K) ) - mm2_out = _resize_cache(workspace2, (M_sum, K)) - a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute( aq=a1q, aq_scale=a1q_scale, @@ -280,17 +249,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ) assert a1q.size(0) == M_sum + mm1_out = _resize_cache(workspace2, (M_sum, N)) m_grouped_fp8_gemm_nt_contiguous( (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids ) - self.activation(activation, act_out, mm1_out.view(-1, N)) - - a2q_scale: torch.Tensor | None = None - a2q, a2q_scale = per_token_group_quant_fp8( - act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) + ) + a2q, a2q_scale = self._act_mul_quant( + input=mm1_out.view(-1, N), output=quant_out, activation=activation ) + mm2_out = _resize_cache(workspace2, (M_sum, K)) m_grouped_fp8_gemm_nt_contiguous( (a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids ) @@ -322,7 +293,7 @@ def deep_gemm_moe_fp8( expert_map: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, - apply_router_weight_on_input=False, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py index 6cca954123274..57d303cd53fef 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py @@ -84,10 +84,16 @@ def _fwd_kernel_ep_scatter_1( m_indices_start_ptr = m_indices + cur_expert_start off_expert = tl.arange(0, BLOCK_E) + # any rows in the per-expert aligned region that do not correspond to + # real tokens are left untouched here and should remain initialized to + # -1 so DeepGEMM can skip them for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): + offs = start_m + off_expert + mask = offs < cur_expert_token_num tl.store( - m_indices_start_ptr + start_m + off_expert, + m_indices_start_ptr + offs, cur_expert, + mask=mask, ) @@ -366,12 +372,17 @@ def deepgemm_moe_permute( (M_sum, H // block_k), device=device, dtype=torch.float32 ) - maybe_has_empty_blocks = (expert_tokens_meta is None) or ( - expert_tokens_meta.expert_num_tokens_cpu is None + # DeepGEMM uses negative values in m_indices (here expert_ids) to mark + # completely invalid / padded blocks that should be skipped. We always + # initialize expert_ids to -1 so any row that is not explicitly written + # by the scatter kernel will be treated as invalid and skipped by + # DeepGEMM's scheduler. + expert_ids = torch.full( + (M_sum,), + fill_value=-1, + device=device, + dtype=torch.int32, ) - expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty - - expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32) inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32) expert_num_tokens = None diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 9c377db720132..92d72b75656cd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -316,7 +316,11 @@ def fused_marlin_moe( if global_num_experts == -1: global_num_experts = E sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, block_size_m, global_num_experts, expert_map + topk_ids, + block_size_m, + global_num_experts, + expert_map, + ignore_invalid_experts=True, ) assert activation is not None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index df208eae2e71c..b286c3bc6fc07 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -885,16 +885,57 @@ def get_moe_configs( # If no optimized configuration is available, we will use the default # configuration - logger.warning( - ( - "Using default MoE config. Performance might be sub-optimal! " - "Config file not found at %s" - ), - config_file_paths, + logger.warning_once( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s", + ", ".join(config_file_paths), + scope="local", ) return None +def _ensure_block_size_k_divisible( + size_k: int, block_size_k: int, group_size: int +) -> int: + """Ensure block_size_k is a divisor of size_k and divisible by group_size. + + This ensures BLOCK_SIZE_K compatibility with MoeWNA16 CUDA kernel which + requires size_k % BLOCK_SIZE_K == 0 and BLOCK_SIZE_K % group_size == 0. + + Args: + size_k: The size_k dimension that must be divisible by result. + block_size_k: Preferred block size (will be adjusted if needed). + group_size: The result must be divisible by this. + + Returns: + A valid BLOCK_SIZE_K that divides size_k and is divisible by group_size. + """ + # Fast path: already valid + if size_k % block_size_k == 0 and block_size_k % group_size == 0: + return block_size_k + + # Find the largest value that: + # 1. Divides size_k (size_k % candidate == 0) + # 2. Is divisible by group_size (candidate % group_size == 0) + # 3. Is <= block_size_k (prefer smaller values close to block_size_k) + # + # Strategy: Search from min(block_size_k, size_k) down to group_size, + # stepping by group_size to ensure divisibility by group_size + max_search = min(block_size_k, size_k) + start = (max_search // group_size) * group_size + for candidate in range(start, group_size - 1, -group_size): + if size_k % candidate == 0: + return candidate + + # Fallback: if group_size divides size_k, use it + # This should always be true with correct group_size configuration + if size_k % group_size == 0: + return group_size + + # This should not happen with correct group_size, but ensure divisibility + return size_k + + def get_moe_wna16_block_config( config: dict[str, int], use_moe_wna16_cuda: bool, @@ -960,6 +1001,9 @@ def get_moe_wna16_block_config( # at the same time. block_size_n = 1024 + # Ensure BLOCK_SIZE_K is a divisor of size_k for CUDA kernel compatibility + block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size) + return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} @@ -1887,7 +1931,11 @@ def fused_experts_impl( ) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map + curr_topk_ids, + config["BLOCK_SIZE_M"], + global_num_experts, + expert_map, + ignore_invalid_experts=True, ) invoke_fused_moe_kernel( @@ -1946,6 +1994,9 @@ def fused_experts_impl( block_shape=block_shape, ) + if expert_map is not None: + intermediate_cache3.zero_() + invoke_fused_moe_kernel( qintermediate_cache2, w2, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index ef7090c349fc6..8c9d8a2777d58 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod -from collections.abc import Callable import torch @@ -100,22 +99,5 @@ class FusedMoEMethodBase(QuantizeMethodBase): layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index c23c41df226f0..9c9bc2514bb4b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch @@ -51,6 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), shared_experts, getattr(moe_layer, "shared_experts_stream", None), + moe_parallel_config=moe_layer.moe_parallel_config, ), ) @@ -91,23 +91,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: topk_weights, topk_ids, zero_expert_result = layer.select_experts( hidden_states=x, @@ -121,10 +104,10 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): topk_weights=topk_weights, topk_ids=topk_ids, inplace=self.allow_inplace, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=None if self.disable_expert_map else expert_map, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + expert_map=None if self.disable_expert_map else layer.expert_map, ) if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 902a77987d61a..cc3afade709d9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -33,10 +33,6 @@ from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, -) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( init_aiter_topK_meta_data, ) @@ -57,11 +53,8 @@ from vllm.utils.torch_utils import ( from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): - from .fused_moe import eplb_map_to_physical_and_record, fused_experts + from .fused_moe import eplb_map_to_physical_and_record else: - fused_experts = None # type: ignore - FusedMoEPermuteExpertsUnpermute = object # type: ignore - FusedMoEPrepareAndFinalize = object # type: ignore def _eplb_map_to_physical_and_record( topk_ids: torch.Tensor, @@ -376,7 +369,9 @@ class FusedMoE(CustomOp): # aux_stream() returns None on non-cuda-alike platforms. self.shared_experts_stream = aux_stream() if self.shared_experts_stream is not None: - logger.info_once("Enabled separate cuda stream for MoE shared_experts") + logger.info_once( + "Enabled separate cuda stream for MoE shared_experts", scope="local" + ) if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -483,7 +478,7 @@ class FusedMoE(CustomOp): enable_eplb=self.enable_eplb, ) - self.expert_map: torch.Tensor | None + self._expert_map: torch.Tensor | None local_num_experts, expert_map, expert_mask = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, @@ -493,7 +488,7 @@ class FusedMoE(CustomOp): return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts - self.register_buffer("expert_map", expert_map) + self.register_buffer("_expert_map", expert_map) self.register_buffer("expert_mask", expert_mask) self._maybe_init_expert_routing_tables() logger.info_once( @@ -506,10 +501,10 @@ class FusedMoE(CustomOp): self.expert_placement_strategy, self.local_num_experts, self.global_num_experts, - get_compressed_expert_map(self.expert_map), + get_compressed_expert_map(self._expert_map), ) else: - self.local_num_experts, self.expert_map, self.expert_mask = ( + self.local_num_experts, self._expert_map, self.expert_mask = ( self.global_num_experts, None, None, @@ -520,6 +515,10 @@ class FusedMoE(CustomOp): self._init_aiter_shared_experts_topK_buffer( vllm_config=vllm_config, dp_size=dp_size_ ) + if self.use_ep and self.rocm_aiter_fmoe_enabled: + assert self.expert_mask is None or torch.all( + (expert_mask == 0) | (expert_mask == 1) + ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size @@ -749,7 +748,7 @@ class FusedMoE(CustomOp): self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) - ) + ) and envs.VLLM_ENABLE_MOE_DP_CHUNK @property def is_internal_router(self) -> bool: @@ -777,7 +776,7 @@ class FusedMoE(CustomOp): ), ) - if self.expert_map is None: + if self._expert_map is None: return None routing_tables = self.ensure_round_robin_expert_routing_tables( @@ -785,7 +784,7 @@ class FusedMoE(CustomOp): ep_size=self.ep_size, ep_rank=self.ep_rank, local_num_experts=self.local_num_experts, - device=self.expert_map.device, + device=self._expert_map.device, ) global_to_physical, physical_to_global, local_global = routing_tables @@ -836,8 +835,8 @@ class FusedMoE(CustomOp): def update_expert_map(self): # ep_size and ep_rank should already be updated - assert self.expert_map is not None - with self.expert_map.device: + assert self._expert_map is not None + with self._expert_map.device: local_num_experts, expert_map, expert_mask = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, @@ -847,7 +846,7 @@ class FusedMoE(CustomOp): return_expert_mask=self.rocm_aiter_fmoe_enabled, ) self.local_num_experts = local_num_experts - self.register_buffer("expert_map", expert_map) + self.register_buffer("_expert_map", expert_map) self.register_buffer("expert_mask", expert_mask) self._maybe_init_expert_routing_tables() if self.aiter_fmoe_shared_expert_enabled: @@ -863,7 +862,8 @@ class FusedMoE(CustomOp): use_chunked_impl: bool, ) -> tuple[bool, torch.Tensor | None]: use_shared_experts_stream = ( - has_separate_shared_experts + current_platform.is_cuda() + and has_separate_shared_experts and not use_chunked_impl and self.shared_experts_stream is not None and ( @@ -883,7 +883,7 @@ class FusedMoE(CustomOp): # Record that the clone will be used by shared_experts_stream # to avoid gc issue from deallocation of hidden_states_clone # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 - # NOTE: We dont need shared_output.record_stream(current_stream()) + # NOTE: We don't need shared_output.record_stream(current_stream()) # because we synch the streams before using shared_output. hidden_states_clone.record_stream(self.shared_experts_stream) @@ -1063,9 +1063,9 @@ class FusedMoE(CustomOp): expert_data.copy_(loaded_weight) def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: - if self.expert_map is None: + if self._expert_map is None: return expert_id - return self.expert_map[expert_id].item() + return self._expert_map[expert_id].item() def _init_aiter_shared_experts_topK_buffer( self, vllm_config: VllmConfig, dp_size: int @@ -1202,10 +1202,14 @@ class FusedMoE(CustomOp): if full_load: shard_dim += 1 - # Materialize GGUF UninitializedParameter + # Materialize GGUF UninitializedParameter accounting merged weights if is_gguf_weight and isinstance(param, UninitializedParameter): + # To materialize a tensor, we must have full shape including + # number of experts, making this portion to require `full_load`. + assert full_load final_shape = list(loaded_weight.shape) - if shard_id in ["w1", "w3"]: + # w1 and w3 are merged per expert. + if shard_id in {"w1", "w3"}: final_shape[1] *= 2 final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size param.materialize(final_shape, dtype=loaded_weight.dtype) @@ -1558,6 +1562,14 @@ class FusedMoE(CustomOp): f"EPLB is not supported for {self.quant_method.method_name}." ) + def valid_grouping() -> bool: + # Check if num_experts is greater than num_expert_group + # and is divisible by num_expert_group + num_experts = router_logits.shape[-1] + if num_experts <= self.num_expert_group: + return False + return num_experts % self.num_expert_group == 0 + indices_type = self.quant_method.topk_indices_dtype # Check if we should use a routing simulation strategy @@ -1572,7 +1584,7 @@ class FusedMoE(CustomOp): ) # DeepSeekv2 uses grouped_top_k - elif self.use_grouped_topk: + elif self.use_grouped_topk and valid_grouping(): assert self.topk_group is not None assert self.num_expert_group is not None if rocm_aiter_ops.is_fused_moe_enabled(): @@ -1739,6 +1751,12 @@ class FusedMoE(CustomOp): reduce_output(fused_output)[..., :og_hidden_states], ) + @property + def expert_map(self) -> torch.Tensor | None: + return ( + self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask + ) + def forward_cuda( self, hidden_states: torch.Tensor, @@ -1800,24 +1818,6 @@ class FusedMoE(CustomOp): layer=self, x=staged_hidden_states, router_logits=staged_router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.global_num_experts, - expert_map=self.expert_map - if not self.rocm_aiter_fmoe_enabled - else self.expert_mask, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.e_score_correction_bias, - activation=self.activation, - enable_eplb=self.enable_eplb, - expert_load_view=self.expert_load_view, - logical_to_physical_map=self.logical_to_physical_map, - logical_replica_count=self.logical_replica_count, ) if has_separate_shared_experts: @@ -1963,25 +1963,6 @@ class FusedMoE(CustomOp): if do_naive_dispatch_combine else hidden_states, router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - global_num_experts=self.global_num_experts, - expert_map=self.expert_map - if not self.rocm_aiter_fmoe_enabled - else self.expert_mask, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function, - scoring_func=self.scoring_func, - routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.e_score_correction_bias, - activation=self.activation, - apply_router_weight_on_input=self.apply_router_weight_on_input, - enable_eplb=self.enable_eplb, - expert_load_view=self.expert_load_view, - logical_to_physical_map=self.logical_to_physical_map, - logical_replica_count=self.logical_replica_count, ) if has_separate_shared_experts: diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b2af58cdca887..484314091cb15 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,10 +10,12 @@ from typing import final import torch import vllm.envs as envs -from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, count_expert_num_tokens, @@ -22,12 +24,12 @@ from vllm.model_executor.layers.fused_moe.utils import ( from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.v1.worker.ubatching import ( - dbo_current_ubatch_id, dbo_enabled, dbo_maybe_run_recv_hook, dbo_register_recv_hook, dbo_yield, ) +from vllm.v1.worker.workspace import current_workspace_manager logger = init_logger(__name__) @@ -367,7 +369,7 @@ class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPermuteExpertsUnpermute(ABC): """ An abstract base class for the [Permute-Experts-Unpermute] step described - above. + above. """ def __init__( @@ -661,25 +663,6 @@ def _slice_scales( return None -class SharedResizableBuffer: - def __init__(self): - self.buffer = None - - def get( - self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype - ) -> torch.Tensor: - assert shape != () - shape_numel = prod(shape) - if ( - self.buffer is None - or self.buffer.numel() < shape_numel - or self.buffer.device != device - or self.buffer.dtype != dtype - ): - self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) - return self.buffer[:shape_numel].view(*shape) - - @final class FusedMoEModularKernel(torch.nn.Module): """ @@ -694,28 +677,13 @@ class FusedMoEModularKernel(torch.nn.Module): objects. """ - class SharedBuffers: - def __init__(self) -> None: - self.fused_out = SharedResizableBuffer() - self.workspace13 = SharedResizableBuffer() - self.workspace2 = SharedResizableBuffer() - - # Persistent buffers that are shared across `FusedMoEModularKernel` - # instances (layers), to save memory and allocattions. - # - # We have two sets of buffers to support dual batch overlap (DBO) where each - # microbatch (ubatch) should use its own set of buffers to avoid - # cross-ubatch contimination. - # NOTE that memory is lazily allocated for these buffers, meaning that if - # DBO isn't being used, the second SharedBuffers will be empty. - shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()] - def __init__( self, prepare_finalize: FusedMoEPrepareAndFinalize, fused_experts: FusedMoEPermuteExpertsUnpermute, shared_experts: torch.nn.Module | None = None, shared_experts_stream: torch.cuda.Stream | None = None, + moe_parallel_config: FusedMoEParallelConfig | None = None, ): super().__init__() self.prepare_finalize = prepare_finalize @@ -723,6 +691,17 @@ class FusedMoEModularKernel(torch.nn.Module): self.shared_experts = shared_experts self.shared_experts_stream = shared_experts_stream + # prefer an explicit FusedMoEParallelConfig when available (from + # FusedMoE layers / tests). + # if not provided, assume this kernel is + # running in a non-DP+EP context + self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config + self.is_dp_ep = ( + moe_parallel_config is not None + and moe_parallel_config.dp_size > 1 + and moe_parallel_config.use_ep + ) + self._post_init_setup() assert ( prepare_finalize.activation_format == fused_experts.activation_formats[0] @@ -797,10 +776,6 @@ class FusedMoEModularKernel(torch.nn.Module): assert M_full > 0 and M_chunk > 0 num_chunks, _ = self._chunk_info(M_full) - - # select per-ubatch buffers to avoid cross-ubatch reuse under DBO - ubatch_idx = dbo_current_ubatch_id() - buffers = self.shared_buffers[ubatch_idx] workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) # Force worst-case allocation in profiling run for @@ -811,33 +786,24 @@ class FusedMoEModularKernel(torch.nn.Module): is_forward_context_available() and get_forward_context().attn_metadata is None ) - if is_profile_run and self.fused_experts.supports_chunking(): - parallel_config = get_current_vllm_config().parallel_config - is_dp_ep = ( - parallel_config.data_parallel_size > 1 - and parallel_config.enable_expert_parallel + if is_profile_run and self.fused_experts.supports_chunking() and self.is_dp_ep: + max_workspace_13, max_workspace_2, max_fused_out_shape = ( + self.fused_experts.workspace_shapes( + envs.VLLM_FUSED_MOE_CHUNK_SIZE, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) + ) + + current_workspace_manager().get_simultaneous( + (max_workspace_13, workspace_dtype), + (max_workspace_2, workspace_dtype), + (max_fused_out_shape, out_dtype), ) - if is_dp_ep: - max_workspace_13, max_workspace_2, max_fused_out_shape = ( - self.fused_experts.workspace_shapes( - envs.VLLM_FUSED_MOE_CHUNK_SIZE, - N, - K, - top_k, - global_num_experts, - local_num_experts, - expert_tokens_meta, - ) - ) - buffers.workspace13.get( - max_workspace_13, device=device, dtype=workspace_dtype - ) - buffers.workspace2.get( - max_workspace_2, device=device, dtype=workspace_dtype - ) - buffers.fused_out.get( - max_fused_out_shape, device=device, dtype=workspace_dtype - ) # Get intermediate workspace shapes based off the chunked M size. workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes( @@ -863,22 +829,23 @@ class FusedMoEModularKernel(torch.nn.Module): # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = buffers.workspace13.get( - workspace13_shape, device=device, dtype=workspace_dtype - ) - workspace2 = buffers.workspace2.get( - workspace2_shape, device=device, dtype=workspace_dtype - ) - # Construct the entire output that can then be processed in chunks. # Reuse workspace13 for the output in the non-chunked case as long # as it is large enough. This will not always be the case for standard # format experts and with experts that have empty workspaces. if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): + workspace13, workspace2 = current_workspace_manager().get_simultaneous( + (workspace13_shape, workspace_dtype), + (workspace2_shape, workspace_dtype), + ) fused_out = _resize_cache(workspace13, fused_out_shape) else: - fused_out = buffers.fused_out.get( - fused_out_shape, device=device, dtype=out_dtype + workspace13, workspace2, fused_out = ( + current_workspace_manager().get_simultaneous( + (workspace13_shape, workspace_dtype), + (workspace2_shape, workspace_dtype), + (fused_out_shape, out_dtype), + ) ) return workspace13, workspace2, fused_out diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index 7f6155997264d..7fc8bfcf824d9 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -14,6 +14,7 @@ def moe_align_block_size( num_experts: int, expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, + ignore_invalid_experts: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -35,7 +36,13 @@ def moe_align_block_size( expert parallel shard. If the expert is not in the current expert parallel shard, the mapping is set to -1. - pad_sorted_ids: A flag indicating whether the sorted_token_ids length - should be padded to a multiple of block_size, + should be padded to a multiple of block_size, + - ignore_invalid_experts: A flag indicating whether to ignore invalid + experts. When False, all expert_ids in topk_ids will participate in + counting and ranking, but invalid experts in expert_ids will be marked + as -1. When True, all invalid expert_ids in topk_ids will be ignored + and will not participate in counting or ranking, and there will be no + -1 in expert_ids. Returns: - sorted_token_ids: A tensor containing the sorted token indices according @@ -67,6 +74,10 @@ def moe_align_block_size( max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + if topk_ids.numel() < num_experts: + max_num_tokens_padded = min( + topk_ids.numel() * block_size, max_num_tokens_padded + ) sorted_ids = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) @@ -77,9 +88,16 @@ def moe_align_block_size( num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + expert_map if ignore_invalid_experts else None, ) - if expert_map is not None: + + if expert_map is not None and not ignore_invalid_experts: expert_ids = expert_map[expert_ids] return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 8f05828d74f5f..882ad0a537cd5 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -221,8 +221,8 @@ def rocm_aiter_fused_experts( else: quant_method = QuantMethod.NO.value - # quark moe for mxfp4 w_dtype - if quant_config.use_mxfp4_w4a16: + # quark moe for mxfp4 w_dtype mxfp4 a_dtype + if quant_config.use_mxfp4_w4a4: quant_method = QuantMethod.BLOCK_1X32.value # w8a8 block-scaled if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 9aaeec4f98a61..a143347b19f2c 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -29,14 +29,14 @@ class SharedFusedMoE(FusedMoE): self._shared_experts = shared_experts # Disable shared expert overlap if: - # - we are using eplb, because of correctness issues + # - we are using eplb with non-default backend, because of correctness issues # - we are using flashinfer with DP, since there nothint to gain - # - we are using marlin kjernels + # - we are using marlin kernels + backend = self.moe_parallel_config.all2all_backend self.use_overlapped = ( use_overlapped and not ( - # TODO(wentao): find the root cause and remove this condition - self.enable_eplb + (self.enable_eplb and backend != "allgather_reducescatter") or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) ) and self._shared_experts is not None diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 48e5a8907f926..6182f10aa70f0 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch import torch.nn.functional as F @@ -269,53 +268,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def apply( self, - layer: torch.nn.Module, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - return self.forward( - x=x, layer=layer, + x=x, router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - global_num_experts=global_num_experts, - expert_map=expert_map, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - enable_eplb=enable_eplb, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, ) def get_fused_moe_quant_config( @@ -333,24 +293,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: topk_weights, topk_ids, zero_expert_result = layer.select_experts( hidden_states=x, @@ -364,9 +307,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - expert_map=expert_map, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=layer.expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) elif self.flashinfer_cutlass_moe_enabled: return self.flashinfer_cutlass_moe( @@ -375,8 +318,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: result = fused_experts( @@ -386,11 +329,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, + activation=layer.activation, quant_config=self.moe_quant_config, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, ) if layer.zero_expert_num != 0 and layer.zero_expert_type is not None: @@ -405,148 +348,101 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if ( - enable_eplb is not False - or expert_load_view is not None - or logical_to_physical_map is not None - or logical_replica_count is not None + layer.enable_eplb is not False + or layer.expert_load_view is not None + or layer.logical_to_physical_map is not None + or layer.logical_replica_count is not None ): raise NotImplementedError("Expert load balancing is not supported for CPU.") + return layer.cpu_fused_moe( layer, x, - use_grouped_topk, - top_k, + layer.use_grouped_topk, + layer.top_k, router_logits, - renormalize, - topk_group, - num_expert_group, - global_num_experts, - expert_map, - custom_routing_function, - scoring_func, - routed_scaling_factor, - e_score_correction_bias, - apply_router_weight_on_input, - activation, + layer.renormalize, + layer.topk_group, + layer.num_expert_group, + layer.global_num_experts, + layer.expert_map, + layer.custom_routing_function, + layer.scoring_func, + layer.routed_scaling_factor, + layer.e_score_correction_bias, + layer.apply_router_weight_on_input, + layer.activation, ) def forward_xpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if ( - enable_eplb is not False - or expert_load_view is not None - or logical_to_physical_map is not None - or logical_replica_count is not None + layer.enable_eplb is not False + or layer.expert_load_view is not None + or layer.logical_to_physical_map is not None + or layer.logical_replica_count is not None ): raise NotImplementedError("Expert load balancing is not supported for XPU.") return layer.ipex_fusion( x, - use_grouped_topk, - top_k, + layer.use_grouped_topk, + layer.top_k, router_logits, - renormalize, - topk_group, - num_expert_group, - custom_routing_function=custom_routing_function, + layer.renormalize, + layer.topk_group, + layer.num_expert_group, + custom_routing_function=layer.custom_routing_function, ) def forward_tpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert not use_grouped_topk - assert num_expert_group is None - assert topk_group is None - assert custom_routing_function is None - assert apply_router_weight_on_input is False - if scoring_func != "softmax": + assert not layer.use_grouped_topk + assert layer.num_expert_group is None + assert layer.topk_group is None + assert layer.custom_routing_function is None + assert layer.apply_router_weight_on_input is False + if layer.scoring_func != "softmax": raise NotImplementedError( "Only softmax scoring function is supported for TPU." ) - if e_score_correction_bias is not None: + if layer.e_score_correction_bias is not None: raise NotImplementedError( "Expert score correction bias is not supported for TPU." ) - assert activation == "silu", f"{activation} is not supported for TPU." - assert routed_scaling_factor == 1.0, ( - f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU." + assert layer.activation == "silu", ( + f"{layer.activation} is not supported for TPU." + ) + assert layer.routed_scaling_factor == 1.0, ( + f"routed_scaling_factor {layer.routed_scaling_factor} is " + "not supported for TPU." ) if ( - enable_eplb is not False - or expert_load_view is not None - or logical_to_physical_map is not None - or logical_replica_count is not None + layer.enable_eplb is not False + or layer.expert_load_view is not None + or layer.logical_to_physical_map is not None + or layer.logical_replica_count is not None ): raise NotImplementedError("Expert load balancing is not supported for TPU.") return fused_moe_pallas( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - topk=top_k, + topk=layer.top_k, gating_output=router_logits, - global_num_experts=global_num_experts, - expert_map=expert_map, - renormalize=renormalize, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + renormalize=layer.renormalize, ) if current_platform.is_tpu(): diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 90e520e244416..0b63acf2dc5a5 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -252,7 +252,6 @@ class MambaMixer(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p - num_padded_decodes = attn_metadata.num_padded_decodes # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -281,7 +280,7 @@ class MambaMixer(MambaBase, CustomOp): state_indices_tensor, num_prefill_tokens, num_prefills, - num_padded_decodes, + num_decode_tokens, ) hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d @@ -470,24 +469,24 @@ def split_batch_to_prefill_and_decode( state_indices_tensor: torch.Tensor, num_prefill_tokens: int, num_prefills: int, - num_padded_decodes: int, + num_decode_tokens: int, ) -> PrefillDecodeSplit: - num_actual_tokens = num_prefill_tokens + num_padded_decodes + num_actual_tokens = num_prefill_tokens + num_decode_tokens # In v1, decode tokens come first, then prefill tokens. hidden_states_BC_d, hidden_states_BC_p = torch.split( hidden_states_BC[..., :num_actual_tokens], - [num_padded_decodes, num_prefill_tokens], + [num_decode_tokens, num_prefill_tokens], dim=-1, ) gate_d, gate_p = torch.split( - gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1 + gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1 ) - # num_padded_decodes accounts for CUDA graph padding when applicable + # num_decode_tokens accounts for CUDA graph padding when applicable state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[: num_padded_decodes + num_prefills], - [num_padded_decodes, num_prefills], + state_indices_tensor[: num_decode_tokens + num_prefills], + [num_decode_tokens, num_prefills], dim=0, ) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 53fd5d5458b09..800f8bd840792 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -36,10 +36,14 @@ else: is not None } ) +@triton.heuristics( + {"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens_ptr"] is not None} +) +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens_ptr"] is not None}) @triton.heuristics( {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} ) -@triton.jit +@triton.jit(do_not_specialize=["N"]) def _selective_scan_update_kernel( # Pointers to matrices state_ptr, @@ -55,8 +59,10 @@ def _selective_scan_update_kernel( state_batch_indices_ptr, dst_state_batch_indices_ptr, pad_slot_id, + num_accepted_tokens_ptr, + cu_seqlens_ptr, # Matrix dimensions - batch, + N, nheads, dim, dstate, @@ -91,6 +97,10 @@ def _selective_scan_update_kernel( stride_out_batch, stride_out_head, stride_out_dim, + stride_state_indices_batch, + stride_state_indices_T, + stride_dst_state_indices_batch, + stride_dst_state_indices_T, # Meta-parameters DT_SOFTPLUS: tl.constexpr, TIE_HDIM: tl.constexpr, @@ -99,22 +109,50 @@ def _selective_scan_update_kernel( HAS_D: tl.constexpr, HAS_Z: tl.constexpr, HAS_STATE_BATCH_INDICES: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_VARLEN: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) + if IS_VARLEN: + bos = tl.load(cu_seqlens_ptr + pid_b).to(tl.int64) + eos = tl.load(cu_seqlens_ptr + pid_b + 1).to(tl.int64) + seq_len = eos - bos + + if seq_len == 0: + return + else: + bos = pid_b + seq_len = 1 + + state_ptr_base = state_ptr + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate # is taken from the state_batch_indices_ptr Otherwise, the state coordinate # is the same as the batch id. if HAS_STATE_BATCH_INDICES: - dst_state_batch_indices_ptr += pid_b - dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64) - dst_state_ptr = state_ptr + ( - dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head + if IS_SPEC_DECODING: + num_accepted = tl.load(num_accepted_tokens_ptr + pid_b).to(tl.int64) + init_token_idx = tl.maximum(num_accepted - 1, 0) + else: + init_token_idx = 0 + + dst_state_batch_indices_ptr += pid_b * stride_dst_state_indices_batch + if not IS_SPEC_DECODING: + dst_state_batch_idx = tl.load( + dst_state_batch_indices_ptr + + init_token_idx * stride_dst_state_indices_T + ).to(tl.int64) + dst_state_ptr = state_ptr + ( + dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head + ) + + state_batch_indices_ptr += ( + pid_b * stride_state_indices_batch + init_token_idx * stride_state_indices_T ) - state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head else: @@ -123,86 +161,112 @@ def _selective_scan_update_kernel( ) state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head - x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + x_ptr += bos * stride_x_batch + pid_h * stride_x_head + dt_ptr += bos * stride_dt_batch + pid_h * stride_dt_head if HAS_DT_BIAS: dt_bias_ptr += pid_h * stride_dt_bias_head A_ptr += pid_h * stride_A_head - B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group - C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + B_ptr += bos * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += bos * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group if HAS_Z: - z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + z_ptr += bos * stride_z_batch + pid_h * stride_z_head + out_ptr += bos * stride_out_batch + pid_h * stride_out_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) state_ptrs = state_ptr + ( offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate ) - dst_state_ptrs = dst_state_ptr + ( - offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate - ) - x_ptrs = x_ptr + offs_m * stride_x_dim - dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if not IS_SPEC_DECODING: + dst_state_ptrs = dst_state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= state_batch_idx != pad_slot_id + state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) + if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: D_ptr += pid_h * stride_D_head - A_ptrs = A_ptr + ( - offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate - ) - B_ptrs = B_ptr + offs_n * stride_B_dstate - C_ptrs = C_ptr + offs_n * stride_C_dstate - if HAS_D: D_ptrs = D_ptr + offs_m * stride_D_dim - if HAS_Z: - z_ptrs = z_ptr + offs_m * stride_z_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= state_batch_idx != pad_slot_id - state = tl.load(state_ptrs, mask=mask, other=0.0) + A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate - x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if not TIE_HDIM: - dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if DT_SOFTPLUS: - dt = softplus(dt) - A = tl.load( - A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 - ).to(tl.float32) - dA = tl.exp(A * dt[:, None]) - else: - dt = tl.load(dt_ptr).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptr).to(tl.float32) - if DT_SOFTPLUS: - dt = softplus(dt) - A = tl.load(A_ptr).to(tl.float32) - dA = tl.exp(A * dt) # scalar, not a matrix + for i_t in range(seq_len): + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim - B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - if HAS_D: - D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_Z: - z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load( + A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix - dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt - state = state * dA + dB * x[:, None] + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= state_batch_idx != pad_slot_id - tl.store(dst_state_ptrs, state, mask=mask) - out = tl.sum(state * C[None, :], axis=1) - if HAS_D: - out += x * D - if HAS_Z: - out *= z * tl.sigmoid(z) - tl.store(out_ptrs, out, mask=offs_m < dim) + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + + if IS_SPEC_DECODING: + dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T + token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64) + if token_dst_idx != pad_slot_id: + token_dst_ptrs = ( + state_ptr_base + + token_dst_idx * stride_state_batch + + pid_h * stride_state_head + + offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate + ) + tl.store( + token_dst_ptrs, state.to(token_dst_ptrs.dtype.element_ty), mask=mask + ) + + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + x_ptr += stride_x_batch + dt_ptr += stride_dt_batch + B_ptr += stride_B_batch + C_ptr += stride_C_batch + out_ptr += stride_out_batch + if HAS_Z: + z_ptr += stride_z_batch + + if not IS_SPEC_DECODING: + tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask) def selective_state_update( @@ -220,6 +284,8 @@ def selective_state_update( dst_state_batch_indices=None, pad_slot_id=PAD_SLOT_ID, out=None, + num_accepted_tokens=None, + cu_seqlens=None, ): """ Argument: @@ -240,6 +306,11 @@ def selective_state_update( indices 0 and 3 out: Preallocated ssm output tensor. Assume same shape as x. In-place updated. + num_accepted_tokens: (batch,) + number of accepted tokens from previous verification step, + tells the kernel which initial state to use + cu_seqlens: (batch,) + length per sequence, for variable length in speculative decoding cases """ if state.dim() == 3: state = state.unsqueeze(1) @@ -261,9 +332,26 @@ def selective_state_update( dt_bias = dt_bias.unsqueeze(0) if out.dim() == 2: out = out.unsqueeze(1) + if num_accepted_tokens is not None: + assert state_batch_indices is not None and state_batch_indices.dim() == 2 + assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2 + if state_batch_indices is not None and state_batch_indices.dim() == 1: + state_batch_indices = state_batch_indices.unsqueeze(1) + if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1: + dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1) _, nheads, dim, dstate = state.shape batch = x.shape[0] + if cu_seqlens is not None: + N = len(cu_seqlens) - 1 + # Only used to verify the shape of + # state_batch_indices and dst_state_batch_indices + max_seqlen = ( + state_batch_indices.size(-1) if state_batch_indices is not None else 1 + ) + else: + N = batch + max_seqlen = 1 assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape @@ -279,16 +367,30 @@ def selective_state_update( if dt_bias is not None: assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: - assert state_batch_indices.shape == (batch,) + assert state_batch_indices.shape[0] >= N + assert state_batch_indices.shape[1] >= max_seqlen if dst_state_batch_indices is not None: - assert dst_state_batch_indices.shape == (batch,) + assert dst_state_batch_indices.shape[0] >= N + assert dst_state_batch_indices.shape[1] >= max_seqlen else: # revert to the default behavior of in-place state updates dst_state_batch_indices = state_batch_indices assert out.shape == x.shape + if num_accepted_tokens is not None: + assert num_accepted_tokens.shape == (N,) - grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), N, nheads) z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + state_batch_indices_strides = ( + (state_batch_indices.stride(0), state_batch_indices.stride(1)) + if state_batch_indices is not None + else (0, 0) + ) + dst_state_batch_indices_strides = ( + (dst_state_batch_indices.stride(0), dst_state_batch_indices.stride(1)) + if dst_state_batch_indices is not None + else (0, 0) + ) # We don't want autotune since it will overwrite the state # We instead tune by hand. BLOCK_SIZE_M, num_warps = ( @@ -321,7 +423,9 @@ def selective_state_update( state_batch_indices, dst_state_batch_indices, pad_slot_id, - batch, + num_accepted_tokens, + cu_seqlens, + N, nheads, dim, dstate, @@ -353,6 +457,10 @@ def selective_state_update( out.stride(0), out.stride(1), out.stride(2), + state_batch_indices_strides[0], + state_batch_indices_strides[1], + dst_state_batch_indices_strides[0], + dst_state_batch_indices_strides[1], dt_softplus, tie_hdim, BLOCK_SIZE_M, diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 6ebfa47a9dc3f..1656f4deb6717 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -24,9 +24,9 @@ class MLAModules: q_b_proj: torch.nn.Module | None q_proj: torch.nn.Module | None indexer: torch.nn.Module | None - indexer_rotary_emb: torch.nn.Module | None is_sparse: bool topk_indices_buffer: torch.Tensor | None + indexer_rotary_emb: torch.nn.Module | None = None @CustomOp.register("multi_head_latent_attention") @@ -111,6 +111,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp): self, positions: torch.Tensor, hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: q_c = None kv_lora = None @@ -159,6 +160,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp): hidden_states, q_c, positions, self.indexer_rope_emb ) + if llama_4_scaling is not None: + q *= llama_4_scaling + attn_out = self.mla_attn( q, kv_c_normed, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 7dd02e32ff211..d1942689d7f5c 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -64,42 +64,6 @@ class PoolingParamsUpdate: params.requires_token_ids = self.requires_token_ids -def get_prompt_lens( - hidden_states: torch.Tensor | list[torch.Tensor], - pooling_metadata: PoolingMetadata, -) -> torch.Tensor: - return pooling_metadata.prompt_lens - - -def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: - assert pooling_metadata.prompt_token_ids is not None, ( - "Please set `requires_token_ids=True` in `get_pooling_updates`" - ) - - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] - - -def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]: - pooling_params = pooling_metadata.pooling_params - return pooling_params - - -def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: - pooling_params = get_pooling_params(pooling_metadata) - - tasks: list[PoolingTask] = [ - task - for pooling_param in pooling_params - if (task := pooling_param.task) is not None - ] - assert len(pooling_params) == len(tasks) - - return tasks - - def get_classification_activation_function(config: PretrainedConfig): # Implement alignment with transformers ForSequenceClassificationLoss # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92 @@ -163,14 +127,14 @@ class PoolingMethod(nn.Module, ABC): self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: raise NotImplementedError def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: pooling_cursor = pooling_metadata.pooling_cursor return self.forward_all(hidden_states, pooling_cursor) @@ -183,7 +147,7 @@ class CLSPool(PoolingMethod): self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with CLS pooling" ) @@ -199,27 +163,65 @@ class LastPool(PoolingMethod): self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: return hidden_states[pooling_cursor.last_token_indices_gpu] class AllPool(PoolingMethod): + def __init__(self): + super().__init__() + + vllm_config = get_current_vllm_config() + self.enable_chunked_prefill = ( + vllm_config.scheduler_config.enable_chunked_prefill + ) + def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed", "token_classify"} def forward_all( - self, - hidden_states: torch.Tensor, - pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: - assert not pooling_cursor.is_partial_prefill(), ( - "partial prefill not supported with ALL pooling" + self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor + ) -> PoolerOutput: + raise NotImplementedError( + "forward_all is not implemented for AllPool. Use forward instead." ) + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooling_cursor = pooling_metadata.pooling_cursor + is_finished = pooling_cursor.is_finished() hidden_states_lst = list( hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()) ) - return [hidden_states_lst[i] for i in pooling_cursor.index] + hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index] + + if not self.enable_chunked_prefill: + return hidden_states_lst + + pooling_states = pooling_metadata.pooling_states + + # If chunked_prefill is enabled + # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache + for p, hs_chunk in zip(pooling_states, hidden_states_lst): + p.hidden_states_cache.append(hs_chunk) + + # 2. Once prefill is finished, send hidden_states_cache to PoolerHead + output_list: PoolerOutput = [] + for p, finished in zip(pooling_states, is_finished): + if finished: + hidden_states_cache = p.hidden_states_cache + if len(hidden_states_cache) == 1: + output_list.append(hidden_states_cache[0]) + else: + output_list.append(torch.concat(hidden_states_cache, dim=0)) + p.clean() + else: + output_list.append(None) + + return output_list class MeanPool(PoolingMethod): @@ -230,7 +232,7 @@ class MeanPool(PoolingMethod): self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with MEAN pooling" ) @@ -435,7 +437,7 @@ class PoolerHead(nn.Module): self, pooled_data: list[torch.Tensor] | torch.Tensor, pooling_metadata: PoolingMetadata, - ): + ) -> PoolerOutput: return self.activation(pooled_data) @@ -454,7 +456,7 @@ class EmbeddingPoolerHead(PoolerHead): self, pooled_data: list[torch.Tensor] | torch.Tensor, pooling_metadata: PoolingMetadata, - ): + ) -> PoolerOutput: if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_dimension] @@ -466,7 +468,7 @@ class EmbeddingPoolerHead(PoolerHead): pooled_data = self.projector(pooled_data) # pooled_data shape: [batchsize, embedding_dimension] - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params # for matryoshka representation dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params] @@ -606,7 +608,7 @@ class ClassifierPooler(Pooler): if self.logit_bias is not None: pooled_data -= self.logit_bias - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params flags = [p.use_activation for p in pooling_params] if len(set(flags)) == 1: @@ -622,8 +624,12 @@ class ClassifierPooler(Pooler): class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): def forward( - self, pooled_data: torch.Tensor, pooling_param: PoolingParams - ) -> torch.Tensor: + self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams + ) -> PoolerOutput: + # for unfinished chunked prefill + if pooled_data is None: + return None + pooled_data = pooled_data.to(self.head_dtype) # pooled_data shape: [n_tokens, hidden_dimension] @@ -666,9 +672,13 @@ class TokenClassifierPoolerHead(nn.Module): def forward( self, - hidden_states: torch.Tensor, + hidden_states: torch.Tensor | None, pooling_param: PoolingParams, - ) -> torch.Tensor: + ) -> PoolerOutput: + # for unfinished chunked prefill + if hidden_states is None: + return None + hidden_states = hidden_states.to(self.head_dtype) # hidden_states shape: [n_token, hidden_size] @@ -704,7 +714,7 @@ class AllPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params assert len(pooled_data) == len(pooling_params) pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] @@ -722,17 +732,20 @@ class StepPooler(Pooler): self, hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, - ) -> torch.Tensor | list[torch.Tensor]: + ) -> PoolerOutput: pooled_data_lst = self.pooling(hidden_states, pooling_metadata) - prompt_token_ids = get_prompt_token_ids(pooling_metadata) - - pooled_data = list[torch.Tensor]() - - pooling_params = get_pooling_params(pooling_metadata) + prompt_token_ids = pooling_metadata.get_prompt_token_ids() + pooling_params = pooling_metadata.pooling_params + pooled_data: PoolerOutput = [] for data, token_id, pooling_param in zip( pooled_data_lst, prompt_token_ids, pooling_params ): + # for unfinished chunked prefill + if data is None: + pooled_data.append(data) + continue + step_tag_id = pooling_param.step_tag_id returned_token_ids = pooling_param.returned_token_ids @@ -757,7 +770,7 @@ class StepPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.extract_states(hidden_states, pooling_metadata) - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params assert len(pooled_data) == len(pooling_params) pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] @@ -794,7 +807,7 @@ class DispatchPooler(Pooler): outputs = list[torch.Tensor]() offset = 0 - for task, group in groupby(get_tasks(pooling_metadata)): + for task, group in groupby(pooling_metadata.tasks): if not (pooler := poolers_by_task.get(task)): raise ValueError( f"Unsupported task: {task} " diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index d463e181fd2db..3ed15ed7dd422 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from typing import TYPE_CHECKING, Any, Optional import torch @@ -471,6 +470,11 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): } ) + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full", intermediate_size_per_partition + ) + self.is_k_full = intermediate_size_per_partition == intermediate_size_full + w13_qweight = Parameter( torch.empty( num_experts, @@ -598,6 +602,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + # The modular kernel expects w13_weight and w2_weight, + # but AWQ uses w13_qweight and w2_qweight + # Alias for modular kernel + layer.w13_weight = layer.w13_qweight + # Alias for modular kernel + layer.w2_weight = layer.w2_qweight + # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, @@ -662,32 +673,96 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return None + from vllm.model_executor.layers.fused_moe.config import ( + awq_marlin_moe_quant_config, + ) + + return awq_marlin_moe_quant_config( + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + weight_bits=self.quant_config.weight_bits, + group_size=self.quant_config.group_size, + w1_zp=getattr(layer, "w13_qzeros", None) + if self.quant_config.zero_point + else None, + w2_zp=getattr(layer, "w2_qzeros", None) + if self.quant_config.zero_point + else None, + w1_bias=getattr(layer, "w13_bias", None), + w2_bias=getattr(layer, "w2_bias", None), + ) + + def select_gemm_impl( + self, + prepare_finalize, + layer: torch.nn.Module, + ): + """ + Select the GEMM implementation for AWQ-Marlin MoE. + Returns MarlinExperts configured for AWQ quantization. + This is ONLY used when LoRA is enabled. + Without LoRA, AWQ uses its own apply() method. + """ + # Only use modular kernels when LoRA is enabled + # Without LoRA, AWQ's own apply() method works fine and is more efficient + if not self.moe.is_lora_enabled: + raise NotImplementedError( + "AWQ-Marlin uses its own apply() method when LoRA is not enabled. " + "Modular kernels are only used for LoRA support." + ) + + from vllm.model_executor.layers.fused_moe import modular_kernel as mk + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + BatchedMarlinExperts, + MarlinExperts, + ) + + # Ensure quant config is initialized + assert self.moe_quant_config is not None, ( + "moe_quant_config must be initialized before select_gemm_impl" + ) + + w13_g_idx = getattr(layer, "w13_g_idx", None) + w2_g_idx = getattr(layer, "w2_g_idx", None) + w13_g_idx_sort_indices = getattr(layer, "w13_g_idx_sort_indices", None) + w2_g_idx_sort_indices = getattr(layer, "w2_g_idx_sort_indices", None) + + # Check if using batched expert format (for Expert Parallelism) + if ( + prepare_finalize.activation_format + == mk.FusedMoEActivationFormat.BatchedExperts + ): + # For batched format, use BatchedMarlinExperts + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + return BatchedMarlinExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + w13_g_idx=w13_g_idx, + w2_g_idx=w2_g_idx, + w13_g_idx_sort_indices=w13_g_idx_sort_indices, + w2_g_idx_sort_indices=w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) + else: + # Standard Marlin experts for AWQ + return MarlinExperts( + quant_config=self.moe_quant_config, + w13_g_idx=w13_g_idx, + w2_g_idx=w2_g_idx, + w13_g_idx_sort_indices=w13_g_idx_sort_indices, + w2_g_idx_sort_indices=w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) def apply( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert activation == "silu", "Only SiLU activation is supported." + assert layer.activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, @@ -708,9 +783,9 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale2=getattr(layer, "w2_input_global_scale", None), quant_type_id=self.quant_type.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, workspace=layer.workspace, diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 1e57fa218b797..1fd959cb3423d 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from typing import Any, Union import torch @@ -498,23 +497,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts @@ -534,10 +516,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 02086c3c0052d..f835584219cca 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -116,16 +116,37 @@ class CompressedTensorsConfig(QuantizationConfig): return "compressed-tensors" def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): - self.target_scheme_map = hf_to_vllm_mapper.apply_dict(self.target_scheme_map) - self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) - self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict( - self.sparsity_scheme_map - ) - self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list( - self.sparsity_ignore_list - ) + """ + Transform layer paths in config targets to match vLLM's naming. + + The WeightsMapper is designed for weight paths, but some backends + (e.g. transformers) use broad prefix mappings like "" -> "model." + which would incorrectly transform non-path targets. + + compressed-tensors targets can be: + - Layer paths: "layers.0.self_attn.q_proj" -> transformed + - Module class names: "Linear" -> preserved (no ".") + - Regex patterns: "re:.*proj" -> preserved (starts with "re:") + """ + + def _map_target(target: str) -> str | None: + is_layer_path = "." in target and not target.startswith("re:") + if is_layer_path: + return hf_to_vllm_mapper._map_name(target) + return target + + def _apply_dict(d: dict) -> dict: + return {k: v for t, v in d.items() if (k := _map_target(t)) is not None} + + def _apply_list(lst: list) -> list: + return [t for x in lst if (t := _map_target(x)) is not None] + + self.target_scheme_map = _apply_dict(self.target_scheme_map) + self.ignore = _apply_list(self.ignore) + self.sparsity_scheme_map = _apply_dict(self.sparsity_scheme_map) + self.sparsity_ignore_list = _apply_list(self.sparsity_ignore_list) if self.kv_cache_scheme is not None: - self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(self.kv_cache_scheme) + self.kv_cache_scheme = _apply_dict(self.kv_cache_scheme) def get_quant_method( self, @@ -256,7 +277,7 @@ class CompressedTensorsConfig(QuantizationConfig): if format is not None else is_activation_quantization_format(quant_format) ) - # TODO(czhu): w4a8fp8 is in packed-quantized format + # w4a8fp8 is in packed-quantized format # but needs input activation quantization input_activations = quant_config.get("input_activations") if act_quant_format or input_activations: @@ -767,8 +788,10 @@ class CompressedTensorsConfig(QuantizationConfig): targets=self.target_scheme_map.keys(), fused_mapping=self.packed_modules_mapping, ) - - return self.target_scheme_map[matched_target] + scheme_dict = self.target_scheme_map[matched_target] + if scheme_dict.get("format") is None: + scheme_dict["format"] = self.quant_format + return scheme_dict return None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 80ee443d4dd6a..f650a6eabbb9c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2,12 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum -from collections.abc import Callable from enum import Enum import torch from compressed_tensors import CompressionFormat -from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy +from compressed_tensors.quantization import ( + ActivationOrdering, + QuantizationArgs, + QuantizationStrategy, +) from torch.nn.parameter import Parameter import vllm.envs as envs @@ -29,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, int4_w4a16_moe_quant_config, + int4_w4afp8_moe_quant_config, int8_w8a8_moe_quant_config, int8_w8a16_moe_quant_config, nvfp4_moe_quant_config, @@ -75,7 +79,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + convert_bf16_scales_to_fp8, + convert_packed_uint4b8_to_signed_int4_inplace, + swizzle_blockscale, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, @@ -86,8 +94,10 @@ from vllm.platforms import CpuArchEnum, current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( get_col_major_tma_aligned_tensor, + get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, ) +from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) @@ -142,10 +152,26 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): # are supported + check if the layer is being ignored. weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") + format = scheme_dict.get("format") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): # group_size=None means channelwise group_size = weight_quant.group_size or -1 + + valid_format_and_bits = ( + weight_quant.num_bits in WNA16_SUPPORTED_BITS + and format == CompressionFormat.pack_quantized.value + ) + + if not valid_format_and_bits: + raise ValueError( + "For Fused MoE layers, only format: ", + f"{CompressionFormat.pack_quantized.value} ", + f" and bits: {WNA16_SUPPORTED_BITS} is supported ", + f"but got format: {CompressionFormat.pack_quantized.value} " + f" and bits: {weight_quant.num_bits}", + ) + # Prefer to use the MarlinMoE kernel when it is supported. if ( not check_moe_marlin_supports_layer(layer, group_size) @@ -161,12 +187,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ) logger.info_once("Using CompressedTensorsWNA16MoEMethod") return CompressedTensorsWNA16MoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config ) else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config ) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name) @@ -176,15 +202,20 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): or quant_config._is_fp8_w8a8(weight_quant, input_quant) ): return CompressedTensorsW8A8Fp8MoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config ) elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8MoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config + ) + elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW4A8Fp8MoEMethod") + return CompressedTensorsW4A8Fp8MoEMethod( + weight_quant, input_quant, layer.moe_config ) elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): return CompressedTensorsW4A8Int8MoEMethod( - quant_config, layer.moe_config, layer_name + weight_quant, input_quant, layer.moe_config ) else: raise RuntimeError( @@ -438,16 +469,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ) logger.debug_once("Finished shuffling weights for TRT-LLM MOE") - layer.gemm1_weights_fp4_shuffled = Parameter( + layer.w13_weight = Parameter( gemm1_weights_fp4_shuffled, requires_grad=False ) - layer.gemm2_weights_fp4_shuffled = Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = Parameter( + layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) + layer.w13_weight_scale = Parameter( gemm1_scales_fp4_shuffled, requires_grad=False ) - layer.gemm2_scales_fp4_shuffled = Parameter( + layer.w2_weight_scale = Parameter( gemm2_scales_fp4_shuffled, requires_grad=False ) @@ -456,12 +485,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale else: # swizzle weight scales layer.w13_weight_scale = torch.nn.Parameter( @@ -526,31 +549,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert activation == "silu", "Only SiLU activation is supported." + assert layer.activation == "silu", "Only SiLU activation is supported." if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): - if enable_eplb: + if layer.enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet." ) @@ -559,12 +565,12 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): layer=layer, x=x, router_logits=router_logits, - top_k=top_k, - global_num_experts=global_num_experts, - num_expert_group=num_expert_group, - topk_group=topk_group, - custom_routing_function=custom_routing_function, - e_score_correction_bias=e_score_correction_bias, + top_k=layer.top_k, + global_num_experts=layer.global_num_experts, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + custom_routing_function=layer.custom_routing_function, + e_score_correction_bias=layer.e_score_correction_bias, ) topk_weights, topk_ids, _ = layer.select_experts( @@ -587,9 +593,9 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): global_scale1=layer.w13_weight_scale_2, global_scale2=layer.w2_weight_scale_2, quant_type_id=scalar_types.float4_e2m1f.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, input_dtype=self.marlin_input_dtype, workspace=layer.workspace, ) @@ -614,15 +620,15 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): topk_ids=topk_ids, quant_config=self.moe_quant_config, inplace=False, # TODO(shuw): fix later, now output is high prec - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 - assert expert_map is None, ( + assert layer.expert_map is None, ( "Expert Parallelism / expert_map " "is currently not supported for " "CompressedTensorsW4A4Nvfp4MoEMethod." @@ -638,7 +644,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, quant_config=self.moe_quant_config, - apply_router_weight_on_input=apply_router_weight_on_input, + apply_router_weight_on_input=layer.apply_router_weight_on_input, # TODO(bnell): derive these from arguments m=x.shape[0], n=layer.w2_weight.shape[2] * 2, @@ -650,17 +656,19 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, moe: FusedMoEConfig, layer_name: str | None = None, ): - super().__init__(moe) - self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") - self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations" + from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig, ) + super().__init__(moe) + self.weight_quant = weight_quant + self.input_quant = input_quant + per_tensor = ( self.weight_quant.strategy == QuantizationStrategy.TENSOR and self.input_quant.strategy == QuantizationStrategy.TENSOR @@ -698,11 +706,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() # cutlass path - self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( + self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100( self.weight_quant, self.input_quant ) self.use_cutlass = not self.block_quant and ( - quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant) + CompressedTensorsConfig._is_fp8_w8a8_sm90( + self.weight_quant, self.input_quant + ) or self.is_fp8_w8a8_sm100 ) self.disable_expert_map = False @@ -1064,9 +1074,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): return experts - # triton path - from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts, + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts, + ) + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, @@ -1074,6 +1086,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): assert not self.rocm_aiter_moe_enabled and not self.use_marlin + use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + if ( prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts @@ -1081,22 +1095,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - return BatchedTritonOrDeepGemmExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - allow_deep_gemm=( - envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM - ), + if use_deep_gemm and not has_deep_gemm(): + raise RuntimeError( + "DeepGEMM requested for MoE layer but not installed." + ) + + compatible_with_deep_gemm = ( + self.moe_quant_config.use_fp8_w8a8 + and self.moe_quant_config.block_shape + == get_mk_alignment_for_contiguous_layout() ) + + # If this MoE layer is compatible with DeepGEMM, the proper env + # vars are set and DeepGEMM is not installed, throw an error. + if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm(): + raise RuntimeError( + f"MoE layer incompatible with DeepGEMM, expected " + f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}" + f"or block_shape {self.moe_quant_config.block_shape}" + f"=={get_mk_alignment_for_contiguous_layout()}." + ) + + if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm(): + logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__) + return BatchedDeepGemmExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: + logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) + return BatchedTritonExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) return TritonOrDeepGemmExperts( self.moe_quant_config, - allow_deep_gemm=( - envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM - ), + allow_deep_gemm=use_deep_gemm, ) def get_fused_moe_quant_config( @@ -1123,23 +1162,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, @@ -1150,7 +1172,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL if self.use_marlin: - assert activation == "silu", f"{activation} not supported for Marlin MoE." + assert layer.activation == "silu", ( + f"{layer.activation} not supported for Marlin MoE." + ) return fused_marlin_moe( x, layer.w13_weight, @@ -1163,9 +1187,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_weights, topk_ids, quant_type_id=scalar_types.float8_e4m3fn.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, input_dtype=self.marlin_input_dtype, workspace=layer.workspace, ) @@ -1183,9 +1207,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) @@ -1205,10 +1229,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=None + if self.disable_expert_map + else layer.expert_map, # ??? quant_config=self.moe_quant_config, ) else: @@ -1225,9 +1251,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_weights, topk_ids, quant_config=self.moe_quant_config, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=None if self.disable_expert_map else layer.expert_map, ab_strides1=self.ab_strides1_c_strides2, ab_strides2=self.ab_strides2, c_strides1=self.c_strides1, @@ -1246,10 +1272,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) @@ -1261,16 +1287,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, moe: FusedMoEConfig, layer_name: str | None = None, ): super().__init__(moe) - self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") - self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations" - ) + self.weight_quant = weight_quant + self.input_quant = input_quant per_channel = ( self.weight_quant.strategy == QuantizationStrategy.CHANNEL @@ -1371,23 +1395,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts @@ -1403,10 +1410,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) @@ -1414,36 +1421,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs | None, moe: FusedMoEConfig, layer_name: str | None = None, ): super().__init__(moe) - self.quant_config = quant_config - # TODO: @dsikka: refactor this to use schemes as other kernels - # are supported + check if the layer is being ignored. - config = self.quant_config.target_scheme_map["Linear"].get("weights") - self.num_bits = config.num_bits - self.packed_factor = 32 // config.num_bits - self.strategy = config.strategy - self.group_size = config.group_size - self.actorder = config.actorder - self.layer_name = layer_name - self.marlin_input_dtype = get_marlin_input_dtype(layer_name) - assert config.symmetric, "Only symmetric quantization is supported for MoE" + self.weight_quant = weight_quant + self.input_quant = input_quant + assert weight_quant.symmetric, ( + "Only symmetric quantization is supported for MoE" + ) + # Extract properties from weight_quant + self.num_bits = weight_quant.num_bits + self.packed_factor = 32 // weight_quant.num_bits + self.strategy = weight_quant.strategy + self.group_size = weight_quant.group_size + self.actorder = weight_quant.actorder - if not ( - self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS - ): - raise ValueError( - "For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}", - ) self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] self.use_marlin = True + self.marlin_input_dtype = get_marlin_input_dtype(layer_name) def create_weights( self, @@ -1757,25 +1755,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert activation == "silu", f"{activation} not supported for Marlin MoE." + assert layer.activation == "silu", ( + f"{layer.activation} not supported for Marlin MoE." + ) topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, @@ -1796,9 +1779,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale2=getattr(layer, "w2_input_global_scale", None), quant_type_id=self.quant_type.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, g_idx1=layer.w13_weight_g_idx, g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, @@ -1812,35 +1795,26 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs | None, moe: FusedMoEConfig, layer_name: str | None = None, ): super().__init__(moe) - self.quant_config = quant_config - # TODO: @dsikka: refactor this to use schemes as other kernels - # are supported + check if the layer is being ignored. - config = self.quant_config.target_scheme_map["Linear"].get("weights") - self.num_bits = config.num_bits - self.packed_factor = 32 // config.num_bits - self.strategy = config.strategy + self.weight_quant = weight_quant + self.input_quant = input_quant + # Extract properties from weight_quant + self.num_bits = weight_quant.num_bits + self.packed_factor = 32 // weight_quant.num_bits + self.strategy = weight_quant.strategy # channelwise is not supported by this kernel - assert config.strategy == "group" - self.group_size = config.group_size + assert weight_quant.strategy == "group" + self.group_size = weight_quant.group_size # grouped actorder isn't supported by this kernel - assert config.actorder != "group" - assert config.symmetric, "Only symmetric quantization is supported for MoE" - - if not ( - self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS - ): - raise ValueError( - "For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}", - ) + assert weight_quant.actorder != "group" + assert weight_quant.symmetric, ( + "Only symmetric quantization is supported for MoE" + ) def create_weights( self, @@ -2009,23 +1983,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts @@ -2041,10 +1998,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) @@ -2065,28 +2022,33 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, moe: FusedMoEConfig, layer_name: str | None = None, ): super().__init__(moe) self.has_bias = self.moe.has_bias - self.quant_config = quant_config + self.weight_quant = weight_quant + self.input_quant = input_quant # Validate scheme: weights=W4 (channel or group), # activations=dynamic TOKEN (A8) - wq = self.quant_config.target_scheme_map["Linear"].get("weights") - aq = self.quant_config.target_scheme_map["Linear"].get("input_activations") # Must be dynamic per-token activations - if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic: + if ( + input_quant.strategy != QuantizationStrategy.TOKEN + or not input_quant.dynamic + ): raise ValueError( "W4A8-int MoE needs dynamic per-token activation quantization." ) # Weight can be channel-wise (group_size=None) or group-wise - self.group_size = wq.group_size if (wq.group_size is not None) else -1 - if wq.num_bits != 4: + self.group_size = ( + weight_quant.group_size if (weight_quant.group_size is not None) else -1 + ) + if weight_quant.num_bits != 4: raise ValueError("This method only supports 4-bit weights (num_bits=4).") # CPU only @@ -2319,32 +2281,15 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor: - assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet." - assert activation in ("silu", "swigluoai", "swiglu"), ( + assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet." + assert layer.activation in ("silu", "swigluoai", "swiglu"), ( "Only SiLU/SwiGLUGU/SwiGLUUG are supported." ) - assert expert_map is None, """expert_map/EP not implemented + assert layer.expert_map is None, """expert_map/EP not implemented for CPU dyn-4bit MoE.""" def _act_kind(s: str) -> int: @@ -2361,15 +2306,9 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, + top_k=layer.top_k, + use_grouped_topk=layer.use_grouped_topk, + renormalize=layer.renormalize, ) return torch.ops._C.dynamic_4bit_int_moe( @@ -2382,6 +2321,317 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): layer.w2_in_features, layer.w13_out_features, layer.group_size, - apply_router_weight_on_input, - int(_act_kind(activation)), + layer.apply_router_weight_on_input, + int(_act_kind(layer.activation)), ) + + +class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.weight_quant = weight_quant + self.input_quant = input_quant + + self.group_size = self.weight_quant.group_size + self.num_bits = self.weight_quant.num_bits + self.packed_factor = 32 // self.num_bits + + assert self.weight_quant.symmetric, ( + "Only symmetric quantization is supported for W4A8 MoE" + ) + assert self.weight_quant.actorder != "group" + assert self.group_size == 128, "Only group size 128 supported for W4A8 MoE" + + self.disable_expert_map = False + self.layer_name = layer_name + + from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + ) + + self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # requirement for CUTLASS reorder_tensor + assert hidden_size % 256 == 0, f"{hidden_size=} must be divisible by 256" + assert intermediate_size_per_partition % 256 == 0, ( + f"{intermediate_size_per_partition=} must be divisible by 256" + ) + # storage type, pack 8xint4 into int32 + params_dtype = torch.int32 + + # WEIGHTS + w13_weight_packed = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.packed_factor, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight_packed) + set_weight_attrs(w13_weight_packed, extra_weight_attrs) + + w2_weight_packed = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.packed_factor, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight_packed) + set_weight_attrs(w2_weight_packed, extra_weight_attrs) + + # SCALES + # weight_scale refers to the group-wise scales + # they are initially loaded as bf16, we will convert to fp8 + # after loading + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=layer.orig_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=layer.orig_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-GROUP quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # weight shapes + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + # don't use input scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer): + device = layer.w13_weight_packed.device + + # STRIDES + # A, C + self.a_strides1_c_strides2 = torch.full( + (layer.local_num_experts,), + layer.hidden_size, + device=device, + dtype=torch.int64, + ) + self.a_strides2 = torch.full( + (layer.local_num_experts,), + layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64, + ) + self.c_strides1 = torch.full( + (layer.local_num_experts,), + 2 * layer.intermediate_size_per_partition, + device=device, + dtype=torch.int64, + ) + + # S (group-wise scales) + # sizeof(StrideS) = 16 bytes, so we need to use 2xint64 to encode it + self.s_strides1 = torch.zeros( + (layer.local_num_experts, 2), device=device, dtype=torch.int64 + ) + self.s_strides1[:, 0] = 2 * layer.intermediate_size_per_partition + + self.s_strides2 = torch.zeros( + (layer.local_num_experts, 2), device=device, dtype=torch.int64 + ) + self.s_strides2[:, 0] = layer.hidden_size + + # encode and reorder weight tensors, and get the layout to pass to + # the grouped gemm kernel. `b_strides1/2` specifies the entire layout + convert_packed_uint4b8_to_signed_int4_inplace(layer.w13_weight_packed) + w13_weight_shuffled, self.b_strides1 = ( + ops.cutlass_encode_and_reorder_int4b_grouped(layer.w13_weight_packed) + ) + replace_parameter(layer, "w13_weight_packed", w13_weight_shuffled) + convert_packed_uint4b8_to_signed_int4_inplace(layer.w2_weight_packed) + w2_weight_shuffled, self.b_strides2 = ( + ops.cutlass_encode_and_reorder_int4b_grouped(layer.w2_weight_packed) + ) + replace_parameter(layer, "w2_weight_packed", w2_weight_shuffled) + + # convert bf16 scales to (fp8_scales, channel_scales) + w13_weight_scale, w13_weight_chan_scale = convert_bf16_scales_to_fp8( + self.quant_fp8, layer.w13_weight_scale + ) + w2_weight_scale, w2_weight_chan_scale = convert_bf16_scales_to_fp8( + self.quant_fp8, layer.w2_weight_scale + ) + + # register channel scales + layer.register_parameter( + "w13_weight_chan_scale", + torch.nn.Parameter(w13_weight_chan_scale, requires_grad=False), + ) + layer.register_parameter( + "w2_weight_chan_scale", + torch.nn.Parameter(w2_weight_chan_scale, requires_grad=False), + ) + + # The scales are stored as (E, N, K // 128) but the kernel expects + # (E, K // 128, N) in row-major format, so we need to permute the last 2 dims + # and make it contiguous + w13_weight_scale_packed = ops.cutlass_pack_scale_fp8( + w13_weight_scale.permute(0, 2, 1).contiguous() + ) + replace_parameter(layer, "w13_weight_scale", w13_weight_scale_packed) + w2_weight_scale_packed = ops.cutlass_pack_scale_fp8( + w2_weight_scale.permute(0, 2, 1).contiguous() + ) + replace_parameter(layer, "w2_weight_scale", w2_weight_scale_packed) + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalize | None: + return super().maybe_make_prepare_finalize(routing_tables) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + # Store quantization scales; both per-group and per-channel + # Note we haven't specified the group size here because + # the quant config logic assumes group-wise scaling + # and channel-wise scaling are exclusive. + return int4_w4afp8_moe_quant_config( + w1_scale=layer.w13_weight_scale, # group scale + w2_scale=layer.w2_weight_scale, # group scale + g1_alphas=layer.w13_weight_chan_scale, + g2_alphas=layer.w2_weight_chan_scale, + per_act_token_quant=True, # always use dynamc per-token + per_out_ch_quant=True, # always use per-channel + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + assert self.moe_quant_config is not None + assert ( + prepare_finalize.activation_format == FusedMoEActivationFormat.Standard + ), "BatchedExperts not supported" + + from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8 + + experts: FusedMoEPermuteExpertsUnpermute + + logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__) + experts = CutlassExpertsW4A8Fp8( + out_dtype=self.moe.in_dtype, + a_strides1=self.a_strides1_c_strides2, + a_strides2=self.a_strides2, + b_strides1=self.b_strides1, + b_strides2=self.b_strides2, + c_strides1=self.c_strides1, + c_strides2=self.a_strides1_c_strides2, + s_strides1=self.s_strides1, + s_strides2=self.s_strides2, + quant_config=self.moe_quant_config, + group_size=self.group_size, + ) + + num_dispatchers = prepare_finalize.num_dispatchers() + self.disable_expert_map = ( + num_dispatchers > 1 or not experts.supports_expert_map() + ) + + return experts + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ): + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsW4A8Fp8MoEMethod` yet." + ) + assert self.moe_quant_config is not None + topk_weights, topk_ids, _ = layer.select_experts( + hidden_states=x, + router_logits=router_logits, + ) + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_w4a8_fp8, + ) + + return cutlass_moe_w4a8_fp8( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + topk_weights, + topk_ids, + quant_config=self.moe_quant_config, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=None if self.disable_expert_map else layer.expert_map, + a_strides1=self.a_strides1_c_strides2, + a_strides2=self.a_strides2, + b_strides1=self.b_strides1, + b_strides2=self.b_strides2, + c_strides1=self.c_strides1, + c_strides2=self.a_strides1_c_strides2, + s_strides1=self.s_strides1, + s_strides2=self.s_strides2, + group_size=self.group_size, + ) + + @property + def supports_eplb(self) -> bool: + return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index 3afadc6eb7e5b..d2701a464f129 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -28,7 +28,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): @classmethod def get_min_capability(cls) -> int: - # dont restrict as emulations + # don't restrict as emulations return 80 def create_weights( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py index a23961e897534..9a25e08cbad75 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -128,14 +128,15 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): ), ) - # TODO(czhu): allocate the packed fp8 scales memory here? - # the scales will be expanded by 8x via `cutlass_pack_scale_fp8` + # After loading, we will transform bf16 -> fp8 -> + # expand by 8x via `cutlass_pack_scale_fp8` + # and construct per-channel fp32 scales. weight_scale_args = { "weight_loader": weight_loader, "data": torch.empty( output_size_per_partition, scales_and_zp_size, - dtype=torch.float8_e4m3fn, + dtype=params_dtype, ), } @@ -152,17 +153,9 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader ) - # per-channel scales - weight_chan_scale = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) - layer.register_parameter("weight_chan_scale", weight_chan_scale) self.kernel = kernel_type( mp_linear_kernel_config, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 7ebe40ec84687..11097cf36f5ca 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from typing import Any, Optional import torch @@ -140,23 +139,6 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts @@ -172,10 +154,10 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 48223c9f103ea..f2b66a2beb6d7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from enum import Enum from functools import partial from typing import TYPE_CHECKING, Any, Optional @@ -95,7 +94,7 @@ from vllm.model_executor.parameter import ( ModelWeightParameter, PerTensorScaleParameter, ) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( @@ -124,17 +123,21 @@ class Fp8MoeBackend(Enum): def get_fp8_moe_backend( - block_quant: bool, moe_parallel_config: FusedMoEParallelConfig + block_quant: bool, + moe_parallel_config: FusedMoEParallelConfig, + with_lora_support: bool, ) -> Fp8MoeBackend: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ + if with_lora_support: + return Fp8MoeBackend.TRITON # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. if ( current_platform.is_cuda() and ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) and envs.VLLM_USE_FLASHINFER_MOE_FP8 @@ -145,7 +148,7 @@ def get_fp8_moe_backend( logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") return Fp8MoeBackend.FLASHINFER_TRTLLM else: - if block_quant and current_platform.is_device_capability(100): + if block_quant and current_platform.is_device_capability_family(100): raise ValueError( "FlashInfer FP8 MoE throughput backend does not " "support block quantization. Please use " @@ -190,7 +193,7 @@ def get_fp8_moe_backend( # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights if ( current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) and block_quant ): logger.info_once( @@ -329,7 +332,10 @@ class Fp8Config(QuantizationConfig): fused_mapping=self.packed_modules_mapping, ): return UnquantizedFusedMoEMethod(layer.moe_config) - moe_quant_method = Fp8MoEMethod(self, layer) + if self.is_checkpoint_fp8_serialized: + moe_quant_method = Fp8MoEMethod(self, layer) + else: + moe_quant_method = Fp8OnlineMoEMethod(self, layer) moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return moe_quant_method elif isinstance(layer, Attention): @@ -461,6 +467,30 @@ class Fp8LinearMethod(LinearMethodBase): output_size_per_partition, input_size_per_partition, weight_loader ) else: + + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # load the current weight chunk + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + + # track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + layer._loaded_numel += loaded_weight.numel() + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Delete the bookkeeping + del layer._loaded_numel + # Prevent the usual `process_weights_after_loading` call from doing + # anything + layer._already_called_process_weights_after_loading = True + + return res + # For non-serialized checkpoints, use original dtype weight = ModelWeightParameter( data=torch.empty( @@ -470,7 +500,7 @@ class Fp8LinearMethod(LinearMethodBase): ), input_dim=1, output_dim=0, - weight_loader=weight_loader, + weight_loader=patched_weight_loader, ) layer.register_parameter("weight", weight) @@ -511,6 +541,9 @@ class Fp8LinearMethod(LinearMethodBase): layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + size_k_first = True input_scale = None # TODO(rob): refactor block quant into separate class. @@ -518,46 +551,50 @@ class Fp8LinearMethod(LinearMethodBase): assert not self.act_q_static size_k_first = False - weight, weight_scale = process_fp8_weight_block_strategy( + weight, weight_scale_inv = process_fp8_weight_block_strategy( layer.weight, layer.weight_scale_inv ) - # Delete the weight_scale_inv parameter to avoid confusion - # with the weight_scale parameter - del layer.weight_scale_inv + + # Update layer with new values + replace_parameter(layer, "weight", weight.data) + replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) # If checkpoint not serialized fp8, quantize the weights. - elif not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) - weight = qweight.t() - - # If checkpoint is fp8 per-tensor, handle that there are N scales for N - # shards in a fused module else: - weight = layer.weight - weight_scale = layer.weight_scale + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + weight = qweight.t() - # If using w8a8, torch._scaled_mm needs per tensor, so - # requantize the logical shards as a single weight. - if not self.use_marlin: - weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( - weight, - weight_scale, - layer.logical_widths, - getattr(layer, "input_scale", None), - ) - if self.act_q_static: - assert input_scale is not None - input_scale = input_scale.max() - weight = weight.t() + # If checkpoint is fp8 per-tensor, handle that there are N scales for N + # shards in a fused module + else: + weight = layer.weight + weight_scale = layer.weight_scale - # Update layer with new values. - layer.weight = Parameter(weight.data, requires_grad=False) - layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) - layer.input_scale = ( - Parameter(input_scale, requires_grad=False) - if input_scale is not None - else None - ) + # If using w8a8, torch._scaled_mm needs per tensor, so + # requantize the logical shards as a single weight. + if not self.use_marlin: + weight, weight_scale, input_scale = ( + process_fp8_weight_tensor_strategy( + weight, + weight_scale, + layer.logical_widths, + getattr(layer, "input_scale", None), + ) + ) + if self.act_q_static: + assert input_scale is not None + input_scale = input_scale.max() + weight = weight.t() + + # Update layer with new values. + replace_parameter(layer, "weight", weight.data) + replace_parameter(layer, "weight_scale", weight_scale.data) + + if input_scale is not None: + replace_parameter(layer, "input_scale", input_scale) + else: + layer.input_scale = None if self.use_marlin: prepare_fp8_layer_for_marlin( @@ -584,7 +621,7 @@ class Fp8LinearMethod(LinearMethodBase): return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, - weight_scale=layer.weight_scale, + weight_scale=layer.weight_scale_inv, input_scale=layer.input_scale, bias=bias, ) @@ -613,10 +650,15 @@ class Fp8LinearMethod(LinearMethodBase): return torch.nn.functional.linear(x, weight_bf16.t(), bias) if self.use_marlin: + if self.block_quant: + weight_scale = layer.weight_scale_inv + else: + weight_scale = layer.weight_scale + return apply_fp8_marlin_linear( input=x, weight=layer.weight, - weight_scale=layer.weight_scale, + weight_scale=weight_scale, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, @@ -630,7 +672,7 @@ class Fp8LinearMethod(LinearMethodBase): return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, - weight_scale=layer.weight_scale, + weight_scale=layer.weight_scale_inv, input_scale=layer.input_scale, bias=bias, ) @@ -665,7 +707,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None self.fp8_backend = get_fp8_moe_backend( - self.block_quant, layer.moe_parallel_config + self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled ) self.marlin_input_dtype = None @@ -706,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.orig_dtype = params_dtype layer.weight_block_size = None - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn + assert self.quant_config.is_checkpoint_fp8_serialized + params_dtype = torch.float8_e4m3fn + if self.block_quant: assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size @@ -801,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.block_quant else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) @@ -835,6 +868,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.rocm_aiter_moe_enabled = False def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + # Lazy import to avoid importing triton too early. self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() @@ -869,22 +905,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): w2_weight_scale_inv = layer.w2_weight_scale_inv # torch.compile() cannot use Parameter subclasses. - layer.w13_weight = Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale_inv = Parameter( - w13_weight_scale_inv, requires_grad=False - ) - layer.w2_weight = Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale_inv = Parameter( - w2_weight_scale_inv, requires_grad=False - ) + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w13_weight_scale_inv", w13_weight_scale_inv) + replace_parameter(layer, "w2_weight", w2_weight) + replace_parameter(layer, "w2_weight_scale_inv", w2_weight_scale_inv) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) + replace_parameter(layer, "w13_weight", shuffled_w13) + replace_parameter(layer, "w2_weight", shuffled_w2) # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. @@ -913,43 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight_scale_inv = Parameter( dg_w2_weight_scale_inv, requires_grad=False ) - - # If checkpoint is fp16, quantize in place. - elif not self.quant_config.is_checkpoint_fp8_serialized: - fp8_dtype = current_platform.fp8_dtype() - w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - layer.w13_weight_scale = torch.nn.Parameter( - torch.ones( - layer.local_num_experts, - dtype=torch.float32, - device=w13_weight.device, - ), - requires_grad=False, - ) - for expert in range(layer.local_num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - if self.rocm_aiter_moe_enabled: - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight, layer.w2_weight - ) - - layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. else: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. @@ -967,12 +962,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): "fp8 MoE layer. Using the maximum across experts " "for each layer." ) - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False - ) + replace_parameter(layer, "w13_input_scale", layer.w13_input_scale.max()) + replace_parameter(layer, "w2_input_scale", layer.w2_input_scale.max()) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( @@ -986,22 +977,14 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) ) # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False - ) + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w13_weight_scale", w13_weight_scale) if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False - ) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter( - w2_weight_scale, requires_grad=False - ) + replace_parameter(layer, "w13_input_scale", w13_input_scale) + replace_parameter(layer, "w2_weight", w2_weight) + replace_parameter(layer, "w2_weight_scale", w2_weight_scale) if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False - ) + replace_parameter(layer, "w2_input_scale", w2_input_scale) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. @@ -1025,12 +1008,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight, layer.w2_weight ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) + replace_parameter(layer, "w13_weight", shuffled_w13) + replace_parameter(layer, "w2_weight", shuffled_w2) - layer.w13_weight_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) + replace_parameter(layer, "w13_weight_scale", max_w13_scales) if self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is @@ -1084,6 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): from vllm.model_executor.layers.fused_moe import ( BatchedDeepGemmExperts, BatchedTritonExperts, + TritonExperts, TritonOrDeepGemmExperts, ) @@ -1116,7 +1098,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, ) - + elif self.moe.is_lora_enabled: + return TritonExperts(quant_config=self.moe_quant_config) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: # Select GEMM experts with block-scale when weights are block-quantized experts = select_cutlass_fp8_gemm_impl( @@ -1171,41 +1154,20 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: - assert expert_load_view is not None - assert logical_to_physical_map is not None - assert logical_replica_count is not None - assert isinstance(layer, FusedMoE) - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - assert activation == "silu", ( - f"Expected 'silu' activation but got {activation}" + if layer.enable_eplb: + raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") + assert layer.activation == "silu", ( + f"Expected 'silu' activation but got {layer.activation}" ) if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 e_score_correction_bias = ( - e_score_correction_bias.to(x.dtype) - if e_score_correction_bias is not None + layer.e_score_correction_bias.to(x.dtype) + if layer.e_score_correction_bias is not None else None ) routing_method_type = layer.routing_method_type @@ -1219,29 +1181,31 @@ class Fp8MoEMethod(FusedMoEMethodBase): w13_weight_scale_inv=layer.w13_weight_scale_inv, w2_weight=layer.w2_weight, w2_weight_scale_inv=layer.w2_weight_scale_inv, - global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, intermediate_size=layer.intermediate_size_per_partition, expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, routing_method_type=routing_method_type, - routed_scaling=routed_scaling_factor, + routed_scaling=layer.routed_scaling_factor, ) else: - assert not renormalize and custom_routing_function is not None + assert ( + not layer.renormalize and layer.custom_routing_function is not None + ) result = apply_flashinfer_per_tensor_scale_fp8( layer=layer, hidden_states=x, router_logits=router_logits, - routing_bias=e_score_correction_bias, - global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - apply_router_weight_on_input=apply_router_weight_on_input, + routing_bias=layer.e_score_correction_bias, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) select_result = layer.select_experts( @@ -1262,13 +1226,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) elif self.use_marlin: - assert activation == "silu", f"{activation} not supported for Marlin MoE." + assert layer.activation == "silu", ( + f"{layer.activation} not supported for Marlin MoE." + ) result = fused_marlin_moe( x, layer.w13_weight, @@ -1281,20 +1247,22 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, quant_type_id=scalar_types.float8_e4m3fn.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, input_dtype=self.marlin_input_dtype, workspace=layer.workspace, ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert activation == "silu", ( - f"Expected 'silu' activation but got {activation}" + assert layer.activation == "silu", ( + f"Expected 'silu' activation but got {layer.activation}" ) if not self.block_quant: - assert not renormalize and custom_routing_function is not None - assert scoring_func == "sigmoid", ( - f"Expected 'sigmoid' scoring func but got {scoring_func}" + assert ( + not layer.renormalize and layer.custom_routing_function is not None + ) + assert layer.scoring_func == "sigmoid", ( + f"Expected 'sigmoid' scoring func but got {layer.scoring_func}" ) # Delegate to CUTLASS FlashInfer path; function already bound with # use_deepseek_fp8_block_scale for block-quant when applicable @@ -1304,10 +1272,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: from vllm.model_executor.layers.fused_moe import fused_experts @@ -1319,10 +1287,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, allow_cutlass_block_scaled_grouped_gemm=( @@ -1339,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase): return result +class Fp8OnlineMoEMethod(Fp8MoEMethod): + """MoE method for online FP8 quantization. + Supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(quant_config, layer) + assert not quant_config.is_checkpoint_fp8_serialized + assert quant_config.activation_scheme == "dynamic" + assert quant_config.weight_block_size is None + assert self.flashinfer_moe_backend is None + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # We are doing online quantization, patch the weight loaded + # to call `process_weights_after_loading` in a streaming fashion + # as soon as the last weight chunk is loaded. + weight_loader = extra_weight_attrs["weight_loader"] + # create a new holder to prevent modifying behavior of any other + # objects which might depend on the old one + new_extra_weight_attrs = extra_weight_attrs + + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # load the current weight chunk + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + + # add a counter to track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + layer._loaded_numel += loaded_weight.numel() + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Delete the bookkeeping + del layer._loaded_numel + # Prevent the usual `process_weights_after_loading` call + # from doing anything + layer._already_called_process_weights_after_loading = True + + return res + + new_extra_weight_attrs["weight_loader"] = patched_weight_loader + extra_weight_attrs = new_extra_weight_attrs + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + layer.w13_input_scale = None + layer.w2_input_scale = None + + self.rocm_aiter_moe_enabled = False + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + # Lazy import to avoid importing triton too early. + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + + # If checkpoint is fp16, quantize in place. + fp8_dtype = current_platform.fp8_dtype() + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + for expert in range(layer.local_num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w2_weight", w2_weight) + + # Reshuffle weights for AITER if needed. + if self.rocm_aiter_moe_enabled: + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + layer.w13_weight, layer.w2_weight + ) + replace_parameter(layer, "w13_weight", shuffled_w13) + replace_parameter(layer, "w2_weight", shuffled_w2) + + # Rushuffle weights for MARLIN if needed. + if self.use_marlin: + prepare_moe_fp8_layer_for_marlin( + layer, False, input_dtype=self.marlin_input_dtype + ) + + class Fp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index bcdfafb50fc5a..9dd734f2fea6a 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable, Mapping +from collections.abc import Mapping from types import MappingProxyType from typing import Any, Optional @@ -33,6 +33,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ) from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -52,6 +53,11 @@ class GGUFConfig(QuantizationConfig): return "gguf" def get_supported_act_dtypes(self) -> list[torch.dtype]: + # GGUF dequantization kernels use half precision (fp16) internally. + # bfloat16 has precision issues on Blackwell devices. + if current_platform.has_device_capability(100): + logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.") + return [torch.half, torch.float32] return [torch.half, torch.bfloat16, torch.float32] @classmethod @@ -82,6 +88,7 @@ class GGUFConfig(QuantizationConfig): return UnquantizedEmbeddingMethod() return GGUFEmbeddingMethod(self) elif isinstance(layer, FusedMoE): + # TODO: Select UnquantizedFusedMoEMethod on unquantized layers. return GGUFMoEMethod(self, layer.moe_config) return None @@ -624,26 +631,9 @@ class GGUFMoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert activation == "silu", "Only SiLU activation is supported." - if apply_router_weight_on_input: + assert layer.activation == "silu", "Only SiLU activation is supported." + if layer.apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for" "fused GGUF MoE method." @@ -661,7 +651,7 @@ class GGUFMoEMethod(FusedMoEMethodBase): topk_ids, layer.w13_qweight_type.weight_type, layer.w2_qweight_type.weight_type, - activation, + layer.activation, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 56034e11329dc..6e5dcfe59b2f9 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from copy import deepcopy from typing import Any, Optional @@ -733,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): is_a_8bit=is_a_8bit, ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + + # The modular kernel expects w13_weight and w2_weight, + # but GPTQ uses w13_qweight and w2_qweight + # Alias for modular kernel + layer.w13_weight = layer.w13_qweight + # Alias for modular kernel + layer.w2_weight = layer.w2_qweight + # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, @@ -783,32 +790,115 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return None + from vllm.model_executor.layers.fused_moe.config import ( + gptq_marlin_moe_quant_config, + ) + + return gptq_marlin_moe_quant_config( + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + weight_bits=self.quant_config.weight_bits, + group_size=self.quant_config.group_size, + w1_zp=getattr(layer, "w13_qzeros", None) + if not self.quant_config.is_sym + else None, + w2_zp=getattr(layer, "w2_qzeros", None) + if not self.quant_config.is_sym + else None, + w1_bias=getattr(layer, "w13_bias", None), + w2_bias=getattr(layer, "w2_bias", None), + ) + + def select_gemm_impl( + self, + prepare_finalize, + layer: torch.nn.Module, + ): + """ + Select the GEMM implementation for GPTQ-Marlin MoE. + + Returns MarlinExperts configured for GPTQ quantization. + This is ONLY used when LoRA is enabled. + Without LoRA, GPTQ uses its own apply() method. + """ + # Only use modular kernels when LoRA is enabled + # Without LoRA, GPTQ's own apply() method works fine and is more efficient + if not self.moe.is_lora_enabled: + raise NotImplementedError( + "GPTQ-Marlin uses its own apply() method when LoRA is not enabled. " + "Modular kernels are only used for LoRA support." + ) + + # The modular marlin kernels do not support 8-bit weights. + if self.quant_config.weight_bits == 8: + raise NotImplementedError( + "GPTQ-Marlin kernel does not support 8-bit weights." + ) + + from vllm.model_executor.layers.fused_moe import modular_kernel as mk + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + BatchedMarlinExperts, + MarlinExperts, + ) + + # Ensure quant config is initialized + assert self.moe_quant_config is not None, ( + "moe_quant_config must be initialized before select_gemm_impl" + ) + + w13_g_idx = ( + getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None + ) + w2_g_idx = ( + getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None + ) + w13_g_idx_sort_indices = ( + getattr(layer, "w13_g_idx_sort_indices", None) + if self.quant_config.desc_act + else None + ) + w2_g_idx_sort_indices = ( + getattr(layer, "w2_g_idx_sort_indices", None) + if self.quant_config.desc_act + else None + ) + + # Check if using batched expert format (for Expert Parallelism) + if ( + prepare_finalize.activation_format + == mk.FusedMoEActivationFormat.BatchedExperts + ): + # For batched format, use BatchedMarlinExperts + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + return BatchedMarlinExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + w13_g_idx=w13_g_idx, + w2_g_idx=w2_g_idx, + w13_g_idx_sort_indices=w13_g_idx_sort_indices, + w2_g_idx_sort_indices=w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) + else: + # Standard Marlin experts for GPTQ + return MarlinExperts( + quant_config=self.moe_quant_config, + w13_g_idx=w13_g_idx, + w2_g_idx=w2_g_idx, + w13_g_idx_sort_indices=w13_g_idx_sort_indices, + w2_g_idx_sort_indices=w2_g_idx_sort_indices, + is_k_full=self.is_k_full, + ) def apply( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert activation == "silu", "Only SiLU activation is supported." + assert layer.activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, @@ -829,9 +919,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): input_global_scale1=getattr(layer, "w13_input_global_scale", None), input_global_scale2=getattr(layer, "w2_input_global_scale", None), quant_type_id=self.quant_type.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, g_idx1=layer.w13_g_idx, g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 7ded8eea79060..a5db086fb4729 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform @@ -45,10 +46,13 @@ class QuantFP8(CustomOp): super().__init__() self.static = static self.group_shape = group_shape + self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN self.num_token_padding = num_token_padding self.column_major_scales = column_major_scales self.use_ue8m0 = use_ue8m0 + self.use_aiter = rocm_aiter_ops.is_linear_fp8_enaled() + self.is_group_quant = group_shape.is_per_group() if self.is_group_quant: assert not static, "Group quantization only supports dynamic mode" @@ -92,6 +96,33 @@ class QuantFP8(CustomOp): use_per_token_if_dynamic=self.use_per_token_if_dynamic, ) + def forward_hip( + self, + x: torch.Tensor, + scale: torch.Tensor | None = None, + scale_ub: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + use_aiter_quant = ( + not self.is_group_quant + and self.use_aiter + and scale_ub is None + and x.is_contiguous() + ) + use_aiter_per_tensor_quant = ( + use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR + ) + use_aiter_per_token_quant = ( + use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN + ) + + if use_aiter_per_tensor_quant: + return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale) + if use_aiter_per_token_quant: + return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale) + + # Fallback to CUDA implementation + return self.forward_cuda(x, scale, scale_ub) + def forward_native( self, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index a1571afba2974..463c74c1c1482 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from typing import Any, Optional import torch @@ -440,31 +439,14 @@ class XPUFp8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor: return layer.ipex_fusion( x, - use_grouped_topk, - top_k, + layer.use_grouped_topk, + layer.top_k, router_logits, - renormalize, - topk_group, - num_expert_group, - custom_routing_function=custom_routing_function, + layer.renormalize, + layer.topk_group, + layer.num_expert_group, + custom_routing_function=layer.custom_routing_function, ) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 0cf3f12af5522..c4160157cd628 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -30,6 +30,9 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer MPLinearKernel, MPLinearLayerConfig, ) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501 + XPUwNa16LinearKernel, +) from vllm.platforms import current_platform # in priority/performance order (when available) @@ -42,6 +45,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ BitBLASLinearKernel, ConchLinearKernel, ExllamaLinearKernel, + XPUwNa16LinearKernel, ] diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py index 8ef6457c952f1..c9c1a3abf7fd3 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -6,7 +6,11 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + convert_bf16_scales_to_fp8, + convert_packed_uint4b8_to_signed_int4_inplace, +) from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel): "CUTLASS W4A8, only supported int4", ) - # TODO(czhu): support -1 (column-wise) if c.group_size != 128: return False, "Only group_size 128 is supported" @@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel): # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): - # TODO(czhu): optimize speed/mem usage def transform_w_q(x): assert isinstance(x, BasevLLMParameter) + convert_packed_uint4b8_to_signed_int4_inplace(x.data) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t()) return x @@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel): x.data = ops.cutlass_pack_scale_fp8(x.data) return x + w_s = getattr(layer, self.w_s_name) + fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data) + w_s.data = fp8_scales + + # register per-channel scales + layer.register_parameter( + "weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False) + ) + # Encode/reorder weights and pack scales self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) - self._transform_param(layer, "weight_chan_scale", lambda x: x) def apply_weights( self, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/xpu.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/xpu.py new file mode 100644 index 0000000000000..abd2e047aeed0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/xpu.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch + +from vllm.platforms import current_platform + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class XPUwNa16LinearKernel(MPLinearKernel): + @classmethod + def get_min_capability(cls) -> int: + return 0 + + @classmethod + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_xpu(): + return False, "IPEX wNa16 only supported on XPU/CPU devices" + + # TODO: (yiliu30) relax these restrictions in later PRs + if c.zero_points: + return False, "Zero points not supported for Now" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + from packaging import version + + MIN_IPEX_VERSION = "2.6.0" + bias = layer.bias if not layer.skip_bias_add else None + + try: + import intel_extension_for_pytorch as ipex + + if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION): + raise ImportError( + "intel_extension_for_pytorch version is " + "wrong. Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}." + ) + except ImportError as err: + raise ImportError( + "Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " + f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" + " to use IPEX-AWQ linear method." + ) from err + # Using the compute dtype (lowp_mode) as INT8 to leverage instructions + # with better performance. + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + # The weight will be de-packed from INT4 to INT8. + weight_dtype = ipex.quantization.WoqWeightDtype.INT4 + # The float activation will be quantized (dynamic, per-token) to INT8. + act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=self.config.group_size, + weight_qscheme=ipex.quantization.WoqWeightQScheme.SYMMETRIC, + ) + qweight = layer.weight_packed + g_idx = layer.weight_g_idx if self.config.has_g_idx else None + scales = layer.weight_scale + qzeros = None + if self.config.zero_points: + qzeros = layer.weight_zero_point.contiguous() + qweight = qweight.t().contiguous() + scales = scales.t().contiguous() + layer.ipex_output_size = self.config.partition_weight_shape[1] + layer.ipex_qlinear = ( + ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight( + qweight, + scales, + qzeros, + in_features=self.config.partition_weight_shape[0], + out_features=self.config.partition_weight_shape[1], + qconfig=qconfig, + g_idx=g_idx, + bias=bias, + group_size=self.config.group_size, + quant_method=0, # `0` stands for the IPEX GPTQ + ) + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out = layer.ipex_qlinear(reshaped_x) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size,)) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 2a885ec899458..7be220f7a3734 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -17,7 +17,9 @@ class ScaledMMLinearLayerConfig: class ScaledMMLinearKernel(ABC): @classmethod @abstractmethod - def get_min_capability(cls) -> int: + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: raise NotImplementedError @classmethod @@ -35,6 +37,7 @@ class ScaledMMLinearKernel(ABC): azp_adj_param_name: str, ) -> None: assert self.can_implement(c) + assert self.is_supported() self.config = c self.w_q_name = w_q_param_name self.w_s_name = w_s_param_name diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index dd59e5d935dcb..20d050d387d49 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -27,7 +27,7 @@ from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], - PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } @@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel( type[ScaledMMLinearKernel]: Chosen kernel. """ - if compute_capability is None: - _cc = current_platform.get_device_capability() - if _cc is not None: - compute_capability = _cc[0] * 10 + _cc[1] - failure_reasons = [] for kernel in _POSSIBLE_KERNELS[current_platform._enum]: if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): - failure_reasons.append( - f" {kernel.__name__} disabled by environment variable" - ) + failure_reasons.append(f"{kernel.__name__}: disabled by env var") continue # If the current platform uses compute_capability, - # make sure the kernel supports the compute cability. - if compute_capability is not None: - kernel_min_capability = kernel.get_min_capability() - if ( - kernel_min_capability is not None - and kernel_min_capability > compute_capability - ): - failure_reasons.append( - f"{kernel.__name__} requires capability " - f"{kernel_min_capability}, current compute capability " - f"is {compute_capability}" - ) - continue + # make sure the kernel supports the compute capability. + is_supported, reason = kernel.is_supported(compute_capability) + if not is_supported: + failure_reasons.append(f"{kernel.__name__}: {reason}") + continue - can_implement, failure_reason = kernel.can_implement(config) - if can_implement: - return kernel - else: - failure_reasons.append( - f" {kernel.__name__} cannot implement due to: {failure_reason}" - ) + can_implement, reason = kernel.can_implement(config) + if not can_implement: + failure_reasons.append(f"{kernel.__name__}: {reason}") + continue + + return kernel raise ValueError( "Failed to find a kernel that can implement the " diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 038a92c516cec..971bd2005a23b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -14,17 +14,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod - def get_min_capability(cls) -> int: - return 90 - - @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: if not current_platform.is_rocm(): return ( False, "AiterScaledMMLinearKernel requires `aiter` which is not " + "currently supported on non-ROCm platform.", ) + if compute_capability is None: + _cc = current_platform.get_device_capability() + if _cc is not None: + compute_capability = _cc.major * 10 + _cc.minor + if compute_capability is not None and compute_capability < 90: + return False, f"requires capability 90, got {compute_capability}" try: import aiter # noqa: F401 # deliberately attempt to import aiter @@ -34,8 +38,8 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): "AiterScaledMMLinearKernel requires `aiter` which is not " + "installed on ROCm.", ) - # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled - if not (rocm_aiter_ops.is_linear_enabled()): + + if not rocm_aiter_ops.is_linear_enabled(): return ( False, "AiterScaledMMLinearKernel is disabled. " @@ -44,6 +48,10 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.", ) + return True, None + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not c.input_symmetric: return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index feb1e0bee1aaf..6401b94d6278b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -19,14 +19,15 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi class CPUScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod - def get_min_capability(cls) -> int: - return 75 + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_cpu(): + return False, "Requires CPU." + return True, None @classmethod def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - if not current_platform.is_cpu(): - return False, "CPUScaledMM requires running on CPU." - return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index e8769916b4cef..2f00e0df8ed47 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -16,14 +16,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod - def get_min_capability(cls) -> int: - return 75 + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_cuda(): + return False, "Requires CUDA." + if compute_capability is None: + _cc = current_platform.get_device_capability() + if _cc is not None: + compute_capability = _cc.major * 10 + _cc.minor + if compute_capability is not None and compute_capability < 75: + return False, f"requires capability 75, got {compute_capability}" + return True, None @classmethod def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - if not current_platform.is_cuda(): - return False, "CutlassScaledMM requires running on CUDA." - return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py index 3f4ec7f2a738b..760f1f7f79576 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -4,34 +4,53 @@ import torch +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501 + triton_scaled_mm, +) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.platforms import current_platform -from .cutlass import CutlassScaledMMLinearKernel -from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig -class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): +class TritonScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod - def get_min_capability(cls) -> int: - return 75 + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if current_platform.is_cuda_alike(): + return True, None + return False, "Requires ROCm or CUDA." @classmethod def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - if current_platform.is_cpu(): - return ( - False, - "TritonScaledMMLinearKernel requires Triton which is not " - + "currently supported on CPU.", - ) if not c.input_symmetric: - return ( - False, - "TritonScaledMMLinearKernel only supports symmetric " + "quantization.", - ) + return False, "Only symmetric input is supported." return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - super().process_weights_after_loading(layer) + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, + self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False), + ) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + replace_parameter( + layer, + self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False), + ) + setattr(layer, self.i_zp_name, None) + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + setattr(layer, self.azp_adj_name, None) def apply_weights( self, @@ -39,4 +58,14 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return super().apply_weights(layer, x, bias) + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + x_q, x_s, x_zp = ops.scaled_int8_quant( + x.contiguous(), i_s, i_zp, symmetric=True + ) + + assert x_zp is None, "Triton kernel only supports symmetric quantization" + + return triton_scaled_mm( + x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index ddac9f13cf4f3..0be858c51993d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -17,11 +17,12 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi class XLAScaledMMLinearKernel(ScaledMMLinearKernel): @classmethod - def get_min_capability(cls) -> int: - raise NotImplementedError( - "TPU platform does have a concept of compute capability, " - "this method should not be called." - ) + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_tpu(): + return False, "Requires TPU." + return True, None @classmethod def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 78456dcf1ca56..f0497a8722909 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -45,6 +45,13 @@ class BaseKVCacheMethod(QuantizeMethodBase): raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # skip if there are no weights to process (for example, weight reloading) + if not hasattr(layer, "q_scale"): + assert not hasattr(layer, "k_scale") + assert not hasattr(layer, "v_scale") + assert not hasattr(layer, "prob_scale") + return + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 034e97a713cdd..f71854e6b63c5 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from fnmatch import fnmatch from typing import TYPE_CHECKING, Any, Optional @@ -39,6 +38,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, + flashinfer_trtllm_fp4_routed_moe, prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl, @@ -81,6 +81,7 @@ from vllm.utils.flashinfer import ( has_flashinfer, has_flashinfer_moe, ) +from vllm.utils.math_utils import round_up if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -187,7 +188,24 @@ class ModelOptQuantConfigBase(QuantizationConfig): def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if len(self.exclude_modules) > 0: - self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) + # This is a workaround for the weights remapping issue: + # https://github.com/vllm-project/vllm/issues/28072 + # Right now, the Nvidia ModelOpt library use just one wildcard pattern: + # module_path* + # It gets applied if the whole tree of modules rooted at module_path + # is not quantized. Here we replace such pattern by 2 patterns that are + # collectively equivalent to the original pattern: + # module_path + # module_path.* + new_exclude_modules = [] + for exclude in self.exclude_modules: + if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".": + new_exclude_modules.append(exclude[:-1]) + new_exclude_modules.append(exclude[:-1] + ".*") + else: + new_exclude_modules.append(exclude) + + self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules) @staticmethod def get_config_filenames() -> list[str]: @@ -607,6 +625,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): Only supports pre-quantized checkpoints with FP8 weights and scales. """ + if self.flashinfer_moe_backend is not None: + self._maybe_pad_intermediate_for_flashinfer(layer) + layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) @@ -684,6 +705,50 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) register_moe_scaling_factors(layer) + def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: + """Pad intermediate size so FlashInfer kernels' alignment constraints hold. + + Some FlashInfer FP8 MoE kernels require the (gated) intermediate size + used for GEMM to be divisible by a small alignment value. When this is + not satisfied (e.g. with certain tensor-parallel sizes), we pad the + gate/up and down projection weights along the intermediate dim. + """ + if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"): + return + + # Current local intermediate size (per partition) is the K dimension of + # the down projection. + num_experts, hidden_size, intermediate = layer.w2_weight.shape + + min_alignment = 16 + padded_intermediate = round_up(intermediate, min_alignment) + + if padded_intermediate == intermediate: + return + + logger.info( + "Padding intermediate size from %d to %d for up/down projection weights.", + intermediate, + padded_intermediate, + ) + + up_mult = 2 if self.moe.is_act_and_mul else 1 + padded_gate_up_dim = up_mult * padded_intermediate + + # Pad w13 and w12 along its intermediate dimension. + w13 = layer.w13_weight.data + padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size)) + padded_w13[:, : w13.shape[1], :] = w13 + layer.w13_weight.data = padded_w13 + + w2 = layer.w2_weight.data + padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate)) + padded_w2[:, :, :intermediate] = w2 + layer.w2_weight.data = padded_w2 + + if hasattr(layer, "intermediate_size_per_partition"): + layer.intermediate_size_per_partition = padded_intermediate + def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: @@ -707,43 +772,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if layer.enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet." ) - assert activation == "silu", ( - f"Expected 'silu' activation but got {activation}" + assert layer.activation == "silu", ( + f"Expected 'silu' activation but got {layer.activation}" ) - assert not renormalize + + assert not layer.renormalize return apply_flashinfer_per_tensor_scale_fp8( layer=layer, hidden_states=x, router_logits=router_logits, - routing_bias=e_score_correction_bias, - global_num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - apply_router_weight_on_input=apply_router_weight_on_input, + routing_bias=layer.e_score_correction_bias, + global_num_experts=layer.global_num_experts, + top_k=layer.top_k, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) # Expert selection @@ -753,9 +802,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): ) if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert activation in ("silu", "relu2_no_mul"), ( + assert layer.activation in ("silu", "relu2_no_mul"), ( "Expected activation to be in ('silu', 'relu2_no_mul')," - f"but got {activation}" + f"but got {layer.activation}" ) return flashinfer_cutlass_moe_fp8( x, @@ -763,10 +812,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts @@ -780,11 +829,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, + activation=layer.activation, quant_config=self.moe_quant_config, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) @@ -1342,7 +1391,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): "Accuracy may be affected." ) - w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] + w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous() layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) # Common processing for input scales and alphas @@ -1409,16 +1458,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ) logger.debug_once("Finished shuffling weights for TRT-LLM MOE") - layer.gemm1_weights_fp4_shuffled = Parameter( + layer.w13_weight = Parameter( gemm1_weights_fp4_shuffled, requires_grad=False ) - layer.gemm2_weights_fp4_shuffled = Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = Parameter( + layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False) + layer.w13_weight_scale = Parameter( gemm1_scales_fp4_shuffled, requires_grad=False ) - layer.gemm2_scales_fp4_shuffled = Parameter( + layer.w2_weight_scale = Parameter( gemm2_scales_fp4_shuffled, requires_grad=False ) @@ -1427,12 +1474,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale elif self.use_marlin: # Marlin processing prepare_moe_fp4_layer_for_marlin(layer) @@ -1499,28 +1540,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): a2_gscale=layer.w2_input_scale_quant, ) + @property + def supports_eplb(self) -> bool: + return True + def apply( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if not self.moe.is_act_and_mul: assert ( @@ -1534,21 +1562,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + and not layer.enable_eplb ): - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `ModelOptNvFp4FusedMoE` yet." - ) return flashinfer_trtllm_fp4_moe( layer=layer, x=x, router_logits=router_logits, - top_k=top_k, - global_num_experts=global_num_experts, - num_expert_group=num_expert_group, - topk_group=topk_group, - custom_routing_function=custom_routing_function, - e_score_correction_bias=e_score_correction_bias, + top_k=layer.top_k, + global_num_experts=layer.global_num_experts, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + custom_routing_function=layer.custom_routing_function, + e_score_correction_bias=layer.e_score_correction_bias, ) topk_weights, topk_ids, _ = layer.select_experts( @@ -1556,6 +1581,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): router_logits=router_logits, ) + # EPLB path + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): + return flashinfer_trtllm_fp4_routed_moe( + layer=layer, + x=x, + topk_ids=topk_ids, + topk_weights=topk_weights, + top_k=layer.top_k, + global_num_experts=layer.global_num_experts, + ) + if self.use_marlin: return fused_marlin_moe( x, @@ -1571,9 +1610,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): global_scale1=layer.w13_weight_scale_2, global_scale2=layer.w2_weight_scale_2, quant_type_id=scalar_types.float4_e2m1f.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, input_dtype=self.marlin_input_dtype, ) @@ -1604,10 +1643,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): topk_ids=topk_ids, quant_config=self.moe_quant_config, inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case @@ -1622,8 +1661,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, quant_config=self.moe_quant_config, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, # TODO: derive from arguments m=x.shape[0], n=layer.w2_weight.shape[2] * 2, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index cf348290a2716..4bedb951a33f5 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from typing import Any, Optional import torch @@ -18,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) +from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, +) from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( @@ -60,7 +62,7 @@ class MoeWNA16Config(QuantizationConfig): if self.linear_quant_method == "gptq": self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config) - elif self.linear_quant_method == "awq": + elif self.linear_quant_method in ("awq", "awq_marlin"): capability_tuple = current_platform.get_device_capability() device_capability = ( -1 if capability_tuple is None else capability_tuple.to_int() @@ -107,7 +109,7 @@ class MoeWNA16Config(QuantizationConfig): if linear_quant_method == "gptq": has_zp = not cls.get_from_keys(config, ["sym"]) modules_to_not_convert = [] - elif linear_quant_method == "awq": + elif linear_quant_method in ("awq", "awq_marlin"): has_zp = cls.get_from_keys(config, ["zero_point"]) modules_to_not_convert = cls.get_from_keys_or( config, ["modules_to_not_convert"], None @@ -163,6 +165,8 @@ class MoeWNA16Config(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): + if isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod(layer.moe_config) return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): # Avoid circular import @@ -184,7 +188,7 @@ class MoeWNA16Config(QuantizationConfig): return GPTQConfig.from_config(self.full_config).get_quant_method( layer, prefix ) - elif self.linear_quant_method == "awq": + elif self.linear_quant_method in ("awq", "awq_marlin"): if self.use_marlin and check_marlin_supports_layer( layer, self.group_size ): @@ -362,27 +366,10 @@ class MoeWNA16Method(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - assert activation == "silu", "Only SiLU activation is supported." + assert layer.activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, @@ -395,9 +382,9 @@ class MoeWNA16Method(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) @@ -468,7 +455,8 @@ class MoeWNA16Method(FusedMoEMethodBase): shard_size = layer.intermediate_size_per_partition # convert gptq and awq weight to a standard format - if layer.quant_config.linear_quant_method == "awq": + # awq_marlin uses the same weight format as awq + if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"): assert layer.quant_config.weight_bits == 4 if "weight" in weight_name: loaded_weight = convert_awq_tensor(loaded_weight, "qweight") diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 5d330e837eea0..e96e87d15787d 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from enum import Enum from typing import Optional @@ -119,19 +118,19 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") return Mxfp4Backend.SM90_FI_MXFP4_BF16 elif ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS ): logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS elif ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 ): return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - elif current_platform.is_device_capability(100) and has_flashinfer(): + elif current_platform.is_device_capability_family(100) and has_flashinfer(): logger.info_once( "Using FlashInfer MXFP4 BF16 backend for SM100, " "For faster performance on SM100, consider setting " @@ -140,7 +139,7 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: ) return Mxfp4Backend.SM100_FI_MXFP4_BF16 elif ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) and not has_flashinfer(): logger.warning_once( @@ -892,25 +891,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if enable_eplb: + if layer.enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") if self.mxfp4_backend == Mxfp4Backend.MARLIN: @@ -933,26 +915,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): global_scale1=None, global_scale2=None, quant_type_id=scalar_types.float4_e2m1f.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - activation=activation, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + activation=layer.activation, + expert_map=layer.expert_map, input_dtype=self.marlin_input_dtype, ) assert _can_support_mxfp4( - use_grouped_topk, - topk_group, - num_expert_group, - expert_map, - custom_routing_function, - e_score_correction_bias, - apply_router_weight_on_input, - scoring_func, - activation, - expert_load_view, - logical_to_physical_map, - logical_replica_count, + layer.use_grouped_topk, + layer.topk_group, + layer.num_expert_group, + layer.expert_map, + layer.custom_routing_function, + layer.e_score_correction_bias, + layer.apply_router_weight_on_input, + layer.scoring_func, + layer.activation, + layer.expert_load_view, + layer.logical_to_physical_map, + layer.logical_replica_count, ), "MXFP4 are not supported with this configuration." if ( @@ -988,8 +970,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): None, # output1_scale_scalar None, # output1_scale_gate_scalar None, # output2_scale_scalar - global_num_experts, - top_k, + layer.global_num_experts, + layer.top_k, None, # n_group None, # topk_group self.intermediate_size, # padded to multiple of 256 @@ -997,7 +979,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.num_experts, # local num experts None, None, - 1 if renormalize else 0, # routing_method_type, renormalize + 1 if layer.renormalize else 0, # routing_method_type, renormalize True, # do finalize tune_max_num_tokens=max(self.max_capture_size, 1), )[0] @@ -1081,12 +1063,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w1=layer.w13_weight, w2=layer.w2_weight, gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - global_num_experts=global_num_experts, - expert_map=expert_map, + topk=layer.top_k, + renormalize=layer.renormalize, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, - apply_router_weight_on_input=apply_router_weight_on_input, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") @@ -1138,37 +1120,20 @@ class IpexMxfp4MoEMethod(Mxfp4MoEMethod): layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor: - assert activation == "swigluoai", ( + assert layer.activation == "swigluoai", ( "Only swiglu_oai activation is supported for IPEX MXFP4 MoE" ) hidden_size_pad = round_up(self.original_hidden_size, 128) x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1))) hidden_states = layer.ipex_fusion( x_pad, - use_grouped_topk, - top_k, + layer.use_grouped_topk, + layer.top_k, router_logits, - renormalize, - topk_group, - num_expert_group, + layer.renormalize, + layer.topk_group, + layer.num_expert_group, activation="swiglu_oai", ) hidden_states = hidden_states[..., : self.original_hidden_size].contiguous() diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 8be0299eaa66f..d84e22d1fa0f2 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable from typing import Any import torch @@ -337,23 +336,6 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, @@ -371,13 +353,15 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, quant_config=self.moe_quant_config, - expert_map=expert_map, + expert_map=layer.expert_map, ) elif self.use_marlin: - assert activation == "silu", f"{activation} not supported for Marlin MoE." + assert layer.activation == "silu", ( + f"{layer.activation} not supported for Marlin MoE." + ) return fused_marlin_moe( x, layer.w13_weight, @@ -390,9 +374,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): topk_weights, topk_ids, quant_type_id=scalar_types.float8_e4m3fn.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, ) else: from vllm.model_executor.layers.fused_moe import fused_experts @@ -404,10 +388,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) @@ -597,23 +581,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, @@ -631,8 +598,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - activation=activation, + activation=layer.activation, quant_config=self.moe_quant_config, + expert_map=layer.expert_map, ) else: from vllm.model_executor.layers.fused_moe import fused_experts @@ -644,10 +612,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) + return out diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 7b51b828009fc..b2ecb0b175f81 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -3,7 +3,6 @@ # Copyright © 2025, Oracle and/or its affiliates. import os -from collections.abc import Callable from typing import Any, Optional import numpy as np @@ -359,23 +358,6 @@ class RTNMoEMethod(FusedMoEMethodBase): layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, @@ -394,9 +376,9 @@ class RTNMoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, quant_type_id=self.quant_config.quant_type.id, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, workspace=workspace, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index eda40657b1e39..76bce8a8d98d6 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -50,7 +50,7 @@ def is_flashinfer_fp4_cutedsl_moe_available() -> bool: envs.VLLM_USE_FLASHINFER_MOE_FP4 and has_flashinfer_cutedsl_grouped_gemm_nt_masked() and current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.is_device_capability_family(100) ) @@ -301,18 +301,14 @@ def flashinfer_trtllm_fp4_moe( hidden_states_scale=hidden_states_scale_linear_fp4.view( torch.float8_e4m3fn ).flatten(), - gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm1_weights=layer.w13_weight.data, + gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), + gemm2_weights=layer.w2_weight.data, + gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c.data, output1_scale_gate_scalar=layer.g1_alphas.data, @@ -331,3 +327,78 @@ def flashinfer_trtllm_fp4_moe( )[0] return out + + +def flashinfer_trtllm_fp4_routed_moe( + layer: torch.nn.Module, + x: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + top_k: int, + global_num_experts: int, +) -> torch.Tensor: + """ + Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed + input top k expert indices and scores rather than computing + top k expert indices from scores. + + Args: + layer: The MoE layer with weights and scales + x: Input tensor + topk_ids: Ids of selected experts + top_k: Number of experts to select per token + global_num_experts: Total number of experts across all ranks + + Returns: + Output tensor from the MoE layer + """ + import flashinfer + + # Pack top k ids and expert weights into a single int32 tensor, as + # required by TRT-LLM + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16 + ).view(torch.int16) + + # Quantize input to FP4 + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) + + # Call TRT-LLM FP4 block-scale MoE kernel + out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( + topk_ids=packed_tensor, + routing_bias=None, + hidden_states=hidden_states_fp4, + hidden_states_scale=hidden_states_scale_linear_fp4.view( + torch.float8_e4m3fn + ).flatten(), + gemm1_weights=layer.w13_weight.data, + gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=layer.w2_weight.data, + gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=layer.g1_scale_c.data, + output1_scale_gate_scalar=layer.g1_alphas.data, + output2_scale_scalar=layer.g2_alphas.data, + num_experts=global_num_experts, + top_k=top_k, + n_group=0, + topk_group=0, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + routed_scaling_factor=None, + tile_tokens_dim=None, + routing_method_type=1, + do_finalize=True, + )[0] + + return out diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index eef7a0896c375..3d6e9cda87667 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -257,6 +257,7 @@ def flashinfer_cutlass_moe_fp8( out_dtype=hidden_states.dtype, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, ), + moe_parallel_config=layer.moe_parallel_config, ) return fused_experts( @@ -284,7 +285,7 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: if flashinfer_moe_backend in backend_map: if ( flashinfer_moe_backend == "latency" - and not current_platform.is_device_capability(100) + and not current_platform.is_device_capability_family(100) ): logger.info_once( "Flashinfer TRTLLM MOE backend is only supported on " diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ae63b4a767268..ea68745585160 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -27,6 +27,7 @@ from vllm.model_executor.parameter import ( ChannelQuantScaleParameter, PerTensorScaleParameter, ) +from vllm.model_executor.utils import replace_parameter from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( @@ -194,6 +195,39 @@ direct_register_custom_op( ) +def _triton_per_token_group_quant_fp8_impl( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + return per_token_group_quant_fp8( + x, group_size, column_major_scales=False, use_ue8m0=False + ) + + +def _triton_per_token_group_quant_fp8_fake( + x: torch.Tensor, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + x_fp8 = torch.empty((M, N), dtype=current_platform.fp8_dtype(), device=x.device) + out_bs = torch.empty( + ( + M, + (N + group_size - 1) // group_size, + ), + dtype=torch.float32, + device=x.device, + ) + return x_fp8, out_bs + + +direct_register_custom_op( + "triton_per_token_group_quant_fp8", + _triton_per_token_group_quant_fp8_impl, + fake_impl=_triton_per_token_group_quant_fp8_fake, +) + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 class W8A8BlockFp8LinearOp: @@ -213,6 +247,7 @@ class W8A8BlockFp8LinearOp: self.act_quant_group_shape = act_quant_group_shape self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_hopper = current_platform.is_device_capability(90) + self.is_blackwell = current_platform.is_device_capability_family(100) self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() # Get the correct blockscale mul and input quant operations. @@ -268,8 +303,15 @@ class W8A8BlockFp8LinearOp: weight: torch.Tensor, weight_scale: torch.Tensor, ) -> torch.Tensor: - assert self.deepgemm_input_quant_op is not None - q_input, input_scale = self.deepgemm_input_quant_op(input_2d) + if self.use_deep_gemm_e8m0 and self.is_blackwell: + q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( + input_2d, + group_size=self.act_quant_group_shape.col, + use_ue8m0=True, + ) + else: + assert self.deepgemm_input_quant_op is not None + q_input, input_scale = self.deepgemm_input_quant_op(input_2d) output = torch.empty( (q_input.shape[0], weight.shape[0]), dtype=torch.bfloat16, @@ -332,17 +374,15 @@ class W8A8BlockFp8LinearOp: if input_scale is not None: q_input = input_2d - # MI350 case uses triton kernel elif use_triton: - q_input, input_scale = per_token_group_quant_fp8( + q_input, input_scale = torch.ops.vllm.triton_per_token_group_quant_fp8( input_2d, self.act_quant_group_shape.col, - column_major_scales=False, - use_ue8m0=False, ) - # MI300 uses tuned AITER ASM/C++ kernel else: - q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d) + q_input, input_scale = rocm_aiter_ops.group_fp8_quant( + input_2d, self.act_quant_group_shape.col + ) return gemm_a8w8_blockscale_op( q_input, @@ -492,6 +532,139 @@ def _per_token_group_quant_fp8( tl.store(y_s_ptr, y_s) +@triton.jit +def _silu_mul_per_token_group_quant_fp8_colmajor( + y_ptr, # [M, N] + y_q_ptr, # [M, N // 2] + y_s_ptr, # [M, (N // 2) // GROUP_SIZE] + M, # num tokens + N, # intermediate size + # Stride + y_s_col_stride: tl.int64, + # Information for float8 + eps, + fp8_min, + fp8_max, + use_ue8m0: tl.constexpr, + # Meta-parameters + GROUP_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # TODO(varun) : Add expert_ids so we may early-exit no-op thread blocks. + """ + Each thread block (BLOCK_N) computes [BLOCK_M, GROUP_SIZE] act-mul outputs. Then + the thread block quantizes the [BLOCK_M, GROUP_SIZE] block of values and fills + the outputs tensors at the right positions. + """ + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + N_2 = N // 2 + + m_offset = pid_m * BLOCK_M + n_offset = pid_n * BLOCK_N + if m_offset >= M: + return + + offs_n = tl.arange(0, BLOCK_N).to(tl.int64) + offs_m = tl.arange(0, BLOCK_M).to(tl.int64) + + base_y_ptr = y_ptr + m_offset * N + n_offset + + act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :] + + act_in = tl.load(act_in_ptrs) + mul_in = tl.load(act_in_ptrs + N_2) + + # silu & mul + act_in = act_in.to(tl.float32) + one_f32 = tl.cast(1, tl.float32) + silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty) + y = (silu_out * mul_in).to(tl.float32) + + # quant + _absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps) + scale_raw = _absmax / fp8_max + y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw + y_s = tl.reshape(y_s, (BLOCK_M, 1)) + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + # store y_q + base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset + y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :] + tl.store(y_q_ptrs, y_q) + + # store y_s + group_id = n_offset // GROUP_SIZE + base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset + y_s_ptrs = base_y_s_ptr + offs_m + y_s = tl.reshape(y_s, (BLOCK_M,)) + tl.store(y_s_ptrs, y_s) + + +def silu_mul_per_token_group_quant_fp8_colmajor( + input: torch.Tensor, # [M, N] + output: torch.Tensor | None = None, # [M, N // 2] + use_ue8m0: bool | None = None, + eps: float = 1e-10, +): + """ + silu+mul + block-fp8 quant with group size 128. + """ + GROUP_SIZE = 128 + assert input.ndim == 2 + if output is not None: + assert output.ndim == 2 + assert input.size(0) % GROUP_SIZE == 0 + assert input.size(1) % (GROUP_SIZE * 2) == 0 + + if use_ue8m0 is None: + use_ue8m0 = is_deep_gemm_e8m0_used() + + M, N = input.size() + N_2 = N // 2 + + if output is None: + output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device) + + output_scales = torch.empty( + ((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device + ).transpose(0, 1) + + BLOCK_M = 8 + BLOCK_N = GROUP_SIZE + assert M % BLOCK_M == 0 + assert N_2 % BLOCK_N == 0 + + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_min = finfo.min + fp8_max = finfo.max + + # Force even division so we can avoid edgecases within the kernel. + assert M % BLOCK_M == 0 + assert N_2 % BLOCK_N == 0 + grid = (M // BLOCK_M, N_2 // BLOCK_N) + + _silu_mul_per_token_group_quant_fp8_colmajor[grid]( + input, + output, + output_scales, + M, + N, + output_scales.stride(-1), + eps, + fp8_min, + fp8_max, + use_ue8m0, + GROUP_SIZE, + BLOCK_M, + BLOCK_N, + ) + + return output, output_scales + + @triton.jit def _per_token_group_quant_fp8_colmajor( # Pointers to inputs and output @@ -589,14 +762,17 @@ def per_token_group_quant_fp8( ) assert x.stride(-1) == 1, "`x` groups must be contiguous" + # Using the default value (240.0) from pytorch will cause accuracy + # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm + # platforms that use the torch.float8_e4mefnuz dtype. finfo = torch.finfo(dtype) - fp8_min = finfo.min - fp8_max = finfo.max + fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min + fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max assert out_q is None or out_q.shape == x.shape x_q = out_q if x_q is None: - x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_q = torch.empty(x.shape, device=x.device, dtype=dtype) # Allocate the scale tensor in either row- or column-major format. if column_major_scales: @@ -658,6 +834,80 @@ def per_token_group_quant_fp8( return x_q, x_s +def per_token_group_quant_fp8_packed_for_deepgemm( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + use_ue8m0: bool | None = None, + out_q: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """FP8 per-token-group quantization for DeepGEMM. + + Returns: + (x_q, x_s_packed) + x_q: FP8 activations, same shape as `x`. + x_s_packed: Int32 tensor with logical shape + [mn, ceil(num_groups_per_row / 4)], laid out with + TMA-aligned stride along the packed-K dimension + """ + if use_ue8m0 is None: + use_ue8m0 = is_deep_gemm_e8m0_used() + # for DeepGEMM UE8M0-packed layout we *require* UE8M0 scales. + assert use_ue8m0, ( + "per_token_group_quant_fp8_packed_for_deepgemm requires UE8M0 scales." + ) + + dtype = current_platform.fp8_dtype() + assert x.shape[-1] % group_size == 0, ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}" + ) + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min, fp8_max = finfo.min, finfo.max + + # compute DeepGEMM-style packed scale tensor shape. + hidden_dim = x.shape[-1] + mn = x.numel() // hidden_dim + num_groups_per_row = hidden_dim // group_size + k_num_packed_sf_k = (num_groups_per_row + 3) // 4 + tma_aligned_mn = ((mn + 3) // 4) * 4 + + x_s_packed = torch.empty_strided( + (mn, k_num_packed_sf_k), + (1, tma_aligned_mn), + device=x.device, + dtype=torch.int32, + ) + + # CUDA kernel path only (DeepGEMM + E8M0 is CUDA-specific). + assert current_platform.is_cuda(), ( + "per_token_group_quant_fp8_packed_for_deepgemm is only valid on CUDA " + "platforms using DeepGEMM." + ) + + x_contiguous = x.contiguous() + if out_q is not None: + x_q_local = out_q + else: + x_q_local = torch.empty_like(x_contiguous, device=x.device, dtype=dtype) + + torch.ops._C.per_token_group_fp8_quant_packed( + x_contiguous, + x_q_local, + x_s_packed, + group_size, + eps, + fp8_min, + fp8_max, + ) + + # return a tensor with the original logical shape. + x_q = x_q_local.view_as(x) + return x_q, x_s_packed + + @triton.jit def _w8a8_triton_block_scaled_mm( # Pointers to inputs and output @@ -1189,12 +1439,12 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): if should_use_deepgemm: dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block( wq=layer.weight.data, - ws=layer.weight_scale.data, + ws=layer.weight_scale_inv.data, quant_block_shape=tuple(layer.weight_block_size), use_e8m0=is_deep_gemm_e8m0_used(), ) - layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False) - layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False) + replace_parameter(layer, "weight", dg_weight) + replace_parameter(layer, "weight_scale_inv", dg_weight_scale) def expert_weight_is_col_major(x: torch.Tensor) -> bool: diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 925d0a516ce63..32192225f61e2 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -83,26 +83,11 @@ def block_dequant( if current_platform.is_rocm(): - from triton.language import core - - # NOTE: This can be removed when hip.libdevice.round() is available. - @core.extern - def round_f32(arg0, _builder=None): - return core.extern_elementwise( - "", - "", - [arg0], - { - (core.dtype("fp32"),): ("llvm.round", core.dtype("fp32")), - (core.dtype("fp64"),): ("llvm.round", core.dtype("fp64")), - }, - is_pure=True, - _builder=_builder, - ) @triton.jit def round_int8(x): - return round_f32(x).to(tl.int8) + return tl.extra.hip.libdevice.round(x).to(tl.int8) + else: @triton.jit diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 14337ee1d7bee..072b46f055210 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -179,6 +179,8 @@ def check_marlin_supports_shape( def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: + if current_platform.is_rocm(): + return False output_size_per_partition = ( getattr(layer, "output_size_per_partition", None) or layer.output_size ) @@ -195,6 +197,8 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: + if current_platform.is_rocm(): + return False hidden_size = layer.hidden_size intermediate_size_per_partition = layer.intermediate_size_per_partition # apply_router_weight_on_input is not supported for moe marlin diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index e6b4f567caea4..c67e4f437cf0c 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_quant_input, should_use_atomic_add_reduce, ) +from vllm.model_executor.utils import replace_parameter from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -130,7 +131,7 @@ def prepare_fp8_layer_for_marlin( size_n=part_size_n, num_bits=8, ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + replace_parameter(layer, "weight", marlin_qweight) # WEIGHT SCALES # Permute scales @@ -138,7 +139,6 @@ def prepare_fp8_layer_for_marlin( scales = layer.weight_scale.to(layer.orig_dtype) elif "weight_scale_inv" in dir(layer): scales = layer.weight_scale_inv.to(layer.orig_dtype) - del layer.weight_scale_inv group_size = -1 if weight_block_size is None else weight_block_size[1] @@ -177,12 +177,15 @@ def prepare_fp8_layer_for_marlin( ) if input_dtype != torch.float8_e4m3fn: marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + if hasattr(layer, "weight_scale"): + replace_parameter(layer, "weight_scale", marlin_scales) + elif hasattr(layer, "weight_scale_inv"): + replace_parameter(layer, "weight_scale_inv", marlin_scales) if hasattr(layer, "bias") and layer.bias is not None: assert layer.bias.shape == (part_size_n,) bias = marlin_permute_bias(layer.bias) - layer.bias = torch.nn.Parameter(bias, requires_grad=False) + replace_parameter(layer, "bias", bias) def prepare_moe_fp8_layer_for_marlin( diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index d0c8b3d1a3093..e9ecf0547033d 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -57,12 +57,18 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): mx_axis=1, num_warps=num_warps ) ) - if current_platform.is_cuda() and current_platform.is_device_capability(100): - constraints = { - "is_persistent": True, - "epilogue_subtile": 1, - } - opt_flags.update_opt_flags_constraints(constraints) + if current_platform.is_cuda(): + if current_platform.is_device_capability(90): + constraints = { + "split_k": 1, + } + opt_flags.update_opt_flags_constraints(constraints) + elif current_platform.is_device_capability_family(100): + constraints = { + "is_persistent": True, + "epilogue_subtile": 1, + } + opt_flags.update_opt_flags_constraints(constraints) # transpose the tensor so that the quantization axis is on dim1 quant_tensor = quant_tensor.transpose(-2, -1) scale = scale.transpose(-2, -1) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d056d3404385a..d01263f82007d 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This file is used for /tests and /benchmarks""" -from collections.abc import Mapping +from collections.abc import Callable, Mapping from dataclasses import dataclass from types import MappingProxyType from typing import ClassVar, NamedTuple @@ -115,6 +115,12 @@ kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) +kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) +kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) + +kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) +kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True) + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): @@ -685,3 +691,51 @@ def cutlass_fp4_supported() -> bool: capability_tuple = current_platform.get_device_capability() capability = -1 if capability_tuple is None else capability_tuple.to_int() return cutlass_scaled_mm_supports_fp4(capability) + + +def convert_bf16_scales_to_fp8( + quant_fp8: Callable, scales: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales) + expected by W4A8 GEMM kernels. + """ + assert scales.is_contiguous(), ( + f"scale tensor must be contiguous, got {scales.stride()=}" + ) + assert scales.is_cuda, "scales must be on gpu" + + orig_shape = scales.shape + k_groups = orig_shape[-1] + flat_scales = scales.view(-1, k_groups) + + fp8_scales, chan_scales = quant_fp8(flat_scales) + fp8_scales = (fp8_scales.float() / 8.0).to(torch.float8_e4m3fn) + chan_scales *= 8.0 + + # restore original shape + fp8_scales = fp8_scales.view(orig_shape) + chan_scales = chan_scales.view(orig_shape[:-1], -1) + + return fp8_scales, chan_scales + + +def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tensor: + """ + Convert int4b8 (packed to int32) to signed int4 + """ + assert t.is_cuda, "tensor must be on gpu" + assert t.dtype == torch.int32, f"expected int32 packed weights but got {t.dtype}" + + # loop through the 8 4-bit nibbles in each int32 entry + for i in range(8): + shift = 4 * i + # extract the i-th 4-bit nibble + nib = (t >> shift) & 0xF + # clear the original nibble by masking out + t &= ~(0xF << shift) + # convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8 + # and update in-place + t |= ((nib - 8) & 0xF) << shift + + return t diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index fceed3e55c2df..4287922417c63 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -118,8 +118,11 @@ def requantize_with_max_scale( # from disk in this case. Skip requantization in this case (since) # we already are quantized with the single scale. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 + # + # Extra note: upon weight reloading weight_scale.ndim == 0 unfused_module_in_checkpoint = ( - weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min + weight_scale.ndim != 0 + and weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min ) # If unfused checkpoint, need requanize with the single scale. diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 0f10bff6ac4f5..452b87ea4e7a5 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -25,12 +25,10 @@ _ROPE_DICT: dict[tuple, RotaryEmbedding] = {} def get_rope( head_size: int, - rotary_dim: int, max_position: int, is_neox_style: bool = True, rope_parameters: dict[str, Any] | None = None, dtype: torch.dtype | None = None, - partial_rotary_factor: float = 1.0, dual_chunk_attention_config: dict[str, Any] | None = None, ) -> RotaryEmbedding: if dtype is None: @@ -55,8 +53,15 @@ def get_rope( else: dual_chunk_attention_args = None - if partial_rotary_factor < 1.0: - rotary_dim = int(rotary_dim * partial_rotary_factor) + rope_parameters = rope_parameters or {} + base = rope_parameters.get("rope_theta", 10000) + scaling_type = rope_parameters.get("rope_type", "default") + partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) + + if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0: + raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0") + rotary_dim = int(head_size * partial_rotary_factor) + key = ( head_size, rotary_dim, @@ -69,7 +74,6 @@ def get_rope( if key in _ROPE_DICT: return _ROPE_DICT[key] - base = rope_parameters["rope_theta"] if rope_parameters else 10000 if dual_chunk_attention_config is not None: extra_kwargs = { k: v @@ -85,109 +89,76 @@ def get_rope( dtype, **extra_kwargs, ) - elif not rope_parameters: - rotary_emb = RotaryEmbedding( + elif scaling_type == "default": + if "mrope_section" in rope_parameters: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_parameters["mrope_section"], + mrope_interleaved=rope_parameters.get("mrope_interleaved", False), + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) + elif scaling_type == "llama3": + scaling_factor = rope_parameters["factor"] + low_freq_factor = rope_parameters["low_freq_factor"] + high_freq_factor = rope_parameters["high_freq_factor"] + original_max_position = rope_parameters["original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, + ) + elif scaling_type == "mllama4": + rotary_emb = Llama4VisionRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, dtype ) - else: - scaling_type = rope_parameters["rope_type"] - - if scaling_type == "llama3": - scaling_factor = rope_parameters["factor"] - low_freq_factor = rope_parameters["low_freq_factor"] - high_freq_factor = rope_parameters["high_freq_factor"] - original_max_position = rope_parameters["original_max_position_embeddings"] - rotary_emb = Llama3RotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - dtype, - scaling_factor, - low_freq_factor, - high_freq_factor, - original_max_position, - ) - elif scaling_type == "mllama4": - rotary_emb = Llama4VisionRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, dtype - ) - elif scaling_type == "default": - if "mrope_section" in rope_parameters: - rotary_emb = MRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - dtype, - mrope_section=rope_parameters["mrope_section"], - mrope_interleaved=rope_parameters.get("mrope_interleaved", False), - ) - else: - rotary_emb = RotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - dtype, - ) - elif scaling_type == "linear": - scaling_factor = rope_parameters["factor"] - rotary_emb = LinearScalingRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - scaling_factor, - dtype, - ) - elif scaling_type == "ntk": - scaling_factor = rope_parameters["factor"] - mixed_b = rope_parameters.get("mixed_b") - rotary_emb = NTKScalingRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - scaling_factor, - dtype, - mixed_b, - ) - elif scaling_type == "dynamic": - if "alpha" in rope_parameters: - scaling_alpha = rope_parameters["alpha"] - rotary_emb = DynamicNTKAlphaRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - scaling_alpha, - dtype, - ) - elif "factor" in rope_parameters: - scaling_factor = rope_parameters["factor"] - rotary_emb = DynamicNTKScalingRotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - scaling_factor, - dtype, - ) - else: - raise ValueError( - "Dynamic rope scaling must contain either 'alpha' or 'factor' field" - ) - elif scaling_type == "xdrope": + elif scaling_type == "linear": + scaling_factor = rope_parameters["factor"] + rotary_emb = LinearScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + elif scaling_type == "ntk": + scaling_factor = rope_parameters["factor"] + mixed_b = rope_parameters.get("mixed_b") + rotary_emb = NTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + mixed_b, + ) + elif scaling_type == "dynamic": + if "alpha" in rope_parameters: scaling_alpha = rope_parameters["alpha"] - rotary_emb = XDRotaryEmbedding( + rotary_emb = DynamicNTKAlphaRotaryEmbedding( head_size, rotary_dim, max_position, @@ -195,67 +166,66 @@ def get_rope( is_neox_style, scaling_alpha, dtype, - xdrope_section=rope_parameters["xdrope_section"], ) - elif scaling_type == "yarn": + elif "factor" in rope_parameters: scaling_factor = rope_parameters["factor"] - original_max_position = rope_parameters["original_max_position_embeddings"] - extra_kwargs = { - k: v - for k, v in rope_parameters.items() - if k - in ( - "extrapolation_factor", - "attn_factor", - "beta_fast", - "beta_slow", - "apply_yarn_scaling", - "truncate", - ) - } - if "mrope_section" in rope_parameters: - extra_kwargs.pop("apply_yarn_scaling", None) - rotary_emb = MRotaryEmbedding( - head_size, - rotary_dim, - original_max_position, - base, - is_neox_style, - dtype, - mrope_section=rope_parameters["mrope_section"], - mrope_interleaved=rope_parameters.get("mrope_interleaved", False), - scaling_factor=scaling_factor, - **extra_kwargs, - ) - else: - rotary_emb = YaRNScalingRotaryEmbedding( - head_size, - rotary_dim, - original_max_position, - base, - is_neox_style, - scaling_factor, - dtype, - **extra_kwargs, - ) - elif scaling_type == "deepseek_yarn": - scaling_factor = rope_parameters["factor"] - original_max_position = rope_parameters["original_max_position_embeddings"] - # assert max_position == original_max_position * scaling_factor - extra_kwargs = { - k: v - for k, v in rope_parameters.items() - if k - in ( - "extrapolation_factor", - "attn_factor", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ) - } - rotary_emb = DeepseekScalingRotaryEmbedding( + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + else: + raise ValueError( + "Dynamic rope scaling must contain either 'alpha' or 'factor' field" + ) + elif scaling_type == "xdrope": + scaling_alpha = rope_parameters["alpha"] + rotary_emb = XDRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_alpha, + dtype, + xdrope_section=rope_parameters["xdrope_section"], + ) + elif scaling_type == "yarn": + scaling_factor = rope_parameters["factor"] + original_max_position = rope_parameters["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_parameters.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "apply_yarn_scaling", + "truncate", + ) + } + if "mrope_section" in rope_parameters: + extra_kwargs.pop("apply_yarn_scaling", None) + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_parameters["mrope_section"], + mrope_interleaved=rope_parameters.get("mrope_interleaved", False), + scaling_factor=scaling_factor, + **extra_kwargs, + ) + else: + rotary_emb = YaRNScalingRotaryEmbedding( head_size, rotary_dim, original_max_position, @@ -265,28 +235,55 @@ def get_rope( dtype, **extra_kwargs, ) - elif scaling_type == "longrope": - short_factor = rope_parameters["short_factor"] - long_factor = rope_parameters["long_factor"] - original_max_position = rope_parameters["original_max_position_embeddings"] - extra_kwargs = { - k: v - for k, v in rope_parameters.items() - if k in ("short_mscale", "long_mscale") - } - rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( - head_size, - rotary_dim, - max_position, - original_max_position, - base, - is_neox_style, - dtype, - short_factor, - long_factor, - **extra_kwargs, + elif scaling_type in ["deepseek_yarn", "deepseek_llama_scaling"]: + scaling_factor = rope_parameters["factor"] + original_max_position = rope_parameters["original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_parameters.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + } + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + elif scaling_type == "longrope": + short_factor = rope_parameters["short_factor"] + long_factor = rope_parameters["long_factor"] + original_max_position = rope_parameters["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_parameters.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, + rotary_dim, + max_position, + original_max_position, + base, + is_neox_style, + dtype, + short_factor, + long_factor, + **extra_kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb return rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 4114b21168cc8..afa69324c4e2e 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -7,7 +7,7 @@ import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp -from .common import apply_rotary_emb_torch +from .common import ApplyRotaryEmb @CustomOp.register("rotary_embedding") @@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp): rocm_aiter_ops.is_triton_rotary_embed_enabled() ) + self.apply_rotary_emb = ApplyRotaryEmb( + is_neox_style=self.is_neox_style, + ) + def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to @@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase): query = query.view(num_tokens, -1, head_size) query_rot = query[..., :rotary_dim] query_pass = query[..., rotary_dim:] - query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style) + query_rot = ApplyRotaryEmb.forward_static( + query_rot, + cos, + sin, + is_neox_style, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) # key may be None in some cases, e.g. cross-layer KV sharing @@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase): key = key.view(num_tokens, -1, head_size) key_rot = key[..., :rotary_dim] key_pass = key[..., rotary_dim:] - key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style) + key_rot = ApplyRotaryEmb.forward_static( + key_rot, + cos, + sin, + is_neox_style, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 13f8d15cc0f72..3e6584dbc3da0 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -2,19 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from collections.abc import Callable -from functools import cache from importlib.util import find_spec import torch from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.model_executor.custom_op import CustomOp from vllm.utils.torch_utils import direct_register_custom_op -if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - logger = init_logger(__name__) @@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) -def apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -def apply_rotary_emb_dispatch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool -) -> torch.Tensor: - """ - Args: - x: [num_tokens, num_heads, head_size] - cos: [num_tokens, head_size // 2] - sin: [num_tokens, head_size // 2] - is_neox_style: Whether to use the Neox-style or GPT-J-style rotary - positional embeddings. - """ - if current_platform.is_cuda(): - return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0) - else: - return apply_rotary_emb_torch(x, cos, sin, is_neox_style) - - -@cache -def dispatch_rotary_emb_function( - default: Callable[..., torch.Tensor] | None = None, -) -> Callable[..., torch.Tensor]: - if current_platform.is_cuda(): - return apply_rotary_emb - - # if torch compile is not enabled - # use rotary embedding function from flash_attn package - # otherwise use the naive pytorch embedding implementation - # is faster when torch compile is enabled. - if current_platform.is_rocm() and not torch.compiler.is_compiling(): - if find_spec("flash_attn") is not None: - from flash_attn.ops.triton.rotary import apply_rotary - - return apply_rotary - else: - logger.warning( - "flash_attn is not installed. Falling back to PyTorch " - "implementation for rotary embeddings." - ) - if default is not None: - return default - - return apply_rotary_emb_torch - - # yarn functions # Inverse dim formula to find dim based on number of rotations def yarn_find_correction_dim( @@ -186,3 +116,155 @@ direct_register_custom_op( mutates_args=["query", "key"], # These tensors are modified in-place fake_impl=_flashinfer_rotary_embedding_fake, ) + + +@CustomOp.register("apply_rotary_emb") +class ApplyRotaryEmb(CustomOp): + def __init__( + self, + enforce_enable: bool = False, + is_neox_style: bool = True, + enable_fp32_compute: bool = False, + ) -> None: + super().__init__(enforce_enable) + self.is_neox_style = is_neox_style + self.enable_fp32_compute = enable_fp32_compute + + self.apply_rotary_emb_flash_attn = None + if find_spec("flash_attn") is not None: + from flash_attn.ops.triton.rotary import apply_rotary + + self.apply_rotary_emb_flash_attn = apply_rotary + + @staticmethod + def forward_static( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool = True, + enable_fp32_compute: bool = False, + ) -> torch.Tensor: + """ + Args: + x: [batch_size (optional), seq_len, num_heads, head_size] + cos: [seq_len, head_size // 2] + sin: [seq_len, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style. + enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype + for higher accuracy. + """ + origin_dtype = x.dtype + if enable_fp32_compute: + x = x.float() + + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + + if is_neox_style: + output = torch.cat((o1, o2), dim=-1) + else: + output = torch.stack((o1, o2), dim=-1).flatten(-2) + + if enable_fp32_compute: + output = output.to(origin_dtype) + return output + + def forward_native( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + output = self.forward_static( + x, cos, sin, self.is_neox_style, self.enable_fp32_compute + ) + return output + + def forward_cuda( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + origin_dtype = x.dtype + if self.enable_fp32_compute: + x = x.float() + cos = cos.float() + sin = sin.float() + + origin_shape = x.shape + if len(origin_shape) == 3: + # x: [seq_len, num_heads, head_size] + x = x.unsqueeze(0) + + """ + Arguments of apply_rotary_emb() in vllm_flash_attn: + x: [batch_size, seq_len, nheads, headdim] + cos, sin: [seqlen_rotary, rotary_dim / 2] + interleaved: defalut as False (Neox-style). + ... + """ + interleaved = not self.is_neox_style + output = apply_rotary_emb(x, cos, sin, interleaved) + + if len(origin_shape) == 3: + output = output.squeeze(0) + if self.enable_fp32_compute: + output = output.to(origin_dtype) + return output + + def forward_hip( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + if self.apply_rotary_emb_flash_attn is not None: + origin_dtype = x.dtype + if self.enable_fp32_compute: + x = x.float() + cos = cos.float() + sin = sin.float() + + origin_shape = x.shape + if len(origin_shape) == 3: + # x: [seq_len, num_heads, head_size] + x = x.unsqueeze(0) + + """ + Arguments of apply_rotary() in flash_attn: + x: [batch_size, seq_len, nheads, headdim] + cos, sin: [seqlen_rotary, rotary_dim / 2] + interleaved: defalut as False (Neox-style). + ... + """ + interleaved = not self.is_neox_style + output = self.apply_rotary_emb_flash_attn( + x, cos, sin, interleaved=interleaved + ).type_as(x) + + if len(origin_shape) == 3: + output = output.squeeze(0) + if self.enable_fp32_compute: + output = output.to(origin_dtype) + else: + # Falling back to PyTorch native implementation. + output = self.forward_native(x, cos, sin) + + return output + + def extra_repr(self) -> str: + s = f"is_neox_style={self.is_neox_style}" + s += f"enable_fp32_compute={self.enable_fp32_compute}" + return s diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 749cdbe88a62e..2eda63a34ac44 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -4,7 +4,6 @@ import torch -from .common import apply_rotary_emb_dispatch from .mrope import MRotaryEmbedding @@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb.forward_native( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb.forward_native( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0592aa8f967a6..a74bf092b182b 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -8,7 +8,6 @@ import torch from vllm.triton_utils import tl, triton from .base import RotaryEmbeddingBase -from .common import apply_rotary_emb_dispatch from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale @@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb.forward_native( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb.forward_native( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/xdrope.py b/vllm/model_executor/layers/rotary_embedding/xdrope.py index 2432273faf195..dab7aad9759a2 100644 --- a/vllm/model_executor/layers/rotary_embedding/xdrope.py +++ b/vllm/model_executor/layers/rotary_embedding/xdrope.py @@ -4,7 +4,6 @@ import numpy as np import torch -from .common import apply_rotary_emb_dispatch from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding @@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding): dtype, ) - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding): query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., : self.rotary_dim] query_pass = query[..., self.rotary_dim :] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = self.apply_rotary_emb.forward_native( + query_rot, + cos, + sin, + ) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] key_pass = key[..., self.rotary_dim :] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = self.apply_rotary_emb.forward_native( + key_rot, + cos, + sin, + ) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [4, num_tokens] (P/W/H/T positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1 + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1 + ) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = self.apply_rotary_emb( + query_rot, + cos, + sin, + ) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = self.apply_rotary_emb( + key_rot, + cos, + sin, + ) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 74052f72ceab9..7f94bd234fd38 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -4,6 +4,7 @@ import os from collections.abc import Generator import gguf +import regex as re import torch import torch.nn as nn from huggingface_hub import hf_hub_download @@ -94,6 +95,7 @@ class GGUFModelLoader(BaseModelLoader): hasattr(config, "vision_config") and config.vision_config is not None ) gguf_to_hf_name_map = {} + sideload_params: list[re.Pattern] = [] # hack: ggufs have a different name than transformers if model_type == "cohere": model_type = "command-r" @@ -118,6 +120,12 @@ class GGUFModelLoader(BaseModelLoader): gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( f"model.layers.{idx}.mlp.experts.0.up_proj.weight" ) + sideload_params.append( + re.compile( + f"model\\.layers\\.{idx}" + r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight" + ) + ) if model_type in ("qwen2_moe", "qwen3_moe"): model_type = model_type.replace("_", "") # GGUF layer map assumes that we will have a merged expert weights @@ -132,6 +140,12 @@ class GGUFModelLoader(BaseModelLoader): gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( f"model.layers.{idx}.mlp.experts.0.up_proj.weight" ) + sideload_params.append( + re.compile( + f"model\\.layers\\.{idx}" + r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight" + ) + ) arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): @@ -241,7 +255,15 @@ class GGUFModelLoader(BaseModelLoader): # Parameter not in manual overrides either unmapped_params.append(hf_name) - # All parameters must be mapped: both vision/projector and backbone + # All parameters (except those initialized by other means) must be mapped: + # both vision/projector and backbone + if unmapped_params: + unmapped_params = list( + filter( + lambda x: not any(re.fullmatch(p, x) for p in sideload_params), + unmapped_params, + ) + ) if unmapped_params: raise RuntimeError( f"Failed to map GGUF parameters " diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index eeb2444150eef..74b02e4c62583 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -167,7 +167,6 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: from vllm.model_executor.models.adapters import ( as_embedding_model, - as_reward_model, as_seq_cls_model, try_create_mm_pooling_model_cls, ) @@ -207,9 +206,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], elif convert_type == "classify": logger.debug_once("Converting to sequence classification model.") model_cls = as_seq_cls_model(model_cls) - elif convert_type == "reward": - logger.debug_once("Converting to reward model.") - model_cls = as_reward_model(model_cls) else: assert_never(convert_type) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0809bdfa9d4c2..610e6a620ade2 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -641,7 +641,6 @@ def safetensors_weights_iterator( if safetensors_load_strategy == "eager": loading_desc += " (eager)" - state_dict = {} leftover_state_dict: dict[str, torch.Tensor] = {} for st_file in tqdm( @@ -667,6 +666,7 @@ def safetensors_weights_iterator( ) with safe_open(st_file, framework="pt") as f: + state_dict = {} for name in f.keys(): # noqa: SIM118 state_dict[name] = f.get_tensor(name) @@ -921,7 +921,17 @@ def gguf_quant_weights_iterator( name = gguf_to_hf_name_map[tensor.name] if weight_type.name not in ("F32", "BF16", "F16"): name = name.replace("weight", "qweight") - param = torch.tensor(weight) + if weight_type.name == "BF16" and tensor.data.dtype == np.uint8: + # BF16 is currently the only "quantization" type that isn't + # actually quantized but is read as a raw byte tensor. + # Reinterpret as `torch.bfloat16` tensor. + weight = weight.view(np.uint16) + if reader.byte_order == "S": + # GGUF endianness != system endianness + weight = weight.byteswap() + param = torch.tensor(weight).view(torch.bfloat16) + else: + param = torch.tensor(weight) yield name, param diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 05f257feea3ee..504de9fe10871 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -175,9 +175,14 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T: self.vllm_config = vllm_config # These are not used in pooling models - for attr in ("lm_head", "logits_processor"): - if hasattr(self, attr): - delattr(self, attr) + objects_to_clean = [self] + if language_model := getattr(self, "language_model", None): + objects_to_clean.append(language_model) + + for obj in objects_to_clean: + for attr in ("lm_head", "logits_processor"): + if hasattr(obj, attr): + delattr(obj, attr) # If the model already defines a pooler instance, don't overwrite it if not getattr(self, "pooler", None): @@ -332,6 +337,18 @@ def as_seq_cls_model(cls: _T) -> _T: tokens = getattr(text_config, "classifier_from_token", None) method = getattr(text_config, "method", None) + def auto_set_score_bias(weights): + for name, weight in weights: + if name == "score.bias": + device = self.score.weight.device + dtype = self.score.weight.dtype + bias = weight.to(device).to(dtype) + self.score.bias = torch.nn.Parameter(bias) + self.score.skip_bias_add = False + else: + yield name, weight + + weights = auto_set_score_bias(weights) if tokens is None and method is None: return super().load_weights(weights) else: @@ -346,44 +363,6 @@ def as_seq_cls_model(cls: _T) -> _T: return ModelForSequenceClassification # type: ignore -def as_reward_model(cls: _T) -> _T: - """ - Subclass an existing vLLM model to support reward modeling. - - By default, we return the hidden states of each token directly. - - Note: - We assume that no extra layers are added to the original model; - please implement your own model if this is not the case. - """ - # Avoid modifying existing reward models - if is_pooling_model(cls): - return cls - - # Lazy import - from vllm.model_executor.layers.pooler import DispatchPooler, Pooler - - from .interfaces_base import default_pooling_type - - @default_pooling_type("ALL") - class ModelForReward(_create_pooling_model_cls(cls)): - def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config=pooler_config - ) - } - ) - - ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward") - - return ModelForReward # type: ignore - - class SequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: @@ -444,7 +423,7 @@ def load_weights_using_from_2_way_softmax( ) loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True) - from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.tokenizers import get_tokenizer tokenizer = get_tokenizer( model_config.tokenizer, @@ -498,7 +477,7 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te # Skip ModelForSequenceClassification in MRO to avoid infinite recursion loaded_weights = type(model).__mro__[1].load_weights(model, weights) - from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.tokenizers import get_tokenizer tokenizer = get_tokenizer( model_config.tokenizer, diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py index 85827d54c911a..f5dfe43067414 100644 --- a/vllm/model_executor/models/afmoe.py +++ b/vllm/model_executor/models/afmoe.py @@ -241,9 +241,8 @@ class AfmoeAttention(nn.Module): if self.is_local_attention: self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, - rope_parameters=config["rope_parameters"], + rope_parameters=config.rope_parameters, is_neox_style=True, ) else: diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 4a69787af55e2..e3f97a718b0f4 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -148,8 +148,6 @@ class ApertusAttention(nn.Module): if head_dim is None: head_dim = self.hidden_size // self.total_num_heads self.head_dim = head_dim - # Phi models introduced a partial_rotary_factor parameter in the config - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -228,11 +226,9 @@ class ApertusAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=int(self.partial_rotary_factor * self.head_dim), max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor, ) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 266d29a8d9b2b..0200984c0ec85 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -314,7 +314,6 @@ class ArcticAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 3d07e6b612ca3..c6d7f19cbe90d 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -499,8 +499,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): model to perform tasks that involve both image and text inputs. """ - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 diff --git a/vllm/model_executor/models/audioflamingo3.py b/vllm/model_executor/models/audioflamingo3.py new file mode 100644 index 0000000000000..0ca5f2c4e0a75 --- /dev/null +++ b/vllm/model_executor/models/audioflamingo3.py @@ -0,0 +1,639 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal, TypeAlias + +import torch +import torch.nn as nn +from transformers import BatchFeature, PretrainedConfig +from transformers.models.audioflamingo3 import ( + AudioFlamingo3Config, + AudioFlamingo3Processor, +) +from transformers.models.qwen2_audio import Qwen2AudioEncoder + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ModalityData, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + init_vllm_registered_model, + maybe_prefix, +) + +MAX_AUDIO_LEN = 10 * 60 + + +# === Audio Inputs === # +class AudioFlamingo3FeatureInputs(TensorSchema): + """ + Dimensions: + - num_chunks: Number of audio chunks (flattened) + - nmb: Number of mel bins + - num_audios: Number of original audio files + """ + + type: Literal["audio_features"] + input_features: Annotated[ + torch.Tensor | list[torch.Tensor], + TensorShape("num_chunks", "nmb", 3000), + ] + + feature_attention_mask: Annotated[ + torch.Tensor, + TensorShape("num_chunks", 3000), + ] + + chunk_counts: Annotated[ + torch.Tensor, + TensorShape("num_audios"), + ] + + +class AudioFlamingo3EmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size + - naf: Number of audio features + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + + type: Literal["audio_embeds"] = "audio_embeds" + + audio_embeds: Annotated[ + list[torch.Tensor], + TensorShape("bn", "naf", "hs"), + ] + + +AudioFlamingo3Inputs: TypeAlias = ( + AudioFlamingo3FeatureInputs | AudioFlamingo3EmbeddingInputs +) + + +class AudioFlamingo3Encoder(Qwen2AudioEncoder): + def __init__( + self, + config: PretrainedConfig, + ): + super().__init__(config) + self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2) + # self.layer_norm is already initialized in super().__init__ + + def forward( + self, + input_features: torch.Tensor | list[torch.Tensor], + attention_mask: torch.Tensor = None, + ): + # input_features: (batch, num_mel_bins, seq_len) + if isinstance(input_features, list): + input_features = torch.stack(input_features) + + hidden_states = nn.functional.gelu(self.conv1(input_features)) + hidden_states = nn.functional.gelu(self.conv2(hidden_states)) + hidden_states = hidden_states.transpose(-1, -2) + hidden_states = ( + hidden_states + self.embed_positions.weight[: hidden_states.size(-2), :] + ).to(hidden_states.dtype) + + for layer in self.layers: + layer_outputs = layer(hidden_states, attention_mask) + hidden_states = layer_outputs[0] + + # AvgPool (time/2) + LayerNorm + # hidden_states: (batch, seq_len, hidden_size) + hidden_states = hidden_states.permute(0, 2, 1) # (batch, hidden_size, seq_len) + hidden_states = self.avg_pooler(hidden_states) + hidden_states = hidden_states.permute( + 0, 2, 1 + ) # (batch, seq_len/2, hidden_size) + hidden_states = self.layer_norm(hidden_states) + + return hidden_states + + def _get_feat_extract_output_lengths(self, input_lengths: torch.Tensor): + """ + Computes the output length of the convolutional layers and the output length + of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +class AudioFlamingo3MultiModalProjector(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.linear_1 = nn.Linear( + config.audio_config.hidden_size, + config.text_config.hidden_size, + bias=config.projector_bias, + ) + self.act = get_act_fn(config.projector_hidden_act) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=config.projector_bias, + ) + + def forward(self, audio_features): + hidden_states = self.linear_1(audio_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class AudioFlamingo3ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(AudioFlamingo3Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs) + + def get_feature_extractor(self, **kwargs: object): + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": None} + + +class AudioFlamingo3DummyInputsBuilder( + BaseDummyInputsBuilder[AudioFlamingo3ProcessingInfo] +): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + hf_processor = self.info.get_hf_processor() + audio_token = hf_processor.audio_token + return audio_token * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + feature_extractor = self.info.get_feature_extractor() + sampling_rate = feature_extractor.sampling_rate + audio_len = MAX_AUDIO_LEN * sampling_rate + num_audios = mm_counts.get("audio", 0) + audio_overrides = mm_options.get("audio") if mm_options else None + + return { + "audio": self._get_dummy_audios( + length=audio_len, + num_audios=num_audios, + overrides=audio_overrides, + ) + } + + +def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]): + chunk_counts = hf_inputs.get("chunk_counts") + if chunk_counts is not None: + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.flat_from_sizes( + "audio", chunk_counts, dim=0 + ), + feature_attention_mask=MultiModalFieldConfig.flat_from_sizes( + "audio", chunk_counts, dim=0 + ), + chunk_counts=MultiModalFieldConfig.batched("audio"), + ) + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + chunk_counts=MultiModalFieldConfig.batched("audio"), + ) + + +class AudioFlamingo3MultiModalDataParser(MultiModalDataParser): + def _parse_audio_data( + self, + data: dict[str, torch.Tensor] | ModalityData[Any], + ) -> ModalityDataItems[Any, Any] | None: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_audioflamingo3_field_config, + ) + return super()._parse_audio_data(data) + + +class AudioFlamingo3MultiModalProcessor( + BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo] +): + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return AudioFlamingo3MultiModalDataParser( + target_sr=feature_extractor.sampling_rate + ) + + def _call_hf_processor( + self, + prompt: str, + mm_data: dict[str, object], + mm_kwargs: Mapping[str, Any], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + audios = mm_data.pop("audios", []) + if audios: + mm_data["audio"] = audios + + if not mm_data.get("audio", []): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + + # Calculate chunk counts + audio_list = mm_data.get("audio") + if not isinstance(audio_list, list): + audio_list = [audio_list] + + chunk_counts = [] + sampling_rate = feature_extractor.sampling_rate + chunk_length = feature_extractor.chunk_length + window_size = int(sampling_rate * chunk_length) + # MAX_AUDIO_LEN is 10 * 60 in HF processor. + max_windows = int(MAX_AUDIO_LEN // chunk_length) + + for audio in audio_list: + # audio is numpy array or list + n_samples = len(audio) if isinstance(audio, list) else audio.shape[0] + + n_win = max(1, (n_samples + window_size - 1) // window_size) + if n_win > max_windows: + n_win = max_windows + chunk_counts.append(n_win) + + outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + if "input_features_mask" in outputs: + outputs["feature_attention_mask"] = outputs.pop("input_features_mask") + + outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long) + + return outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _audioflamingo3_field_config(hf_inputs) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = getattr(processor, "audio_token", "<sound>") + audio_token_id = vocab.get(audio_token) + if audio_token_id is None: + # Fallback if not found, though it should be there + audio_token_id = processor.audio_token_id + + out_mm_data = out_mm_kwargs.get_data() + feature_attention_mask = out_mm_data.get("feature_attention_mask") + chunk_counts = out_mm_data.get("chunk_counts") + + def get_replacement_audioflamingo3(item_idx: int): + if feature_attention_mask is not None: + if chunk_counts is not None: + counts = ( + chunk_counts.tolist() + if isinstance(chunk_counts, torch.Tensor) + else chunk_counts + ) + start_idx = sum(counts[:item_idx]) + count = counts[item_idx] + end_idx = start_idx + count + + if isinstance(feature_attention_mask, list): + mask_list = feature_attention_mask[start_idx:end_idx] + if len(mask_list) > 0 and isinstance( + mask_list[0], torch.Tensor + ): + mask = torch.stack(mask_list) + else: + mask = torch.tensor(mask_list) + else: + mask = feature_attention_mask[start_idx:end_idx] + else: + # feature_attention_mask is list[Tensor] or Tensor + if isinstance(feature_attention_mask, list): + mask = feature_attention_mask[item_idx] + else: + mask = feature_attention_mask[item_idx].unsqueeze(0) + + # mask shape: (num_chunks, 3000) + input_lengths = mask.sum(-1) + conv_lengths = (input_lengths - 1) // 2 + 1 + audio_output_lengths = (conv_lengths - 2) // 2 + 1 + num_features = audio_output_lengths.sum().item() + else: + audio_embeds = out_mm_data["audio_embeds"][item_idx] + num_features = audio_embeds.shape[0] + + if num_features == 0: + raise ValueError("Audio is too short") + + audio_tokens = [audio_token_id] * int(num_features) + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_audioflamingo3, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + AudioFlamingo3MultiModalProcessor, + info=AudioFlamingo3ProcessingInfo, + dummy_inputs=AudioFlamingo3DummyInputsBuilder, +) +class AudioFlamingo3ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA +): + """ + AudioFlamingo3 model for conditional generation. + + This model integrates a Whisper-based audio encoder with a Qwen2 language model. + It supports multi-chunk audio processing. + """ + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model.", + connector="multi_modal_projector.", + tower_model="audio_tower.", + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + self.audio_tower = AudioFlamingo3Encoder( + config.audio_config, + ) + self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config) + + self.quant_config = quant_config + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> AudioFlamingo3Inputs | None: + input_features = kwargs.pop("input_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + feature_attention_mask = kwargs.pop("feature_attention_mask", None) + chunk_counts = kwargs.pop("chunk_counts", None) + + if input_features is None and audio_embeds is None: + return None + + if audio_embeds is not None: + return AudioFlamingo3EmbeddingInputs( + type="audio_embeds", audio_embeds=audio_embeds + ) + + if input_features is not None: + return AudioFlamingo3FeatureInputs( + type="audio_features", + input_features=input_features, + feature_attention_mask=feature_attention_mask, + chunk_counts=chunk_counts, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: AudioFlamingo3Inputs + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + if audio_input["type"] == "audio_embeds": + audio_embeds = audio_input["audio_embeds"] + return tuple(audio_embeds) + + input_features = audio_input["input_features"] + feature_attention_mask = audio_input["feature_attention_mask"] + chunk_counts = audio_input.get("chunk_counts") + + if isinstance(input_features, list): + input_features = torch.cat(input_features, dim=0) + feature_attention_mask = torch.cat(feature_attention_mask, dim=0) + + if chunk_counts is None: + chunk_counts = [1] * input_features.shape[0] + elif isinstance(chunk_counts, torch.Tensor): + chunk_counts = chunk_counts.tolist() + elif ( + isinstance(chunk_counts, list) + and chunk_counts + and isinstance(chunk_counts[0], torch.Tensor) + ): + chunk_counts = [c.item() for c in chunk_counts] + + # Calculate output lengths + input_lengths = feature_attention_mask.sum(-1) + # Conv downsampling + conv_lengths = (input_lengths - 1) // 2 + 1 + # AvgPool downsampling + audio_output_lengths = (conv_lengths - 2) // 2 + 1 + + batch_size, _, max_mel_seq_len = input_features.shape + + # Calculate max_seq_len after convs (before pooling) for attention mask + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = ( + torch.arange( + 0, + max_seq_len, + dtype=conv_lengths.dtype, + device=conv_lengths.device, + ) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) + lengths_expand = conv_lengths.unsqueeze(-1).expand(batch_size, max_seq_len) + # Create mask + padding_mask = seq_range >= lengths_expand + + audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) + audio_attention_mask = audio_attention_mask_.to( + dtype=self.audio_tower.conv1.weight.dtype, + device=self.audio_tower.conv1.weight.device, + ) + audio_attention_mask[audio_attention_mask_] = float("-inf") + + # Forward pass + audio_features = self.audio_tower( + input_features, attention_mask=audio_attention_mask + ) + + # Project + audio_features = self.multi_modal_projector(audio_features) + + # Masking after pooling + num_audios, max_audio_tokens, embed_dim = audio_features.shape + audio_output_lengths = audio_output_lengths.unsqueeze(1) + audio_features_mask = ( + torch.arange(max_audio_tokens) + .expand(num_audios, max_audio_tokens) + .to(audio_output_lengths.device) + < audio_output_lengths + ) + masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim) + + # Split to tuple of embeddings for individual audio input. + chunk_embeddings = torch.split( + masked_audio_features, audio_output_lengths.flatten().tolist() + ) + + grouped_embeddings = [] + current_idx = 0 + for count in chunk_counts: + audio_chunks = chunk_embeddings[current_idx : current_idx + count] + grouped_embeddings.append(torch.cat(audio_chunks, dim=0)) + current_idx += count + return tuple(grouped_embeddings) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return [] + masked_audio_features = self._process_audio_input(audio_input) + return masked_audio_features + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 0ada2ed5028bb..ee9e210a3240f 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -318,8 +318,6 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: dummy_inputs=AyaVisionDummyInputsBuilder, ) class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 diff --git a/vllm/model_executor/models/bagel.py b/vllm/model_executor/models/bagel.py new file mode 100644 index 0000000000000..98229c6d4ca1b --- /dev/null +++ b/vllm/model_executor/models/bagel.py @@ -0,0 +1,584 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +"""Inference-only BAGEL model compatible with HuggingFace weights. + +BAGEL is a unified multimodal model for image understanding and generation. +For vLLM, we focus on the image understanding (vision-to-text) capabilities. +""" + +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Literal, TypeAlias + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processors.bagel import BagelProcessor +from vllm.utils.tensor_schema import TensorSchema + +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .siglip import SiglipVisionModel +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class BagelImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ + + type: Literal["pixel_values"] + pixel_values: torch.Tensor # Shape: (bn, 3, h, w) + + +BagelImageInputs: TypeAlias = BagelImagePixelInputs + + +class BagelVisionMLP(nn.Module): + """MLP connector for vision features.""" + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + act_layer: str = "gelu_pytorch_tanh", + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear( + in_features, + hidden_features, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.act = get_act_fn(act_layer) + self.fc2 = RowParallelLinear( + hidden_features, + out_features, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.act(x) + x, _ = self.fc2(x) + return x + + +class PositionEmbedding(nn.Module): + """2D position embedding for vision tokens using sin-cos embeddings.""" + + def __init__(self, max_num_patch_per_side: int, hidden_size: int): + super().__init__() + self.max_num_patch_per_side = max_num_patch_per_side + self.hidden_size = hidden_size + + # Create learnable 2D position embeddings (frozen sin-cos) + pos_embed = self._get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side) + self.register_buffer( + "pos_embed", + torch.from_numpy(pos_embed).float(), + persistent=False, + ) + + @staticmethod + def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int): + """Generate 2D sin-cos position embeddings.""" + import numpy as np + + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # w goes first + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = PositionEmbedding._get_2d_sincos_pos_embed_from_grid( + embed_dim, grid + ) + return pos_embed + + @staticmethod + def _get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid): + """Generate 2D sin-cos position embeddings from grid.""" + import numpy as np + + assert embed_dim % 2 == 0 + # use half of dimensions to encode grid_h + emb_h = PositionEmbedding._get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0] + ) + emb_w = PositionEmbedding._get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1] + ) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + @staticmethod + def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos): + """Generate 1D sin-cos position embeddings.""" + import numpy as np + + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + """ + Args: + position_ids: Flattened position IDs, shape (N,) where each ID + corresponds to a position in the flattened grid + Returns: + Position embeddings of shape (N, hidden_size) + """ + # Ensure position_ids are on the same device as pos_embed + position_ids = position_ids.to(self.pos_embed.device) + return self.pos_embed[position_ids] + + +class BagelProcessingInfo(BaseProcessingInfo): + """Processing information for BAGEL model.""" + + def get_hf_processor(self, **kwargs: object) -> BagelProcessor: + from vllm.transformers_utils.processor import cached_get_image_processor + + image_processor = cached_get_image_processor( + self.ctx.model_config.model, + revision=self.ctx.model_config.revision, + trust_remote_code=self.ctx.model_config.trust_remote_code, + ) + + tokenizer = self.get_tokenizer() + + return BagelProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + **kwargs, + ) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + hf_config = self.get_hf_config() + # Calculate max tokens per image + # For BAGEL: (vit_max_num_patch_per_side) ** 2 + max_num_patches = hf_config.vit_max_num_patch_per_side**2 + return {"image": max_num_patches} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vit_config = hf_config.vit_config + patch_size = vit_config.patch_size + + # Calculate number of patches + num_patches_h = image_height // patch_size + num_patches_w = image_width // patch_size + return num_patches_h * num_patches_w + + +class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]): + """Build dummy inputs for BAGEL model profiling.""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + # Use a simple placeholder for each image + return "<|image_pad|>" * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + hf_config = self.info.get_hf_config() + vit_config = hf_config.vit_config + + # Use the configured image size + image_size = vit_config.image_size + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=image_size, + height=image_size, + num_images=num_images, + overrides=image_overrides, + ), + } + + +class BagelMultiModalProcessor(BaseMultiModalProcessor[BagelProcessingInfo]): + """Multimodal processor for BAGEL model.""" + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptReplacement]: + """Replace image placeholders with the correct number of tokens.""" + hf_config = self.info.get_hf_config() + + # Get the tokenizer to look up the image token ID + tokenizer = self.info.get_tokenizer() + image_token_id = tokenizer.get_vocab().get("<|image_pad|>") + if image_token_id is None: + raise ValueError( + "Image token '<|image_pad|>' not found in tokenizer vocabulary" + ) + + def get_replacement_bagel(item_idx: int): + # For BAGEL, calculate number of tokens based on max patch size + num_tokens = hf_config.vit_max_num_patch_per_side**2 + # Use the image token ID from tokenizer + return [image_token_id] * num_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_bagel, + ) + ] + + def _get_mm_fields_config( + self, + hf_inputs: Any, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return { + "pixel_values": MultiModalFieldConfig.batched("image"), + } + + +@MULTIMODAL_REGISTRY.register_processor( + BagelMultiModalProcessor, + info=BagelProcessingInfo, + dummy_inputs=BagelDummyInputsBuilder, +) +class BagelForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP +): + """ + BAGEL: A unified multimodal model for image understanding and generation. + + For vLLM, we focus on the image understanding (vision-to-text) capabilities. + The image generation part is not supported in vLLM. + """ + + # Weight mapping from HF to vLLM + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.": "language_model.", + "vit_model.": "vit_model.", + "connector.": "connector.", + "vit_pos_embed.": "vit_pos_embed.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + # Ensure we have a BagelConfig (check by name to handle trust_remote_code) + # When trust_remote_code=True, the config comes from transformers_modules + if type(config).__name__ != "BagelConfig": + raise ValueError( + f"Expected BagelConfig, got {type(config).__name__}. " + "Make sure the model config is properly loaded." + ) + + self.config = config + self.multimodal_config = multimodal_config + + # Initialize language model (Qwen2) + # Pass the llm_config from BagelConfig to initialize Qwen2 properly + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.llm_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + # Initialize vision model (SigLIP) if visual understanding is enabled + if config.visual_und: + # Fix vit_config: checkpoint has 26 layers (0-25) but config says 27 + # Also disable head as it's not in checkpoint + vit_config = config.vit_config + if vit_config.num_hidden_layers == 27: + logger.warning( + "Overriding vit_config.num_hidden_layers from 27 to 26 " + "to match the Bagel model checkpoint." + ) + vit_config.num_hidden_layers = 26 + if not hasattr(vit_config, "vision_use_head"): + logger.warning( + "Setting vit_config.vision_use_head to False as it is not " + "present in the Bagel model checkpoint." + ) + vit_config.vision_use_head = False + + self.vit_model = SiglipVisionModel( + config=vit_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vit_model"), + ) + + # Initialize connector (MLP) + vit_hidden_size = config.vit_config.hidden_size + llm_hidden_size = config.llm_config.hidden_size + + self.connector = BagelVisionMLP( + in_features=vit_hidden_size, + hidden_features=llm_hidden_size, + out_features=llm_hidden_size, + act_layer=config.connector_act, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "connector"), + ) + + # Position embedding for vision tokens + self.vit_pos_embed = PositionEmbedding( + max_num_patch_per_side=config.vit_max_num_patch_per_side, + hidden_size=llm_hidden_size, + ) + else: + self.vit_model = None + self.connector = None + self.vit_pos_embed = None + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> BagelImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is None: + return None + + return BagelImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + ) + + def _process_image_input( + self, image_input: BagelImageInputs + ) -> tuple[torch.Tensor, ...]: + """Process image inputs through vision encoder and connector.""" + pixel_values = image_input["pixel_values"] + + # Handle potential extra batch dimension + # Expected shape: (batch_size * num_images, 3, H, W) + # But might receive: (batch_size, num_images, 3, H, W) + if pixel_values.ndim == 5: + # Flatten batch and num_images dimensions + batch_size, num_images, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape( + batch_size * num_images, channels, height, width + ) + + # Get vision features from SigLIP + # pixel_values shape: (batch_size * num_images, 3, H, W) + vision_features = self.vit_model(pixel_values) + + # Pass through connector + vision_embeds = self.connector(vision_features) + + # Add position embeddings + batch_size, num_patches, hidden_size = vision_embeds.shape + patch_size = self.config.vit_config.patch_size + image_size = self.config.vit_config.image_size + + # Calculate grid dimensions + num_patches_per_side = image_size // patch_size + + # Create flattened position IDs (0 to num_patches-1) + # For BAGEL, we use extrapolate mode by default + h_coords = torch.arange(num_patches_per_side, device=vision_embeds.device) + w_coords = torch.arange(num_patches_per_side, device=vision_embeds.device) + position_ids = ( + h_coords[:, None] * self.config.vit_max_num_patch_per_side + w_coords + ).flatten() + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1).flatten() + + # Add position embeddings + pos_embeds = self.vit_pos_embed(position_ids) + pos_embeds = pos_embeds.reshape(batch_size, num_patches, hidden_size) + # Ensure pos_embeds are on the same device as vision_embeds + pos_embeds = pos_embeds.to(vision_embeds.device) + vision_embeds = vision_embeds + pos_embeds + + # Split by image + return tuple(vision_embeds) + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + """Get multimodal embeddings from input.""" + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input) + + def get_language_model(self) -> nn.Module: + return self.language_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + """Run forward pass for BAGEL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a batch. + positions: Flattened (concatenated) position ids corresponding to a batch. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. + """ + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights from checkpoint.""" + skip_prefixes = [] + # Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module + skip_prefixes.append("vit_pos_embed.pos_embed") + + # If visual understanding is disabled, skip vision-related weights + if self.vit_model is None: + skip_prefixes.extend(["vit_model.", "connector.", "vit_pos_embed"]) + + # Skip generation-related weights since we only support text2text and image2text + # Filter out all image generation components: + # - 'moe_gen': MoE generation weights + # - 'latent_pos_embed': Latent position embeddings for VAE + # - 'llm2vae', 'vae2llm': LLM-VAE projections + # - 'time_embedder': Timestep embeddings for diffusion + # - VAE encoder/decoder: Use specific prefixes to avoid matching vision encoder + generation_keywords = [ + "moe_gen", + "latent_pos_embed", + "llm2vae", + "vae2llm", + "time_embedder", + ] + vae_prefixes = [ + "decoder.", + "encoder.", + ] # VAE encoder/decoder, not vision encoder + filtered_weights = [] + for name, tensor in weights: + # Skip generation-related keywords + if any(skip in name for skip in generation_keywords): + continue + if any(name.startswith(prefix) for prefix in vae_prefixes): + continue + + if "patch_embedding.weight" in name and tensor.ndim == 2: + out_channels = tensor.shape[0] + in_features = tensor.shape[1] + patch_size = self.config.vit_config.patch_size + in_channels = self.config.vit_config.num_channels + if in_features == in_channels * patch_size * patch_size: + tensor = tensor.reshape( + out_channels, patch_size, patch_size, in_channels + ) + tensor = tensor.permute(0, 3, 1, 2).contiguous() + + filtered_weights.append((name, tensor)) + + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index beb22995a0719..ee4a1dbd6df94 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -189,7 +189,6 @@ class BaiChuanAttention(nn.Module): else: self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=rope_parameters, ) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index f7a5d4e7889e5..4bccee7521749 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -127,17 +127,14 @@ class BailingAttention(nn.Module): prefix=f"{prefix}.dense", ) - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) - - self.rotary_dim = getattr(config, "rotary_dim", self.head_dim) + rotary_dim = getattr(config, "rotary_dim", self.head_dim) + config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.rotary_dim, max_position=config.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, - partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention( diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 1d6493b18c343..22631bbc5489b 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -178,16 +178,11 @@ class BambaAttentionDecoderLayer(nn.Module): self.scaling = self.head_dim**-0.5 self.max_position_embeddings = max_position_embeddings - if hasattr(config, "partial_rotary_factor"): - rotary_dim = int(self.head_dim * config.partial_rotary_factor) - elif hasattr(config, "attn_rotary_emb"): - rotary_dim = config.attn_rotary_emb # for backward compatibility - else: - rotary_dim = self.head_dim # default + rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim) + config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim self.rotary_emb = get_rope( head_size=self.head_dim, - rotary_dim=rotary_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index e774cd647ea8c..ee429bf458843 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -55,7 +55,9 @@ class BertEmbedding(nn.Module): "position_ids", torch.arange(config.max_position_embeddings).unsqueeze(0), ) - self.position_embedding_type = config.position_embedding_type + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) if self.position_embedding_type != "absolute": raise ValueError( "Only 'absolute' position_embedding_type" + " is supported" diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index f71b9c01d359d..1244f97a1bd68 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -523,8 +523,6 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): class Blip2ForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant ): - merge_by_field_config = True - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 3aa01bb1905fe..176c5cd14c6e2 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -314,7 +314,6 @@ class ChameleonAttention(nn.Module): self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim)) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, ) @@ -918,8 +917,6 @@ class ChameleonModel(nn.Module): class ChameleonForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant ): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 3d485fdd0a2e1..26181d1c9bae4 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -99,13 +99,16 @@ class GLMAttention(nn.Module): # https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 rope_ratio = getattr(config, "rope_ratio", 1.0) max_positions = getattr(config, "seq_length", 8192) - rope_parameters = {"rope_type": "default", "rope_theta": 10000 * rope_ratio} + rope_parameters = { + "rope_type": "default", + "rope_theta": 10000 * rope_ratio, + "partial_rotary_factor": 0.5, + } # NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False, # which is equivalent to is_neox_style=True is_neox_style = not config.original_rope self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim // 2, max_position=max_positions, rope_parameters=rope_parameters, is_neox_style=is_neox_style, diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index b8af3050990bc..22f3ecad748e6 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -784,7 +784,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): is_pooling_model = True packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 139ccba9df6d8..07dc7a01dc316 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -331,8 +331,6 @@ class Cohere2VisionMultiModalProcessor( dummy_inputs=Cohere2VisionDummyInputsBuilder, ) class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_tower.": "vision_tower.", diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index f837502c468f1..63a93eaa2d4f3 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -175,7 +175,6 @@ class CohereAttention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=False, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index d7e802ba1aca0..4b08472538db4 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -4,11 +4,10 @@ from copy import deepcopy from math import lcm from typing import TYPE_CHECKING -import vllm.envs as envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform -from vllm.transformers_utils.config import set_default_rope_theta from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec @@ -43,9 +42,10 @@ class GteNewModelConfig(VerifyAndUpdateConfig): config.hidden_act = "geglu" head_dim = config.hidden_size // config.num_attention_heads + rotary_dim = getattr(config, "rotary_emb_dim", head_dim) + config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim config.rotary_kwargs = { "head_size": head_dim, - "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), "max_position": config.max_position_embeddings, "rope_parameters": config.rope_parameters, } @@ -78,11 +78,11 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig): if not model_config.enforce_eager: max_position = round_up(max_position, 8) - set_default_rope_theta(config, default_theta=config.rotary_emb_base) + rotary_dim = getattr(config, "rotary_emb_dim", head_dim) + config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim config.rotary_kwargs = { "head_size": head_dim, - "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), "max_position": max_position, "rope_parameters": config.rope_parameters, } @@ -116,14 +116,10 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): config.num_hidden_layers = config.n_layer head_dim = config.hidden_size // config.num_attention_heads - rotary_emb_dim = int(head_dim * config.rotary_emb_fraction) max_trained_positions = getattr(config, "max_trained_positions", 2048) - set_default_rope_theta(config, default_theta=config.rotary_emb_base) - config.rotary_kwargs = { "head_size": head_dim, - "rotary_dim": rotary_emb_dim, "max_position": max_trained_positions, "rope_parameters": config.rope_parameters, } @@ -219,7 +215,7 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): tokens = getattr(config, "classifier_from_token", None) assert tokens is not None and len(tokens) == 2, ( "Try loading the original Qwen3 Reranker?, see: " - "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py" + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py" ) vllm_config.model_config.hf_config.method = "from_2_way_softmax" @@ -245,9 +241,10 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): config.hidden_act = "geglu" head_dim = config.hidden_size // config.num_attention_heads + rotary_dim = getattr(config, "rotary_emb_dim", head_dim) + config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim config.rotary_kwargs = { "head_size": head_dim, - "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), "max_position": config.max_position_embeddings, "rope_parameters": config.rope_parameters, } @@ -336,6 +333,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config) + attention_config = vllm_config.attention_config cache_config = vllm_config.cache_config model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config @@ -352,7 +350,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): # * CUTLASS_MLA backend: kernel_block_size 128 alignment # * Other MLA backends: kernel_block_size 64 alignment if model_config.use_mla: - use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + use_cutlass_mla = ( + attention_config.backend == AttentionBackendEnum.CUTLASS_MLA + ) kernel_block_alignment_size = 128 if use_cutlass_mla else 64 attn_page_size_1_token = MLAAttentionSpec( block_size=1, @@ -363,11 +363,11 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): else: kernel_block_alignment_size = 16 if ( - current_platform.is_device_capability(100) + current_platform.is_device_capability_family(100) and model_config.get_head_size() == 256 and ( - envs.VLLM_ATTENTION_BACKEND is None - or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER" + attention_config.backend is None + or attention_config.backend == AttentionBackendEnum.FLASHINFER ) ): # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that` @@ -490,6 +490,26 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): logger.info("Using bfloat16 kv-cache for DeepSeekV3.2") +class NemotronHForCausalLMConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto' + (or not explicitly set), to the value specified in the HF config, or to + float16 if not specified. + """ + cache_config = vllm_config.cache_config + if cache_config.mamba_ssm_cache_dtype == "auto": + hf_config = vllm_config.model_config.hf_config + mamba_ssm_cache_dtype = getattr( + hf_config, "mamba_ssm_cache_dtype", "float16" + ) + logger.info( + "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model", + mamba_ssm_cache_dtype, + ) + cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype + + MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, @@ -507,4 +527,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "Mamba2ForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig, "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, + "NemotronHForCausalLM": NemotronHForCausalLMConfig, } diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 946baffc8817a..db4fe61b0d85f 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -222,7 +222,6 @@ class DbrxAttention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position, rope_parameters=rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 6e23037b919ab..ca77b8322e2e8 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -346,11 +346,16 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): # Use expert_params_mapping to locate the destination # param and delegate to its expert-aware weight_loader # with expert_id. + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in chunk_name: continue + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + # Do not modify `name` since the loop may continue here # Instead, create a new variable name_mapped = chunk_name.replace(weight_name, param_name) @@ -377,6 +382,12 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): loaded_params.add(name_mapped) break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index 8179f916ff417..1f07381c0cbd0 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -27,7 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, + MultiModalKwargsItems, NestedTensors, ) from vllm.multimodal.parse import ( @@ -45,6 +45,7 @@ from vllm.multimodal.processing import ( from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors +from vllm.tokenizers import cached_tokenizer_from_config from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.processors.deepseek_ocr import ( BASE_SIZE, @@ -53,7 +54,6 @@ from vllm.transformers_utils.processors.deepseek_ocr import ( DeepseekOCRProcessor, count_tiles, ) -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.v1.sample.logits_processor import ( AdapterLogitsProcessor, @@ -305,7 +305,7 @@ class DeepseekOCRMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -344,8 +344,6 @@ class DeepseekOCRMultiModalProcessor( dummy_inputs=DeepseekOCRDummyInputsBuilder, ) class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # map prefix for language backbone diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 73cac2556c55a..146124153c79d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -83,6 +83,7 @@ from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec +from vllm.v1.worker.workspace import current_workspace_manager from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP from .utils import ( @@ -156,7 +157,6 @@ class DeepseekAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, ) @@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 +def _get_llama_4_scaling( + original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor +) -> torch.Tensor: + scaling = 1 + scaling_beta * torch.log( + 1 + torch.floor(positions / original_max_position_embeddings) + ) + # Broadcast over num_heads and head_dim + return scaling[..., None, None] + + class DeepseekV2Attention(nn.Module): def __init__( self, @@ -481,17 +491,23 @@ class DeepseekV2Attention(nn.Module): prefix=f"{prefix}.o_proj", ) if config.rope_parameters["rope_type"] != "default": - config.rope_parameters["rope_type"] = "deepseek_yarn" + config.rope_parameters["rope_type"] = ( + "deepseek_yarn" + if config.rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) self.rotary_emb = get_rope( qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=False, ) - if config.rope_parameters["rope_type"] != "default": + if ( + config.rope_parameters["rope_type"] != "default" + and config.rope_parameters["rope_type"] == "deepseek_yarn" + ): mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False) scaling_factor = config.rope_parameters["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) @@ -511,6 +527,7 @@ class DeepseekV2Attention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None, ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] @@ -536,6 +553,11 @@ class DeepseekV2Attention(nn.Module): k = torch.empty_like(q) k[..., : self.qk_nope_head_dim] = k_nope k[..., self.qk_nope_head_dim :] = k_pe + + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + # padding value to qk_head_dim for alignment v = torch.nn.functional.pad( v, [0, self.qk_head_dim - self.v_head_dim], value=0 @@ -595,8 +617,15 @@ def sparse_attn_indexer( # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata fp8_dtype = current_platform.fp8_dtype() + # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): + # Reserve workspace for indexer during profiling run + current_workspace_manager().get_simultaneous( + ((total_seq_lens, head_dim), torch.float8_e4m3fn), + ((total_seq_lens, 4), torch.uint8), + ) + return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, @@ -630,17 +659,17 @@ def sparse_attn_indexer( topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill + + # Get the full shared workspace buffers once (will allocate on first use) + workspace_manager = current_workspace_manager() + k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) + for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty( - [chunk.total_seq_lens, head_dim], - device=k.device, - dtype=fp8_dtype, - ) - k_scale = torch.empty( - [chunk.total_seq_lens, 4], - device=k.device, - dtype=torch.uint8, - ) + k_fp8 = k_fp8_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] ops.cp_gather_indexer_k_quant_cache( kv_cache, k_fp8, @@ -661,11 +690,10 @@ def sparse_attn_indexer( chunk.cu_seqlen_ke, ) num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens ] - torch.ops._C.top_k_per_row( + torch.ops._C.top_k_per_row_prefill( logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, @@ -673,6 +701,7 @@ def sparse_attn_indexer( num_rows, logits.stride(0), logits.stride(1), + topk_tokens, ) if has_decode: @@ -715,7 +744,6 @@ def sparse_attn_indexer( max_model_len=max_model_len, ) num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] torch.ops._C.top_k_per_row_decode( @@ -726,6 +754,7 @@ def sparse_attn_indexer( num_rows, logits.stride(0), logits.stride(1), + topk_tokens, ) if decode_metadata.requires_padding: # if padded, we need to unpack @@ -756,15 +785,6 @@ def sparse_attn_indexer_fake( total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: - # profile run - # NOTE(Chen): create the max possible flattened_kv. So that - # profile_run can get correct memory usage. - _flattened_kv = torch.empty( - [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 - ) - fp8_dtype = current_platform.fp8_dtype() - _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() - _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() return topk_indices_buffer @@ -987,15 +1007,23 @@ class DeepseekV2MLAAttention(nn.Module): ) if config.rope_parameters["rope_type"] != "default": - config.rope_parameters["rope_type"] = "deepseek_yarn" + config.rope_parameters["rope_type"] = ( + "deepseek_yarn" + if config.rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) + self.rotary_emb = get_rope( qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=False, ) - if config.rope_parameters["rope_type"] != "default": + + if ( + config.rope_parameters["rope_type"] != "default" + and config.rope_parameters["rope_type"] == "deepseek_yarn" + ): mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False) scaling_factor = config.rope_parameters["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) @@ -1006,7 +1034,6 @@ class DeepseekV2MLAAttention(nn.Module): if self.is_v32: self.indexer_rope_emb = get_rope( qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, @@ -1064,8 +1091,9 @@ class DeepseekV2MLAAttention(nn.Module): self, positions: torch.Tensor, hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None, ) -> torch.Tensor: - return self.mla_attn(positions, hidden_states) + return self.mla_attn(positions, hidden_states, llama_4_scaling) class DeepseekV2DecoderLayer(nn.Module): @@ -1102,6 +1130,8 @@ class DeepseekV2DecoderLayer(nn.Module): dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) ) + self.use_mha = use_mha + if use_mha: attn_cls = DeepseekAttention elif model_config.use_mla: @@ -1155,6 +1185,7 @@ class DeepseekV2DecoderLayer(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, + llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: # Self Attention if residual is None: @@ -1162,10 +1193,14 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) + + attn_kwargs = { + "positions": positions, + "hidden_states": hidden_states, + } + if not self.use_mha: + attn_kwargs["llama_4_scaling"] = llama_4_scaling + hidden_states = self.self_attn(**attn_kwargs) if ( not isinstance(self.self_attn, DeepseekAttention) @@ -1266,8 +1301,24 @@ class DeepseekV2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + # Compute llama 4 scaling once per forward pass if enabled + llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None) + llama_4_scaling: torch.Tensor | None + if llama_4_scaling_config is not None: + llama_4_scaling = _get_llama_4_scaling( + original_max_position_embeddings=llama_4_scaling_config[ + "original_max_position_embeddings" + ], + scaling_beta=llama_4_scaling_config["beta"], + positions=positions, + ) + else: + llama_4_scaling = None + for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM( packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } + model_cls = DeepseekV2Model def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1355,7 +1407,7 @@ class DeepseekV2ForCausalLM( "kv_a_proj_with_mqa", ] - self.model = DeepseekV2Model( + self.model = self.model_cls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 1b6e4110039c4..9f8faf9ed91ce 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -41,13 +41,13 @@ from vllm.multimodal.processing import ( ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.tokenizers import cached_tokenizer_from_config from vllm.transformers_utils.configs.deepseek_vl2 import ( DeepseekVLV2Config, MlpProjectorConfig, VisionEncoderConfig, ) from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.torch_utils import set_default_torch_dtype @@ -344,8 +344,6 @@ class DeepseekVL2MultiModalProcessor( dummy_inputs=DeepseekVL2DummyInputsBuilder, ) class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "language.": "language_model.", diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 3beee9f864634..870a37039f151 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -250,7 +250,6 @@ class Dots1Attention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, ) diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 5cc2a48f26d64..6d8dbec9236c9 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -5,15 +5,14 @@ from typing import Annotated, Literal, TypeAlias import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, ) -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import utils as dist_utils from vllm.distributed.parallel_state import ( @@ -30,6 +29,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, @@ -159,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo): return processor -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - tensor: torch.Tensor, freqs: torch.Tensor -) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - - cos = freqs.cos() - sin = freqs.sin() - - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - - output = (tensor * cos) + (rotate_half(tensor) * sin) - - output = output.to(orig_dtype) - - return output - - class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() @@ -254,11 +230,15 @@ class DotsVisionAttention(nn.Module): bias: bool = True, *, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.embed_dim = dim self.tp_size = ( @@ -287,31 +267,18 @@ class DotsVisionAttention(nn.Module): prefix=f"{prefix}.proj", disable_tp=use_data_parallel, ) - # Select attention backend - self.attn_backend = get_vit_attn_backend( - self.hidden_size_per_attention_head, - torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + multimodal_config=multimodal_config, + prefix=f"{prefix}.attn", ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, ) - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Unsupported vision attention backend: {self.attn_backend}" - ) - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } def forward( self, @@ -319,7 +286,7 @@ class DotsVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor | None = None, *, - max_seqlen: int | None = None, + max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: # [S, C] -> [S, B=1, C] x = hidden_states.unsqueeze(1) @@ -333,44 +300,20 @@ class DotsVisionAttention(nn.Module): if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb.cos(), + rotary_pos_emb.sin(), + ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) - k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) - v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) - output = self.flash_attn_varlen_func( - q_, - k_, - v_, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, - ) - context_layer = output.view( - bs, - -1, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - outputs = [] - for i in range(1, len(cu_seqlens)): - s = int(cu_seqlens[i - 1]) - e = int(cu_seqlens[i]) - q_i = q[:, s:e].permute(0, 2, 1, 3) - k_i = k[:, s:e].permute(0, 2, 1, 3) - v_i = v[:, s:e].permute(0, 2, 1, 3) - out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - out_i = out_i.permute(0, 2, 1, 3) - outputs.append(out_i) - context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - else: - raise RuntimeError("Unsupported attention backend") + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) # [B,S,H,D] -> [S,B,H*D] -> [S, C] context_layer = context_layer.permute(1, 0, 2, 3).contiguous() @@ -385,14 +328,19 @@ class DotsSwiGLUFFN(nn.Module): config, *, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() hidden_features = config.intermediate_size in_features = config.embed_dim bias = config.use_bias + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) # Referenced aimv2.py AIMv2SwiGLUFFN self.fc13 = MergedColumnParallelLinear( in_features, @@ -498,9 +446,8 @@ class DotsVisionBlock(nn.Module): config, *, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() @@ -510,16 +457,15 @@ class DotsVisionBlock(nn.Module): num_heads=config.num_attention_heads, bias=config.use_bias, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.mlp = DotsSwiGLUFFN( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) @@ -546,12 +492,11 @@ class DotsVisionTransformer(nn.Module): self, config: DotsVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.config = config @@ -561,6 +506,11 @@ class DotsVisionTransformer(nn.Module): head_dim = config.embed_dim // config.num_attention_heads self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -578,9 +528,8 @@ class DotsVisionTransformer(nn.Module): DotsVisionBlock( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{i}", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) for i in range(num_layers) ] @@ -592,6 +541,11 @@ class DotsVisionTransformer(nn.Module): else: self.post_trunk_norm = None + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.merger = PatchMerger( dim=config.hidden_size, context_dim=config.embed_dim, @@ -647,7 +601,7 @@ class DotsVisionTransformer(nn.Module): self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen def forward( @@ -690,8 +644,6 @@ class DotsVisionTransformer(nn.Module): dummy_inputs=DotsOCRDummyInputsBuilder, ) class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ ".attn.qkv_proj.": ".attn.qkv.", @@ -735,17 +687,12 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA self.config.vision_config = vision_config else: vision_config = self.config.vision_config - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) + self.vision_tower = DotsVisionTransformer( vision_config, quant_config=self.quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "vision_tower"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( vllm_config=vllm_config, diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 278ba45e9684c..fbbd31a485383 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -288,7 +288,6 @@ class Ernie4_5_MoeAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, is_neox_style=False, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 81663dd7bbb45..61cf78fdb5a67 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -33,14 +33,14 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from transformers import BatchFeature from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, ) -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -69,7 +72,6 @@ from vllm.multimodal.processing import ( PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -89,52 +91,6 @@ logger = init_logger(__name__) # === Vision Transformer === # -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False -) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - apply_rotary_emb = apply_rotary_emb_torch - if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - output = apply_rotary_emb(t_, cos, sin).type_as(t) - return output - - def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): """All-gather the input tensor interleavely across model parallel group.""" import torch.distributed as dist @@ -163,8 +119,8 @@ class Ernie4_5_VisionAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -193,33 +149,18 @@ class Ernie4_5_VisionAttention(nn.Module): prefix=f"{prefix}.proj", ) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, + prefix=f"{prefix}.attn", ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, ) - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Ernie45-VL does not support {self.attn_backend} backend now." - ) - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } - def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape @@ -253,58 +194,32 @@ class Ernie4_5_VisionAttention(nn.Module): x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] q, k, v = self.split_qkv(x) - batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb.cos(), + rotary_pos_emb.sin(), + ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = self.flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, - ) - - context_layer = rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + output = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + context_layer = rearrange(output, "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -350,8 +265,8 @@ class Ernie4_5_VisionBlock(nn.Module): act_layer: type[nn.Module] = QuickGELU, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -366,8 +281,8 @@ class Ernie4_5_VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - attn_backend_override=attn_backend_override, ) self.mlp = Ernie4_5_VisionMLP( @@ -383,7 +298,7 @@ class Ernie4_5_VisionBlock(nn.Module): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), @@ -441,8 +356,8 @@ class Ernie4_5_VisionTransformer(nn.Module): vision_config, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() patch_size = vision_config.patch_size @@ -477,8 +392,8 @@ class Ernie4_5_VisionTransformer(nn.Module): mlp_ratio=mlp_ratio, norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", - attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -489,6 +404,9 @@ class Ernie4_5_VisionTransformer(nn.Module): ) self.ln = nn.LayerNorm(hidden_size, eps=1e-6) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend if multimodal_config else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -535,13 +453,13 @@ class Ernie4_5_VisionTransformer(nn.Module): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None: + def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None: max_seqlen = None if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen def forward( @@ -1254,8 +1172,6 @@ class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessing class Ernie4_5_VLMoeForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1306,17 +1222,12 @@ class Ernie4_5_VLMoeForConditionalGeneration( self.config = config self.multimodal_config = multimodal_config - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.vision_model = Ernie4_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "vision_model"), - attn_backend_override=attn_backend_override, ) self.language_model = Ernie4_5_VLMoeForCausalLM( diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index acf651ed24988..039e7cf68e52b 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -167,7 +167,6 @@ class ExaoneAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=is_neox_style, diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index cb710a7ec5cf9..b4b7a798fd050 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -176,7 +176,6 @@ class Exaone4Attention(nn.Module): set_default_rope_theta(config, default_theta=1000000) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=is_neox_style, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 32d9e7b925597..7cdfcae0e718d 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -167,7 +167,6 @@ class FalconAttention(nn.Module): max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, ) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 83ceb9303cfb5..bfb6b1a1f160d 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -242,16 +242,11 @@ class FalconH1AttentionDecoderLayer(nn.Module): self.scaling = self.head_dim**-0.5 self.max_position_embeddings = max_position_embeddings - if hasattr(config, "partial_rotary_factor"): - rotary_dim = self.head_dim * config.partial_rotary_factor - elif hasattr(config, "attn_rotary_emb"): - rotary_dim = config.attn_rotary_emb # for backward compatibility - else: - rotary_dim = self.head_dim # default + rotary_dim = getattr(config, "attn_rotary_emb", self.head_dim) + config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim self.rotary_emb = get_rope( head_size=self.head_dim, - rotary_dim=rotary_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 269c36ab5b9c7..8a7a3dd771c38 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -260,8 +260,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): dummy_inputs=FuyuDummyInputsBuilder, ) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.": "vision_embed_tokens.", diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index dd5a74c8ed005..7304a728067f4 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -174,7 +174,6 @@ class GemmaAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index cb36e04824588..fe6ec5ff83dec 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -152,7 +152,6 @@ class Gemma2Attention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 73176eba95ed5..40f6d100c767e 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -176,7 +176,6 @@ class Gemma3Attention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 43c69e5e13992..45dfacd94431c 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -237,8 +237,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ) max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] - # Result in the max possible feature size (h:w = max_num_crops:1) - return ImageSize(height=50 * max_num_crops, width=50) + vision_config = self.get_hf_config().vision_config + native_size = vision_config.image_size + return ImageSize(height=native_size * max_num_crops, width=native_size) class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): @@ -483,8 +484,6 @@ class Gemma3MultiModalProjector(nn.Module): class Gemma3ForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index f4427c9fd1d10..4d446f51c2ecb 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -384,7 +384,6 @@ class Gemma3nAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 6ae76976eb46c..7036118ada084 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -463,7 +463,6 @@ class Gemma3nMultimodalEmbedder(nn.Module): class Gemma3nForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsTranscription ): - merge_by_field_config = True supported_languages = ISO639_1_SUPPORTED_LANGS packed_modules_mapping = { diff --git a/vllm/model_executor/models/glm.py b/vllm/model_executor/models/glm.py index a6991f8e43fef..26d7c29aae6e2 100644 --- a/vllm/model_executor/models/glm.py +++ b/vllm/model_executor/models/glm.py @@ -10,7 +10,8 @@ from .utils import PPMissingLayer class GlmForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - vllm_config.model_config.hf_config.partial_rotary_factor = 0.5 + hf_config = vllm_config.model_config.hf_config + hf_config.rope_parameters["partial_rotary_factor"] = 0.5 super().__init__(vllm_config=vllm_config, prefix=prefix) # Hack Llama model to fit HF format GLM implementation # Attention difference between GLM and Llama: diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 002cdb721e1db..2cd11e66c752b 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -78,10 +78,9 @@ class Glm4Attention(nn.Module): # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 - partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + config.rope_parameters.setdefault("partial_rotary_factor", 0.5) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = head_dim or hidden_size // self.total_num_heads - self.rotary_dim = self.head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -103,10 +102,8 @@ class Glm4Attention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.rotary_dim, max_position=max_position, rope_parameters=config.rope_parameters, - partial_rotary_factor=partial_rotary_factor, is_neox_style=False, ) self.attn = Attention( diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 5ba3c0a35928d..84989537da6e2 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -47,8 +47,10 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import maybe_get_vit_flash_attn_backend -from vllm.config import VllmConfig +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, +) +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import utils as dist_utils @@ -63,6 +65,9 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -93,7 +98,7 @@ from .interfaces import ( SupportsMultiModal, SupportsPP, ) -from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision +from .qwen2_vl import _create_qwen2vl_field_factory from .utils import ( AutoWeightsLoader, WeightsMapper, @@ -191,10 +196,15 @@ class Glm4vVisionMLP(nn.Module): hidden_features: int, bias: bool = False, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, @@ -248,12 +258,16 @@ class Glm4vVisionAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.tp_size = ( 1 if use_data_parallel else get_tensor_model_parallel_world_size() ) @@ -287,33 +301,13 @@ class Glm4vVisionAttention(nn.Module): disable_tp=use_data_parallel, ) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"GLM-4V does not support {self.attn_backend} backend now." - ) - - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] @@ -338,61 +332,33 @@ class Glm4vVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, - max_seqlen: int | None = None, # Only used for Flash Attention + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] q, k, v = self.split_qkv(x) - batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision( - qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb_cos, + rotary_pos_emb_sin, ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = self.flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, - ) - - context_layer = rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -406,9 +372,8 @@ class Glm4vVisionBlock(nn.Module): mlp_hidden_dim: int, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -420,17 +385,16 @@ class Glm4vVisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.mlp = Glm4vVisionMLP( dim, mlp_hidden_dim, bias=False, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -489,11 +453,16 @@ class Glm4vPatchMerger(nn.Module): d_model: int, context_dim: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, bias: bool = False, prefix: str = "", - use_data_parallel: bool = False, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.hidden_size = d_model self.proj = ColumnParallelLinear( self.hidden_size, @@ -649,19 +618,19 @@ class Glm4vVisionTransformer(nn.Module): vision_config: Glm4vVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() + assert multimodal_config is not None, "multimodal_config must be provided" + patch_size = vision_config.patch_size temporal_patch_size = vision_config.temporal_patch_size in_channels = vision_config.in_channels depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads - self.use_data_parallel = use_data_parallel self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size @@ -678,9 +647,9 @@ class Glm4vVisionTransformer(nn.Module): head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = get_rope( head_size=head_dim, - rotary_dim=head_dim // 2, max_position=8192, is_neox_style=True, + rope_parameters={"partial_rotary_factor": 0.5}, ) self.blocks = nn.ModuleList( [ @@ -690,9 +659,8 @@ class Glm4vVisionTransformer(nn.Module): mlp_hidden_dim=vision_config.out_hidden_size, norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -701,9 +669,9 @@ class Glm4vVisionTransformer(nn.Module): d_model=vision_config.out_hidden_size, context_dim=vision_config.intermediate_size, quant_config=quant_config, + multimodal_config=multimodal_config, bias=False, prefix=f"{prefix}.merger", - use_data_parallel=self.use_data_parallel, ) self.embeddings = Glm4vVisionEmbeddings(vision_config) @@ -723,7 +691,7 @@ class Glm4vVisionTransformer(nn.Module): self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + attn_backend_override=multimodal_config.mm_encoder_attn_backend, ) @property @@ -775,22 +743,22 @@ class Glm4vVisionTransformer(nn.Module): def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> int | None: + ) -> torch.Tensor | None: max_seqlen = None if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen def forward( self, x: torch.Tensor, - grid_thw: list[list[int]], + grid_thw: torch.Tensor | list[list[int]], ) -> torch.Tensor: - # Convert grid_thw to tensor (always expecting list format now) - grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long) + if isinstance(grid_thw, list): + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) # patchify x = x.to(device=self.device, dtype=self.dtype) @@ -805,7 +773,8 @@ class Glm4vVisionTransformer(nn.Module): cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] ).cumsum(dim=0, dtype=torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) @@ -1256,6 +1225,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): ) height = min(height, overrides.height) + num_frames = max(num_frames, 2) # GLM 4.6V requires 2 frames video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] for i in range(num_videos): @@ -1424,8 +1394,6 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): class Glm4vForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1465,18 +1433,12 @@ class Glm4vForConditionalGeneration( self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Glm4vVisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-5), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) if config.model_type == "glm4v": @@ -1550,7 +1512,6 @@ class Glm4vForConditionalGeneration( ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) @@ -1561,12 +1522,10 @@ class Glm4vForConditionalGeneration( self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" ) else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist()) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + merge_size = self.visual.spatial_merge_size - sizes = ( - torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) - // (merge_size * merge_size) - ).tolist() + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return image_embeds.split(sizes) def _process_video_input( @@ -1574,7 +1533,6 @@ class Glm4vForConditionalGeneration( ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) @@ -1590,15 +1548,11 @@ class Glm4vForConditionalGeneration( rope_type="rope_3d", ) else: - video_embeds = self.visual( - pixel_values_videos, grid_thw=grid_thw.tolist() - ) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = ( - torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) - // (merge_size * merge_size) - ).tolist() + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index c99f824e1bd4d..541d3b2beff83 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -282,13 +282,11 @@ class Glm4MoeAttention(nn.Module): prefix=f"{prefix}.o_proj", ) - partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + config.rope_parameters.setdefault("partial_rotary_factor", 0.5) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, - partial_rotary_factor=partial_rotary_factor, ) self.attn = Attention( self.num_heads, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 514082cf60ce2..ec5af94e297c1 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -561,8 +561,6 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): class GLM4VForCausalLM( ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): - merge_by_field_config = True - packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"], diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index f0a34c47da54c..f32ac2639435c 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -95,12 +95,13 @@ class GPTJAttention(nn.Module): scaling = self.head_size**-0.5 assert getattr(config, "rotary", True) assert config.rotary_dim % 2 == 0 + rope_parameters = getattr(config, "rope_parameters", {}) + rope_parameters["partial_rotary_factor"] = config.rotary_dim / self.head_size max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_size, - rotary_dim=config.rotary_dim, max_position=max_position_embeddings, - rope_parameters=getattr(config, "rope_parameters", None), + rope_parameters=rope_parameters, is_neox_style=False, ) self.attn = Attention( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index b9959682cbcef..c4d11b488f38b 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -89,16 +89,13 @@ class GPTNeoXAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.dense", ) - scaling = self.head_size**-0.5 - rotary_dim = int(self.head_size * config.rotary_pct) - assert rotary_dim % 2 == 0 max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_size, - rotary_dim=rotary_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, ) + scaling = self.head_size**-0.5 self.attn = Attention( self.num_heads, self.head_size, diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index cff16b7a7a8cd..6a92cf1533213 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -67,7 +67,6 @@ class OAIAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=config.max_position_embeddings, dtype=torch.float32, rope_parameters={ diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 76519c4660f15..82c945f5ad5ec 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -160,7 +160,6 @@ class GraniteAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, ) diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 1797adab8d146..a4e50f4086281 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -59,8 +59,8 @@ from vllm.multimodal.processing import ( ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.processor import cached_get_processor -from vllm.transformers_utils.tokenizer import cached_get_tokenizer +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.transformers_utils.processor import cached_processor_from_config from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip2 import Blip2QFormerModel @@ -564,7 +564,6 @@ class GraniteSpeechForConditionalGeneration( SupportsLoRA, SupportsTranscription, ): - merge_by_field_config = True supported_languages = ISO639_1_SUPPORTED_LANGS packed_modules_mapping = { @@ -862,7 +861,7 @@ class GraniteSpeechForConditionalGeneration( else: raise ValueError(f"Unsupported task type {task_type}") - tokenizer = cached_get_tokenizer(model_config.model) + tokenizer = cached_tokenizer_from_config(model_config) chat = [dict(role="user", content=user_prompt)] prompt = tokenizer.apply_chat_template( chat, @@ -886,7 +885,7 @@ class GraniteSpeechForConditionalGeneration( model_config: ModelConfig, ) -> int | None: """Get the number of audio tokens for an audio duration in sec.""" - processor = cached_get_processor(model_config.model) + processor = cached_processor_from_config(model_config) hop_length = processor.audio_processor.melspec_kwargs["hop_length"] proj_win_size = processor.audio_processor.projector_window_size ds_rate = processor.audio_processor.projector_downsample_rate diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index b038400a1262a..0b1064b6343e3 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -190,7 +190,6 @@ class GraniteMoeAttention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 1d9c2f5df4a55..3434716b83789 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -271,7 +271,6 @@ class GraniteMoeHybridAttention(nn.Module): if config.position_embedding_type == "rope": self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=config.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 181c4ed2dca5a..2aba626a7c737 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -14,12 +14,10 @@ from vllm.model_executor.layers.pooler import ( PoolerHead, PoolerNormalize, PoolingParamsUpdate, - get_prompt_lens, - get_prompt_token_ids, ) from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.tasks import PoolingTask -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config from vllm.v1.outputs import PoolerOutput from vllm.v1.pool.metadata import PoolingMetadata @@ -153,11 +151,11 @@ class GritLMMeanPool(nn.Module): hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, ) -> list[torch.Tensor] | torch.Tensor: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) + prompt_lens = pooling_metadata.prompt_lens instr_lens = torch.tensor( [ self._get_instruction_len(token_ids.cpu().numpy()) - for token_ids in get_prompt_token_ids(pooling_metadata) + for token_ids in pooling_metadata.get_prompt_token_ids() ], device="cpu", ) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 6f62a1d11e52e..0a2e5cf39ffd8 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -181,7 +181,6 @@ class Grok1Attention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index ccdfa3fe175f1..0e82e84c4edbe 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -199,7 +199,6 @@ class HunYuanAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, @@ -305,7 +304,6 @@ class HunYuanCrossAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py index 2950db571e6ee..be084f4ee0f8e 100644 --- a/vllm/model_executor/models/hunyuan_vision.py +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -62,6 +62,7 @@ from vllm.multimodal.inputs import ( from vllm.multimodal.parse import ( DictEmbeddingItems, ImageSize, + ModalityDataItems, MultiModalDataItems, MultiModalDataParser, ) @@ -501,6 +502,7 @@ class HunYuanVisionTransformer(nn.Module): cu_seqlens: list = [0] hidden_states = x.to(device=self.device, dtype=self.dtype) + # embeddings = patch_embeds + patch_pos_embed hidden_states = self.embeddings(hidden_states, grid_thw) for t, h, w in grid_thw: @@ -514,8 +516,14 @@ class HunYuanVisionTransformer(nn.Module): hidden_states = hidden_states.reshape(seq_len, -1) hidden_states = hidden_states.unsqueeze(0) - for layer_num, layer in enumerate(self.layers): - hidden_states = layer(hidden_states) + + # build per-image lengths once + split_lengths = [int(h) * int(w) for (_, h, w) in grid_thw] + for layer in self.layers: + # hidden_states: (1, T_total, D) + parts = hidden_states.split(split_lengths, dim=1) # list of (1, L_i, D) + parts = [layer(p) for p in parts] + hidden_states = torch.cat(parts, dim=1) # adapter split_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -562,7 +570,7 @@ def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), + image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True), ) @@ -570,7 +578,7 @@ class HunYuanVLMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: dict[str, torch.Tensor] | ModalityData[ImageItem], - ): + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -785,8 +793,6 @@ class HunYuanVLForConditionalGeneration( SupportsQuant, SupportsXDRoPE, ): - multimodal_cpu_fields = {"image_grid_thw"} - # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index db46353efde5c..3a083870e4b5a 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -592,8 +592,6 @@ class HCXVisionCAbstractor(nn.Module): dummy_inputs=HCXVisionDummyInputsBuilder, ) class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 9c5f9389e54bb..0eed464487865 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -338,6 +338,7 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + mm_kwargs = {"input_data_format": "channels_last", **mm_kwargs} processed_outputs = super()._call_hf_processor( prompt, mm_data, @@ -575,8 +576,6 @@ class Idefics3Model(nn.Module): dummy_inputs=Idefics3DummyInputsBuilder, ) class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index ccd5be42e65a9..cb99d57e8b8c7 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -53,6 +53,22 @@ The output embeddings must be one of the following formats: """ +def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor: + """ + A helper function to be used in the context of + [vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids][] + to provide a better error message. + """ + if is_multimodal is None: + raise ValueError( + "`embed_input_ids` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229." + ) + + return is_multimodal + + @runtime_checkable class SupportsMultiModal(Protocol): """The interface required for all multi-modal models.""" @@ -78,15 +94,15 @@ class SupportsMultiModal(Protocol): `multimodal_config.mm_encoder_tp_mode="data"`. """ - merge_by_field_config: ClassVar[bool] = False + merge_by_field_config: ClassVar[bool | None] = None """ - A flag that indicates which implementation of + [DEPRECATED] A flag that indicates which implementation of `vllm.multimodal.utils.group_mm_kwargs_by_modality` to use. """ - multimodal_cpu_fields: ClassVar[Set[str]] = frozenset() + multimodal_cpu_fields: ClassVar[Set[str] | None] = None """ - A set indicating CPU-only multimodal fields. + [DEPRECATED] A set indicating CPU-only multimodal fields. """ _processor_factory: ClassVar[_ProcessorFactories] @@ -111,13 +127,7 @@ class SupportsMultiModal(Protocol): the appearances of their corresponding multimodal data item in the input prompt. """ - if hasattr(self, "get_multimodal_embeddings"): - logger.warning_once( - "`get_multimodal_embeddings` for vLLM models is deprecated and will be " - "removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename " - "this method to `embed_multimodal`." - ) - return self.get_multimodal_embeddings(**kwargs) + ... def get_language_model(self) -> VllmModel: """ @@ -196,17 +206,10 @@ class SupportsMultiModal(Protocol): if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds - if is_multimodal is None: - raise ValueError( - "`embed_input_ids` now requires `is_multimodal` arg, " - "please update your model runner according to " - "https://github.com/vllm-project/vllm/pull/16229." - ) - return _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, - is_multimodal=is_multimodal, + is_multimodal=_require_is_multimodal(is_multimodal), ) @@ -260,7 +263,35 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: ... def supports_multimodal( model: type[object] | object, ) -> TypeIs[type[SupportsMultiModal]] | TypeIs[SupportsMultiModal]: - return getattr(model, "supports_multimodal", False) + res = getattr(model, "supports_multimodal", False) + + if res: + # We can remove this starting from v0.14 + merge_by_field_config = getattr(model, "merge_by_field_config", None) + if merge_by_field_config is False: + raise ValueError( + "`merge_by_field_config=False` is no longer effective, " + "please update your model to consider the new batching logic " + "in `group_mm_kwargs_by_modality` (refer to " + "https://github.com/vllm-project/vllm/issues/26149), " + "and then remove the override from your model." + ) + if merge_by_field_config is True: + logger.warning_once( + "`merge_by_field_config=True` is redundant, " + "please remove the override from your model." + ) + + multimodal_cpu_fields = getattr(model, "multimodal_cpu_fields", None) + if multimodal_cpu_fields is not None: + raise ValueError( + "`multimodal_cpu_fields` is no longer effective, " + "please set `keep_on_cpu=True` in `MultiModalFieldConfig` " + "(refer to https://github.com/vllm-project/vllm/pull/30181), " + "and then remove the override from your model." + ) + + return res def supports_multimodal_raw_input_only(model: type[object] | object) -> bool: @@ -837,6 +868,10 @@ class SupportsTranscription(Protocol): Transcription models can opt out of text generation by setting this to `True`. """ + supports_segment_timestamp: ClassVar[bool] = False + """ + Enables the segment timestamp option for supported models by setting this to `True`. + """ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 2c99fce8d918c..134a1d9483804 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -49,13 +49,7 @@ class VllmModel(Protocol[T_co]): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: """Apply token embeddings to `input_ids`.""" - if hasattr(self, "get_input_embeddings"): - logger.warning_once( - "`get_input_embeddings` for vLLM models is deprecated and will be " - "removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename " - "this method to `embed_input_ids`." - ) - return self.get_input_embeddings(input_ids) + ... def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ... @@ -68,14 +62,6 @@ def _check_vllm_model_init(model: type[object] | object) -> bool: def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool: model_embed_input_ids = getattr(model, "embed_input_ids", None) if not callable(model_embed_input_ids): - model_get_input_embeddings = getattr(model, "get_input_embeddings", None) - if callable(model_get_input_embeddings): - logger.warning( - "`get_input_embeddings` for vLLM models is deprecated and will be " - "removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename " - "this method to `embed_input_ids`." - ) - model.embed_input_ids = model_get_input_embeddings logger.warning( "The model (%s) is missing the `embed_input_ids` method.", model, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index c79934e121447..3ca8864618628 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -140,7 +140,6 @@ class InternLM2Attention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, ) diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index c2195fd0cb88d..18985cefbf5ea 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -509,8 +509,6 @@ class InternS1MultiModalProcessor(BaseMultiModalProcessor[InternS1ProcessingInfo class InternS1ForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ): - merge_by_field_config = True - # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index fccddf3a6b293..15f7d4f418e48 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -1074,8 +1074,6 @@ class InternVLMultiModalProcessor( dummy_inputs=InternVLDummyInputsBuilder, ) class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): - merge_by_field_config = True - supports_encoder_tp_data = True @classmethod diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 05a40837954d8..8bba7b62882f1 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -29,7 +29,7 @@ logger = init_logger(__name__) class JinaVLScorer(nn.Module): def __init__(self, model_config: "ModelConfig"): super().__init__() - config = model_config.hf_config + config = model_config.hf_config.get_text_config() head_dtype = model_config.head_dtype self.dense = ColumnParallelLinear( config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 8817601558148..fcf88953ba20f 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -9,7 +9,6 @@ from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from einops import rearrange from transformers import PretrainedConfig from transformers.activations import GELUActivation @@ -17,11 +16,10 @@ from transformers.feature_extraction_utils import BatchFeature from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int -from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, ) -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -32,6 +30,9 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -61,7 +62,6 @@ from vllm.multimodal.processing import ( PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -80,7 +80,6 @@ from .utils import ( is_pp_missing_parameter, maybe_prefix, ) -from .vision import get_vit_attn_backend logger = init_logger(__name__) @@ -344,20 +343,14 @@ def apply_rotary_pos_emb_flashatt( cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - elif current_platform.is_rocm(): - from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb - else: - # For other platforms, use PyTorch fallback - from vllm.model_executor.layers.rotary_embedding.common import ( - apply_rotary_emb_torch, - ) + apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, + ) - apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True) + q_embed = apply_rotary_emb(q, cos, sin) + k_embed = apply_rotary_emb(k, cos, sin) - q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) - k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed @@ -369,8 +362,8 @@ class KeyeSiglipAttention(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -408,34 +401,14 @@ class KeyeSiglipAttention(nn.Module): prefix=f"{prefix}.out_proj", ) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_heads, head_size=self.head_dim, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + num_kv_heads=self.num_kv_heads, + prefix=f"{prefix}.attn", + multimodal_config=multimodal_config, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Keye-VL does not support {self.attn_backend} backend now." - ) - - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } - def forward( self, hidden_states: torch.Tensor, @@ -450,8 +423,7 @@ class KeyeSiglipAttention(nn.Module): dim=-1, ) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - batch_size = q.shape[0] + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() if rope_emb is None: q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) @@ -482,38 +454,14 @@ class KeyeSiglipAttention(nn.Module): self.head_dim, ) - if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = self.flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - causal=False, - softmax_scale=self.scale, - ) - context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i) - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - - context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous() + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + context_layer = rearrange(context_layer, "b s h d -> b s (h d)") output, _ = self.out_proj(context_layer) return output @@ -547,8 +495,8 @@ class KeyeSiglipEncoderLayer(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -556,8 +504,8 @@ class KeyeSiglipEncoderLayer(nn.Module): self.self_attn = KeyeSiglipAttention( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn", - attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -601,8 +549,8 @@ class KeyeSiglipEncoder(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -614,8 +562,8 @@ class KeyeSiglipEncoder(nn.Module): KeyeSiglipEncoderLayer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.layers.{layer_idx}", - attn_backend_override=attn_backend_override, ) for layer_idx in range(config.num_hidden_layers) ] @@ -696,8 +644,8 @@ class KeyeSiglipVisionTransformer(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -707,8 +655,8 @@ class KeyeSiglipVisionTransformer(nn.Module): self.encoder = KeyeSiglipEncoder( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.encoder", - attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -779,16 +727,16 @@ class KeyeSiglipVisionModel(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.vision_model = KeyeSiglipVisionTransformer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.vision_model", - attn_backend_override=attn_backend_override, ) self.quant_config = quant_config @@ -1000,7 +948,7 @@ class KeyeMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: dict[str, torch.Tensor] | ModalityData[ImageItem], - ) -> ModalityDataItems[Any, Any]: + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -1017,7 +965,7 @@ class KeyeMultiModalDataParser(MultiModalDataParser): def _parse_video_data( self, data: dict[str, torch.Tensor] | ModalityData[VideoItem], - ) -> ModalityDataItems[Any, Any]: + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -1292,8 +1240,6 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): class BaseKeyeModule(nn.Module): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1331,16 +1277,11 @@ class BaseKeyeModule(nn.Module): self.config = config self.multimodal_config = multimodal_config - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = KeyeSiglipVisionModel( config.vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - attn_backend_override=attn_backend_override, ) self.mlp_AR = self._build_projector( diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 124e9c2afa217..2b04e3bd4b75b 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -333,7 +333,7 @@ class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: dict[str, torch.Tensor] | ModalityData[ImageItem], - ) -> ModalityDataItems[Any, Any]: + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -350,7 +350,7 @@ class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): def _parse_video_data( self, data: dict[str, torch.Tensor] | ModalityData[VideoItem], - ) -> ModalityDataItems[Any, Any]: + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 8167b82f32330..85267ccda8a91 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -298,8 +298,6 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): dummy_inputs=KimiVLDummyInputsBuilder, ) class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - supports_encoder_tp_data = True @classmethod diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index a4a994f97a2f8..142ad3d6d1d1a 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -143,7 +143,6 @@ class Lfm2Attention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index c8669de72dd09..70804e0a843e8 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -236,7 +236,6 @@ class Lfm2MoeAttention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/lightonocr.py b/vllm/model_executor/models/lightonocr.py index 9839e4f8f707e..353ee7806b1b1 100644 --- a/vllm/model_executor/models/lightonocr.py +++ b/vllm/model_executor/models/lightonocr.py @@ -28,7 +28,7 @@ from vllm.model_executor.models.utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache -from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.processing import ( BaseMultiModalProcessor, @@ -103,7 +103,7 @@ class LightOnOCRMultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingIn self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8f5a967cd422a..3507a2bc66c17 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -149,8 +149,6 @@ class LlamaAttention(nn.Module): if head_dim is None: head_dim = self.hidden_size // self.total_num_heads self.head_dim = head_dim - # Phi models introduced a partial_rotary_factor parameter in the config - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -261,11 +259,9 @@ class LlamaAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=getattr(config, "rope_parameters", None), is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor, ) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 423be45e80149..7b3da3e10ab8a 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -243,7 +243,6 @@ class Llama4Attention(nn.Module): self.rotary_emb = ( get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=is_neox_style, diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 0146b30579287..02f5b5ff639bd 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -28,7 +28,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.torchao import TorchAOConfig -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausalLM from vllm.model_executor.models.utils import extract_layer_index @@ -182,6 +185,12 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): self.config.vocab_size, scale=logit_scale ) + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + # Set MoE hyperparameters self.set_moe_parameters() @@ -211,6 +220,6 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): loader = AutoWeightsLoader( self, # lm_head is tied with target model (Llama4ForCausalLM) - skip_prefixes=(["lm_head."]), + skip_prefixes=([]), ) loader.load_weights(map(transform, weights)) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index c1fb2d4f4af7d..66a327bb7603d 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -506,8 +506,6 @@ def init_vision_tower_for_llava( dummy_inputs=LlavaDummyInputsBuilder, ) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index b995cac47ac1c..526846d0d9812 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -223,8 +223,6 @@ class LlavaNextMultiModalProcessor( dummy_inputs=LlavaDummyInputsBuilder, ) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 902c598c226f0..cd55cfec6cdec 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -299,8 +299,6 @@ class LlavaNextMultiModalProjector(nn.Module): dummy_inputs=LlavaNextVideoDummyInputsBuilder, ) class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 4e243ade68358..5aa8de7dc252e 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -479,8 +479,6 @@ class LlavaOnevisionMultiModalProjector(nn.Module): dummy_inputs=LlavaOnevisionDummyInputsBuilder, ) class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index d9b23811730d4..2d506978d266e 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -683,8 +683,6 @@ class MiDashengLMMultiModalProcessor( dummy_inputs=MiDashengLMDummyInputsBuilder, ) class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 67c462f4b25c4..f104018d3aa6c 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -277,7 +277,6 @@ class MiniCPMAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, ) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 0a2bcbd7f6084..c7a54cea21544 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -120,7 +120,6 @@ class MiniCPM3Attention(nn.Module): self.rotary_emb = get_rope( self.qk_rope_head_dim, - rotary_dim=self.qk_rope_head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 6d0ebf5c9825c..c45bdf95e7487 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1003,8 +1003,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): instantiated. """ - merge_by_field_config = True - supports_encoder_tp_data = True @classmethod diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index dd98e36ec0851..ee19288ae6852 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -199,9 +199,13 @@ class MiniMaxM2Attention(nn.Module): prefix=f"{prefix}.o_proj", ) + if ( + rope_parameters is not None + and "partial_rotary_factor" not in rope_parameters + ): + rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim self.rotary_emb = get_rope( self.head_dim, - rotary_dim=rotary_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, ) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 390de78cc27b4..4bfe3c391c26f 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -187,7 +187,6 @@ class MiniMaxText01Attention(nn.Module): num_heads: int, head_dim: int, num_kv_heads: int, - rotary_dim: int, max_position: int = 4096 * 32, rope_parameters: dict | None = None, sliding_window: int | None = None, @@ -245,7 +244,6 @@ class MiniMaxText01Attention(nn.Module): ) self.rotary_emb = get_rope( head_size=self.head_dim, - rotary_dim=rotary_dim, max_position=max_position, rope_parameters=rope_parameters, is_neox_style=True, @@ -290,6 +288,8 @@ class MiniMaxText01DecoderLayer(nn.Module): head_dim = getattr(config, "head_dim", None) if head_dim is None: head_dim = config.hidden_size // config.num_attention_heads + rotary_dim = getattr(config, "rotary_dim", head_dim) + config.rope_parameters["partial_rotary_factor"] = rotary_dim / head_dim if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int): max_position_embeddings = min( config.max_position_embeddings, config.max_model_len @@ -321,9 +321,6 @@ class MiniMaxText01DecoderLayer(nn.Module): hidden_size=self.hidden_size, num_heads=config.num_attention_heads, head_dim=head_dim, - rotary_dim=config.rotary_dim - if hasattr(config, "rotary_dim") - else head_dim, num_kv_heads=config.num_key_value_heads, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 0939a72ba53ec..e480454953df8 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -179,8 +179,6 @@ class MiniMaxVL01MultiModalProcessor( dummy_inputs=MiniMaxVL01DummyInputsBuilder, ) class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 1ddb470a0f93d..e9161e69e731b 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -423,8 +423,6 @@ def init_vision_tower_for_llava( class Mistral3ForConditionalGeneration( nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP ): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/mistral_large_3.py b/vllm/model_executor/models/mistral_large_3.py new file mode 100644 index 0000000000000..ff7e9b60c1d3c --- /dev/null +++ b/vllm/model_executor/models/mistral_large_3.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable + +import regex as re +import torch + +from vllm.model_executor.models.deepseek_v2 import DeepseekV3ForCausalLM + + +class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM): + # fmt: off + remapping = { + r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501 + r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501 + r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501 + r"norm\.weight": "model.norm.weight", # noqa: E501 + r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501 + r"output\.weight": "lm_head.weight", # noqa: E501 + } + # fmt: on + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + return super().load_weights(map(self._remap_mistral_to_ds, weights)) + + def _remap_mistral_to_ds( + self, weight: tuple[str, torch.Tensor] + ) -> tuple[str, torch.Tensor]: + """Remap Mistral parameters to DeepseekV2 parameters.""" + name, loaded_weight = weight + + for k, v in self.remapping.items(): + match = re.fullmatch(k, name) + if match: + name = re.sub(k, v, name) + break + else: + raise ValueError(f"Cannot remap {name}") + + # Remapping scale names. We could do this in the regex above but it + # would triple the number of lines for most layers. + if name.endswith(".qscale_act"): + name = re.sub(r"\.qscale_act$", ".input_scale", name) + elif name.endswith(".qscale_weight"): + name = re.sub(r"\.qscale_weight$", ".weight_scale", name) + + return name, loaded_weight diff --git a/vllm/model_executor/models/mistral_large_3_eagle.py b/vllm/model_executor/models/mistral_large_3_eagle.py new file mode 100644 index 0000000000000..37cd4324e53d9 --- /dev/null +++ b/vllm/model_executor/models/mistral_large_3_eagle.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from functools import partial + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2DecoderLayer, + DeepseekV2Model, +) +from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM + +from .interfaces import SupportsMultiModal +from .utils import make_empty_intermediate_tensors_factory, maybe_prefix + +logger = init_logger(__name__) + + +@support_torch_compile +class EagleMistralLarge3Model(DeepseekV2Model): + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0 + ): + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.vllm_config = vllm_config + + self.vocab_size = config.vocab_size + + assert get_pp_group().world_size == 1 + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) + for i in range(self.config.num_hidden_layers) + ] + ) + self.start_layer = 0 + self.end_layer = self.config.num_hidden_layers + + self.fc = RowParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + bias=False, + input_is_parallel=False, + quant_config=quant_config, + return_bias=False, + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_input_ids(input_ids) + inputs_embeds = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) + output = super().forward( + input_ids, positions, intermediate_tensors=None, inputs_embeds=inputs_embeds + ) + assert isinstance(output, torch.Tensor) + return output + + +class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM): + remapping = MistralLarge3ForCausalLM.remapping | { + r"eagle_linear\.weight": r"model.fc.weight", + r"eagle_linear\.qscale_act": r"model.fc.input_scale", + r"eagle_linear\.qscale_weight": r"model.fc.weight_scale", + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + vllm_config.model_config = vllm_config.speculative_config.draft_model_config + # draft model quantization config may differ from target model + self.quant_config = VllmConfig.get_quantization_config( + vllm_config.speculative_config.draft_model_config, vllm_config.load_config + ) + vllm_config.quant_config = self.quant_config + self.model_cls = partial( + EagleMistralLarge3Model, start_layer_id=target_layer_num + ) + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def get_language_model(self) -> torch.nn.Module: + return self.model + + embed_input_ids = SupportsMultiModal.embed_input_ids # type: ignore + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.model(input_ids, positions, hidden_states, inputs_embeds) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Pretend we've loaded the embedding and lm_head weights + # (later copied from target model) + return super().load_weights(weights) | { + "model.embed_tokens.weight", + "lm_head.weight", + } diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 50ec57e7a8053..e170c530ca29f 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -206,7 +206,6 @@ class MixtralAttention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 286859d188d34..fe963cc6644fb 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -295,11 +295,11 @@ class Llama4VisionAttention(nn.Module): rope_parameters = { "rope_type": "mllama4", "rope_theta": config.rope_parameters["rope_theta"], + "partial_rotary_factor": 0.5, } self.rotary_emb = get_rope( head_size=self.head_dim, - rotary_dim=config.hidden_size // config.num_attention_heads // 2, # number of image patches max_position=(config.image_size // config.patch_size) ** 2, rope_parameters=rope_parameters, @@ -741,8 +741,6 @@ class Llama4ForConditionalGeneration( SupportsEagle3, SupportsLoRA, ): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 743bc23d9876f..4655ffa7b2f61 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.pooler import ( PoolingParamsUpdate, PoolingType, ) -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors @@ -62,19 +62,6 @@ class ModernBertEmbeddings(nn.Module): return embeddings -class ModernBertRotaryEmbedding(RotaryEmbedding): - def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float): - super().__init__( - head_size=head_size, - rotary_dim=dim, - max_position_embeddings=config.max_position_embeddings, - base=base, - is_neox_style=True, - dtype=torch.float16, - ) - self.config = config - - class ModernBertAttention(nn.Module): def __init__(self, config: ModernBertConfig, layer_id: int | None = None): super().__init__() @@ -95,19 +82,32 @@ class ModernBertAttention(nn.Module): bias=config.attention_bias, ) - sliding_window = None - if layer_id % config.global_attn_every_n_layers != 0: - sliding_window = config.local_attention // 2 - rope_theta = ( - config.local_rope_theta - if config.local_rope_theta is not None - else config.global_rope_theta - ) + if layer_types := getattr(config, "layer_types", None): + # Transformers v5 + layer_type = layer_types[layer_id] + rope_parameters = config.rope_parameters[layer_type] + sliding_window: int | None = None + if layer_type == "sliding_attention": + sliding_window = config.local_attention // 2 else: - rope_theta = config.global_rope_theta + # Transformers v4 + sliding_window = None + if layer_id % config.global_attn_every_n_layers != 0: + sliding_window = config.local_attention // 2 + rope_theta = ( + config.local_rope_theta + if config.local_rope_theta is not None + else config.global_rope_theta + ) + else: + rope_theta = config.global_rope_theta + rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} - self.rotary_emb = ModernBertRotaryEmbedding( - config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta + self.rotary_emb = get_rope( + head_size=self.head_dim, + max_position=config.max_position_embeddings, + rope_parameters=rope_parameters, + dtype=torch.float16, ) self.attn = EncoderOnlyAttention( self.num_heads, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 7b53299cccbe4..71c6b1aa2e814 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -433,7 +433,6 @@ class MolmoAttention(nn.Module): # Rotary embeddings. self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, ) @@ -1354,8 +1353,6 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): class MolmoForCausalLM( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant ): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ # vision backbone mapping diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 11beeddabe307..6dfab595e5b92 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -52,7 +52,6 @@ from vllm.multimodal.evs import ( from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, MultiModalKwargsItems, VideoItem, ) @@ -73,12 +72,8 @@ from vllm.multimodal.processing import ( ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.tokenizers import TokenizerLike +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.transformers_utils.configs.radio import RadioConfig -from vllm.transformers_utils.tokenizer import ( - cached_tokenizer_from_config, - encode_tokens, -) from vllm.utils.tensor_schema import TensorSchema, TensorShape from .utils import _merge_multimodal_embeddings @@ -457,14 +452,12 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): # Pre-tokenize special tokens for video processing # to avoid repeated tokenization - self._img_start_token_ids = encode_tokens( - tokenizer, IMG_START, add_special_tokens=False + self._img_start_token_ids = tokenizer.encode( + IMG_START, add_special_tokens=False ) - self._img_end_token_ids = encode_tokens( - tokenizer, IMG_END, add_special_tokens=False - ) - self._img_context_token_ids = encode_tokens( - tokenizer, IMG_CONTEXT, add_special_tokens=False + self._img_end_token_ids = tokenizer.encode(IMG_END, add_special_tokens=False) + self._img_context_token_ids = tokenizer.encode( + IMG_CONTEXT, add_special_tokens=False ) @property @@ -855,17 +848,18 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if "image_num_patches" in out_mm_kwargs: - image_num_patches = out_mm_kwargs["image_num_patches"] + out_mm_data = out_mm_kwargs.get_data() + if "image_num_patches" in out_mm_data: + image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) image_num_patches = image_num_patches.tolist() - elif "image_embeds" in out_mm_kwargs: + elif "image_embeds" in out_mm_data: # to compute num_patches (similar to Qwen2-VL) - image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + image_num_patches = [None] * len(out_mm_data["image_embeds"]) else: image_num_patches = [] @@ -1122,8 +1116,6 @@ class NanoNemotronVLDummyInputsBuilder( class NemotronH_Nano_VL_V2( nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning ): - merge_by_field_config = True - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): @@ -1182,14 +1174,12 @@ class NemotronH_Nano_VL_V2( # Pre-tokenize special tokens for video processing # to avoid repeated tokenization tokenizer = cached_tokenizer_from_config(vllm_config.model_config) - self._img_start_token_ids = encode_tokens( - tokenizer, IMG_START, add_special_tokens=False + self._img_start_token_ids = tokenizer.encode( + IMG_START, add_special_tokens=False ) - self._img_end_token_ids = encode_tokens( - tokenizer, IMG_END, add_special_tokens=False - ) - self._img_context_token_ids = encode_tokens( - tokenizer, IMG_CONTEXT, add_special_tokens=False + self._img_end_token_ids = tokenizer.encode(IMG_END, add_special_tokens=False) + self._img_context_token_ids = tokenizer.encode( + IMG_CONTEXT, add_special_tokens=False ) def pixel_shuffle(self, x, scale_factor=0.5): diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index ffba6c9dfe739..21605015c470b 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -178,7 +178,6 @@ class NemotronAttention(nn.Module): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.partial_rotary_factor = config.partial_rotary_factor self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( @@ -200,10 +199,8 @@ class NemotronAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, - partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention( self.num_heads, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index baeb901bbb05a..2d9dfbd3e7688 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -83,6 +83,7 @@ class NemotronHMLP(nn.Module): def __init__( self, config: NemotronHConfig, + hidden_size: int, intermediate_size: int, quant_config: QuantizationConfig | None = None, bias: bool = False, @@ -93,7 +94,7 @@ class NemotronHMLP(nn.Module): super().__init__() self.up_proj = ColumnParallelLinear( - input_size=config.hidden_size, + input_size=hidden_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, @@ -102,7 +103,7 @@ class NemotronHMLP(nn.Module): ) self.down_proj = RowParallelLinear( input_size=intermediate_size, - output_size=config.hidden_size, + output_size=hidden_size, bias=bias, quant_config=quant_config, reduce_results=reduce_results, @@ -135,6 +136,10 @@ class NemotronHMoE(nn.Module): self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts + self.use_latent_moe: bool = getattr(config, "moe_latent_size", None) is not None + self.moe_hidden_size: int = ( + config.moe_latent_size if self.use_latent_moe else config.hidden_size + ) self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe @@ -172,6 +177,7 @@ class NemotronHMoE(nn.Module): self.shared_experts = NemotronHMLP( config=config, + hidden_size=config.hidden_size, intermediate_size=intermediate_size, quant_config=quant_config, reduce_results=False, @@ -180,10 +186,12 @@ class NemotronHMoE(nn.Module): ) self.experts = SharedFusedMoE( - shared_experts=self.shared_experts, + # TODO: make it possible for shared experts to have + # different input in SharedFusedMoE + shared_experts=self.shared_experts if not self.use_latent_moe else None, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, + hidden_size=self.moe_hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, @@ -201,6 +209,32 @@ class NemotronHMoE(nn.Module): is_sequence_parallel=self.is_sequence_parallel, ) + if self.use_latent_moe: + # TODO: check if using ReplicatedLinear is better than + # ColumnParallelLinear + all_gather + self.fc1_latent_proj = ColumnParallelLinear( + input_size=config.hidden_size, + output_size=self.moe_hidden_size, + bias=config.mlp_bias, + quant_config=quant_config, + disable_tp=self.is_sequence_parallel, + # We need to gather the output to prepare input for moe + gather_output=True, + prefix=f"{prefix}.fc1_latent_proj", + ) + self.fc2_latent_proj = ReplicatedLinear( + input_size=self.moe_hidden_size, + output_size=config.hidden_size, + bias=config.mlp_bias, + quant_config=quant_config, + disable_tp=self.is_sequence_parallel, + prefix=f"{prefix}.fc2_latent_proj", + ) + + else: + self.fc1_latent_proj = None + self.fc2_latent_proj = None + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -210,12 +244,20 @@ class NemotronHMoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + shared_output = None + if self.use_latent_moe: + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + hidden_states, _ = self.fc1_latent_proj(hidden_states) fused_moe_out = self.experts( hidden_states=hidden_states, router_logits=router_logits ) - shared_output, final_hidden_states = fused_moe_out + if self.use_latent_moe: + _, final_hidden_states = fused_moe_out + else: + shared_output, final_hidden_states = fused_moe_out # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. @@ -225,6 +267,13 @@ class NemotronHMoE(nn.Module): assert shared_output is not None shared_output *= 1.0 / self.routed_scaling_factor + # TODO: currently latent up_proj is done before all-reduce for simplicity. + # if and when shared experts will be part of SharedFusedMoE, + # we should do the up_proj after all-reduce, + # to have the all-reduce in the smaller latent dimension. + if self.use_latent_moe: + final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states) + if self.shared_experts is not None: assert shared_output is not None final_hidden_states += shared_output @@ -268,6 +317,7 @@ class NemotronHMLPDecoderLayer(nn.Module): self.mixer = NemotronHMLP( config, + hidden_size=config.hidden_size, intermediate_size=intermediate_size, quant_config=quant_config, bias=config.mlp_bias, @@ -846,5 +896,5 @@ class NemotronHForCausalLM( return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + loader = AutoWeightsLoader(self, skip_prefixes=["mtp"]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 9d968dee87114..19a942a5277cc 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -118,11 +118,9 @@ class DeciLMAttention(LlamaAttention): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor, ) diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index a57668b21fb86..391980fc61f9e 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -358,8 +358,6 @@ class NemotronVLProcessingInfo(BaseInternVLProcessingInfo): dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo], ) class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): - merge_by_field_config = True - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 3bbb4dd242262..dd7c27f10c531 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -102,7 +102,6 @@ class OlmoAttention(nn.Module): # Rotary embeddings. self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, ) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 88e9c2d8541a1..b030c94b54cd5 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -146,7 +146,6 @@ class Olmo2Attention(nn.Module): rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=rope_parameters, ) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 1376583a99725..a5a926151c5c9 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -171,7 +171,6 @@ class OlmoeAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/opencua.py b/vllm/model_executor/models/opencua.py index 4338918663378..35a6a78f653ef 100644 --- a/vllm/model_executor/models/opencua.py +++ b/vllm/model_executor/models/opencua.py @@ -23,7 +23,7 @@ from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalFieldConfig, - MultiModalKwargs, + MultiModalKwargsItems, ) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.processing import ( @@ -153,7 +153,7 @@ class OpenCUAMultiModalProcessor(BaseMultiModalProcessor[OpenCUAProcessingInfo]) self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) @@ -201,9 +201,6 @@ class OpenCUADummyInputsBuilder(Qwen2VLDummyInputsBuilder): dummy_inputs=OpenCUADummyInputsBuilder, ) class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): - merge_by_field_config = True - multimodal_cpu_fields = {"image_grid_thw"} - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -243,18 +240,12 @@ class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ) if multimodal_config.get_limit_per_prompt("image"): - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = OpenCUAVisionTransformer( vision_config=config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self.quant_config, + multimodal_config=self.multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) else: self.visual = None diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index bddd9fa50957a..47abd7bf0b68a 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -352,7 +352,6 @@ class OpenPanguMLAAttention(nn.Module): } self.rotary_emb = get_rope( qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, is_neox_style=False, @@ -525,7 +524,6 @@ class OpenPanguEmbeddedAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=is_neox_style, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 544a44ed54681..9d9066c4ba619 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -135,7 +135,6 @@ class OrionAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, ) diff --git a/vllm/model_executor/models/ouro.py b/vllm/model_executor/models/ouro.py index dcae92ed20881..829148b4c1fb7 100644 --- a/vllm/model_executor/models/ouro.py +++ b/vllm/model_executor/models/ouro.py @@ -166,7 +166,6 @@ class OuroAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=config.rope_parameters, dual_chunk_attention_config=dual_chunk_attention_config, diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index a0fab820720fb..0691bbc615be9 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -414,8 +414,6 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): dummy_inputs=OvisDummyInputsBuilder, ) class Ovis(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 85f37cfea10b1..945138b5972f7 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -10,8 +10,7 @@ import torch import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig @@ -104,18 +103,16 @@ class VisualTokenizer(torch.nn.Module): config: PretrainedConfig, visual_vocab_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config self.vit = self._init_backbone( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.vit", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) # reserved tokens for INDICATOR_IDS head_dim = visual_vocab_size - len(INDICATOR_IDS) @@ -133,18 +130,16 @@ class VisualTokenizer(torch.nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: QuantizationConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): model_type = config.model_type if model_type == "siglip2_navit": return Siglip2NavitModel( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=prefix, - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") @@ -456,8 +451,6 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]) dummy_inputs=Ovis2_5DummyInputsBuilder, ) class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -470,17 +463,12 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): prefix=maybe_prefix(prefix, "llm"), ) - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual_tokenizer = VisualTokenizer( config=config.vit_config, visual_vocab_size=config.visual_vocab_size, + multimodal_config=multimodal_config, quant_config=quant_config, prefix=f"{prefix}.visual_tokenizer", - attn_backend_override=attn_backend_override, ) self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 5256d8ba7fd86..56565266c0dcc 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -22,8 +22,7 @@ from typing import Annotated, Literal import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from transformers import BatchFeature, PretrainedConfig from transformers.activations import GELUActivation from transformers.modeling_outputs import ( @@ -32,13 +31,10 @@ from transformers.modeling_outputs import ( from transformers.utils import torch_int from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, +from vllm.attention.layers.mm_encoder_attention import ( + MMEncoderAttention, ) -from vllm.attention.ops.vit_attn_wrappers import ( - vit_flash_attn_wrapper, -) -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -51,7 +47,7 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding.common import ( - dispatch_rotary_emb_function, + ApplyRotaryEmb, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, @@ -62,7 +58,7 @@ from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, - MultiModalKwargs, + MultiModalKwargsItems, ) from vllm.multimodal.parse import ( ImageProcessorItems, @@ -134,47 +130,6 @@ def smart_resize( return h_bar, w_bar -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) - - -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False -) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch) - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - output = rotary_emb_function(t_, cos, sin).type_as(t) - return output - - class PaddleOCRVLProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config() @@ -307,7 +262,7 @@ class PaddleOCRVLMultiModalProcessor( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) hf_config = self.info.get_hf_config() @@ -578,9 +533,8 @@ class SiglipAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -608,18 +562,16 @@ class SiglipAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - - self.attn_backend = attn_backend - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + multimodal_config=multimodal_config, + prefix=f"{prefix}.attn", + ) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, ) - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: seq_len, bs, _ = qkv.shape @@ -662,47 +614,23 @@ class SiglipAttention(nn.Module): if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb.cos(), + rotary_pos_emb.sin(), + ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - if max_seqlen is None: - raise ValueError("Flash attention backend requires max_seqlen.") - context_layer = vit_flash_attn_wrapper( - q, - k, - v, - cu_seqlens, - max_seqlen, - batch_size, - self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, - ) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - rearrange(tensor, "b s h d -> b h s d") - for tensor in (q_i, k_i, v_i) - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() - else: - raise RuntimeError( - f"PaddleOCR-VL does not support {self.attn_backend} backend now." - ) + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + context_layer = rearrange(context_layer, "b s h d -> b s (h d)") output, _ = self.out_proj(context_layer) - output = rearrange(output, "s b d -> b s d") return output @@ -774,10 +702,8 @@ class SiglipEncoderLayer(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - *, - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -787,9 +713,8 @@ class SiglipEncoderLayer(nn.Module): num_heads=config.num_attention_heads, projection_size=config.hidden_size, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn", - attn_backend=attn_backend, - attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -832,14 +757,18 @@ class SiglipEncoder(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads + + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend if multimodal_config else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -858,9 +787,8 @@ class SiglipEncoder(nn.Module): SiglipEncoderLayer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.layers.{layer_idx}", - attn_backend=self.attn_backend, - attn_backend_override=attn_backend_override, ) for layer_idx in range(config.num_hidden_layers) ] @@ -941,8 +869,8 @@ class SiglipVisionTransformer(nn.Module): self, config: PretrainedConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -952,8 +880,8 @@ class SiglipVisionTransformer(nn.Module): self.encoder = SiglipEncoder( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.encoder", - attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -991,16 +919,16 @@ class SiglipVisionModel(nn.Module): self, config, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.vision_model = SiglipVisionTransformer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.vision_model", - attn_backend_override=attn_backend_override, ) self.quant_config = quant_config @@ -1103,8 +1031,6 @@ class SiglipVisionModel(nn.Module): dummy_inputs=PaddleOCRVLDummyInputsBuilder, ) class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsMRoPE): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.": "language_model.model.", @@ -1121,17 +1047,11 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support self.config = config self.multimodal_config = multimodal_config - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) - self.visual = SiglipVisionModel( config=config.vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - attn_backend_override=attn_backend_override, ) self.mlp_AR = Projector(config, config.vision_config) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index ec5d0fa6226dd..67240c6e71249 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -40,7 +40,6 @@ from .siglip import SiglipVisionModel from .utils import ( AutoWeightsLoader, WeightsMapper, - flatten_bn, init_vllm_registered_model, maybe_prefix, ) @@ -327,9 +326,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP return None if pixel_values is not None: - pixel_values = flatten_bn(pixel_values, concat=True) - h = w = self.config.vision_config.image_size + return PaliGemmaImagePixelInputs( type="pixel_values", data=pixel_values, @@ -337,8 +335,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ) if image_embeds is not None: - image_embeds = flatten_bn(image_embeds, concat=True) - return PaliGemmaImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 795cd25f16753..b644603c5baa1 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -106,7 +106,6 @@ class PersimmonAttention(nn.Module): self.num_heads = self.total_num_heads // tensor_parallel_world_size self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - self.partial_rotary_factor = config.partial_rotary_factor self.is_causal = True assert (self.head_dim * self.total_num_heads) == self.hidden_size @@ -135,10 +134,8 @@ class PersimmonAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, - partial_rotary_factor=self.partial_rotary_factor, ) self.scaling = self.head_dim**-0.5 self.attn = Attention( diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 70016d9ed246c..e01e9d47c545c 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -84,19 +84,18 @@ class PhiAttention(nn.Module): prefix: str = "", ): super().__init__() - self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.total_num_heads + self.head_size = self.hidden_size // config.num_attention_heads tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() - assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = self.total_num_heads // tensor_model_parallel_world_size + assert config.num_attention_heads % tensor_model_parallel_world_size == 0 + self.num_heads = config.num_attention_heads // tensor_model_parallel_world_size # pylint: disable=C0103 self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_size, - self.total_num_heads, + config.num_attention_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", @@ -109,16 +108,10 @@ class PhiAttention(nn.Module): ) scaling = self.head_size**-0.5 - rotary_dim = int( - config.partial_rotary_factor - * (config.hidden_size // config.num_attention_heads) - ) - assert rotary_dim % 2 == 0 max_position_embeddings = getattr(config, "max_position_embeddings", 2048) self.rotary_emb = get_rope( self.head_size, - rotary_dim=rotary_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, ) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 384572217bc19..900b0eade308c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -64,6 +64,7 @@ from .interfaces import ( SupportsMultiModal, SupportsPP, SupportsQuant, + _require_is_multimodal, ) from .utils import ( AutoWeightsLoader, @@ -562,8 +563,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): dummy_inputs=Phi3VDummyInputsBuilder, ) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", @@ -689,17 +688,10 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds - if is_multimodal is None: - raise ValueError( - "`embed_input_ids` now requires `is_multimodal` arg, " - "please update your model runner according to " - "https://github.com/vllm-project/vllm/pull/16229." - ) - return _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, - is_multimodal=is_multimodal, + is_multimodal=_require_is_multimodal(is_multimodal), ) def forward( diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py deleted file mode 100644 index 0f1230a55bae6..0000000000000 --- a/vllm/model_executor/models/phi4_multimodal.py +++ /dev/null @@ -1,1447 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, TypeAlias - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import ( - BatchFeature, - Phi4MultimodalAudioConfig, - Phi4MultimodalConfig, - Phi4MultimodalFeatureExtractor, - Phi4MultimodalImageProcessorFast, -) -from transformers import Phi4MultimodalProcessor as Phi4MMProcessor -from transformers.models.phi4_multimodal.modeling_phi4_multimodal import ( - Phi4MultimodalAudioConvModule, - Phi4MultimodalAudioNemoConvSubsampling, - Phi4MultimodalAudioRelativeAttentionBias, - adaptive_enc_mask, - unfold_tensor, -) - -from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import ( - divide, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalFieldConfig, - MultiModalKwargsItems, - NestedTensors, -) -from vllm.multimodal.parse import ( - AudioProcessorItems, - ImageEmbeddingItems, - ImageProcessorItems, - ImageSize, - MultiModalDataItems, - MultiModalDataParser, -) -from vllm.multimodal.processing import ( - BaseMultiModalProcessor, - BaseProcessingInfo, - PromptReplacement, - PromptUpdate, -) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal -from .utils import ( - AutoWeightsLoader, - WeightsMapper, - init_vllm_registered_model, - maybe_prefix, -) - -_AUDIO_MAX_SOUNDFILE_SIZE = 241_000 - - -def _get_padding_size( - orig_width: int, orig_height: int, target_height: int, target_width: int -): - ratio_width = target_width / orig_width - ratio_height = target_height / orig_height - - if ratio_width < ratio_height: - padding_width = 0 - padding_height = target_height - int(orig_height * ratio_width) - else: - padding_width = target_width - int(orig_width * ratio_height) - padding_height = 0 - return padding_height, padding_width - - -class Phi4MMProjector(nn.Module): - def __init__(self, input_size: int, hidden_size: int): - super().__init__() - self.up = ColumnParallelLinear(input_size, hidden_size) - self.down = RowParallelLinear(hidden_size, hidden_size) - self.act = get_act_fn("gelu") - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x, _ = self.up(x) - x = self.act(x) - x, _ = self.down(x) - return x - - -class Phi4MMImageEmbedding(nn.Module): - """Image embedding.""" - - def __init__(self, config: Phi4MultimodalConfig): - super().__init__() - self.config = config - self.layer_idx = config.vision_config.feature_layer - self.crop_size = config.vision_config.crop_size - self.image_dim_out = config.vision_config.hidden_size - - n_patches = config.vision_config.image_size // config.vision_config.patch_size - if n_patches % 2 != 0: - self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) - n_patches += 1 - self.num_img_tokens = (n_patches // 2) ** 2 - - num_hidden_layers = ( - config.vision_config.num_hidden_layers + self.layer_idx + 1 - if self.layer_idx < 0 - else self.layer_idx + 1 - ) - self.img_processor = Idefics2VisionTransformer( - config.vision_config, - require_post_norm=False, - num_hidden_layers_override=num_hidden_layers, - ) - self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) - self.img_projection = Phi4MMProjector(self.image_dim_out, config.hidden_size) - self.global_img_feature_extensor = nn.Parameter( - torch.zeros([1, 1, self.image_dim_out]) - ) - self.sub_img_feature_extensor = nn.Parameter( - torch.zeros([1, 1, 1, self.image_dim_out]) - ) - - def get_img_features( - self, - img_embeds: torch.FloatTensor, - attention_mask: torch.Tensor | None = None, - ) -> torch.FloatTensor: - img_feature = self.img_processor( - img_embeds, patch_attention_mask=attention_mask - ) - - patch_feature = img_feature - # reshape to 2D tensor - width = int(math.sqrt(patch_feature.size(1))) - patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) - # convert to NCHW - patch_feature = patch_feature.permute(0, 3, 1, 2) - if getattr(self, "img_processor_padding", None) is not None: - patch_feature = self.img_processor_padding(patch_feature) - patch_feature = self.image_token_compression(patch_feature) - # convert to NHWC - patch_feature = patch_feature.permute(0, 2, 3, 1) - patch_feature = patch_feature.view( - -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1) - ) - return patch_feature - - def forward( - self, - image_pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor | None = None, - image_attention_mask: torch.Tensor | None = None, - ) -> torch.FloatTensor: - image_pixel_values = image_pixel_values.to( - self.img_processor.embeddings.patch_embedding.weight.dtype - ) - - target_device = self.img_projection.up.bias.device - target_dtype = self.img_projection.up.bias.dtype - - batch_size = image_pixel_values.shape[0] - - img_features = self.get_img_features( - image_pixel_values.flatten(0, 1), - attention_mask=image_attention_mask.flatten(0, 1).to( - dtype=bool, device=target_device - ), - ) - base_feat_size = int(np.sqrt(img_features.shape[1])) - img_features = img_features.view( - batch_size, -1, base_feat_size**2, self.image_dim_out - ) - image_sizes = image_sizes.view(-1, 2) - - output_imgs = [] - for idx in range(batch_size): - height, width = image_sizes[idx] - height_ratio = height // self.crop_size - width_ratio = width // self.crop_size - area_ratio = height_ratio * width_ratio - - global_img = img_features[idx, :1] - global_img = global_img.reshape( - 1, base_feat_size, base_feat_size, self.image_dim_out - ).contiguous() - temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, base_feat_size, 1, 1 - ) - global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape( - 1, -1, self.image_dim_out - ) - - sub_img = img_features[idx, 1:] - sub_img = sub_img[:area_ratio] - sub_img = ( - sub_img.reshape( - height_ratio, - width_ratio, - base_feat_size, - base_feat_size, - self.image_dim_out, - ) - .transpose(1, 2) - .reshape( - 1, - height_ratio * base_feat_size, - width_ratio * base_feat_size, - self.image_dim_out, - ) - .contiguous() - ) - - if image_attention_mask is not None: - reshaped_image_attention_mask = ( - image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] - .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) - .transpose(1, 2) - .reshape( - 1, height_ratio * base_feat_size, width_ratio * base_feat_size - ) - ) - useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) - useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) - sub_img = sub_img[:, :useful_height, :useful_width] - temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, useful_height, 1, 1 - ) - else: - temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, height_ratio * base_feat_size, 1, 1 - ) - - sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape( - 1, -1, self.image_dim_out - ) - - # Merge global and sub - output_imgs.append( - torch.cat( - [sub_img, self.global_img_feature_extensor, global_img], dim=1 - ) - ) - - img_set_tensor = [] - for output_img in output_imgs: - output_img = output_img.to(device=target_device, dtype=target_dtype) - img_feature_proj = self.img_projection(output_img) - img_set_tensor.append(img_feature_proj.flatten(0, 1)) - - return img_set_tensor - - -class Phi4MultimodalAudioMLP(nn.Module): - def __init__( - self, - config: Phi4MultimodalAudioConfig, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.layer_norm = nn.LayerNorm(config.hidden_size) - self.act_fn = MulAndSilu() - self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, - [config.intermediate_size] * 2, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.down_proj", - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.layer_norm(hidden_states) - hidden_states, _ = self.gate_up_proj(hidden_states) - hidden_states = self.act_fn(hidden_states) - hidden_states, _ = self.down_proj(hidden_states) - return hidden_states - - -class Phi4MultimodalAudioAttention(nn.Module): - def __init__( - self, - config: Phi4MultimodalAudioConfig, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.total_num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.total_num_heads - if self.head_dim * self.total_num_heads != self.embed_dim: - raise ValueError( - "embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - - self.qkv_proj = QKVParallelLinear( - hidden_size=self.embed_dim, - head_size=self.head_dim, - total_num_heads=self.total_num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - input_size=self.embed_dim, - output_size=self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.num_heads = divide(self.total_num_heads, self.tp_size) - - def split_attn_mask(self, attention_mask: torch.Tensor) -> torch.Tensor: - start_idx = self.num_heads * self.tp_rank - end_idx = self.num_heads * (self.tp_rank + 1) - return attention_mask[:, start_idx:end_idx] - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - ) -> torch.Tensor: - qkv_states, _ = self.qkv_proj(hidden_states) - query, key, value = qkv_states.chunk(3, dim=-1) - - bsz, seq_len, _ = query.size() - query = query.view(bsz, seq_len, self.num_heads, self.head_dim) - key = key.view(bsz, seq_len, self.num_heads, self.head_dim) - value = value.view(bsz, seq_len, self.num_heads, self.head_dim) - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - - attention_mask = self.split_attn_mask(attention_mask) - out = F.scaled_dot_product_attention( - query, - key, - value, - scale=self.scale, - attn_mask=attention_mask, - ) - out = out.transpose(1, 2).reshape(bsz, seq_len, -1) - - attn_output, _ = self.o_proj(out) - - return attn_output - - -class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): - def __init__(self, config: Phi4MultimodalAudioConfig): - super().__init__() - - self.feed_forward_in = Phi4MultimodalAudioMLP(config) - self.self_attn = Phi4MultimodalAudioAttention(config) - self.conv = Phi4MultimodalAudioConvModule(config) - self.feed_forward_out = Phi4MultimodalAudioMLP(config) - self.layer_norm_att = nn.LayerNorm(config.hidden_size) - self.layer_norm = nn.LayerNorm(config.hidden_size) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - ) -> torch.Tensor: - residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) - hidden_states = self.layer_norm_att(residual) - - hidden_states = residual + self.self_attn(hidden_states, attention_mask) - hidden_states = hidden_states + self.conv(hidden_states) - hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) - - out = self.layer_norm(hidden_states) - - return out - - -class Phi4MMAudioMeanVarianceNormLayer(nn.Module): - """Mean/variance normalization layer. - - Will subtract mean and multiply input by inverted standard deviation. - Typically used as a very first layer in a model. - - Args: - config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig) - object containing model parameters. - """ - - def __init__(self, config: Phi4MultimodalAudioConfig): - super().__init__() - self.global_mean = nn.Parameter(torch.zeros(config.input_size)) - self.global_invstd = nn.Parameter(torch.ones(config.input_size)) - - def forward(self, input_: torch.Tensor) -> torch.Tensor: - """MeanVarianceNormLayer Forward - - Args: - input_: torch.Tensor - input tensor. - """ - return (input_ - self.global_mean) * self.global_invstd - - -class Phi4MultimodalAudioModel(nn.Module): - def __init__(self, config: Phi4MultimodalAudioConfig): - super().__init__() - self.config = config - - self.encoder_embedding = Phi4MMAudioMeanVarianceNormLayer(config) - self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) - self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias( - config - ) - self.encoders = nn.ModuleList( - [ - Phi4MultimodalAudioConformerEncoderLayer(config) - for _ in range(config.num_blocks) - ] - ) - - def _streaming_mask( - self, - seq_len: int, - batch_size: int, - chunk_size: int, - left_chunk: int, - ): - # Create mask matrix for streaming - # S stores start index. if chunksize is 18, s is [0,18,36,....] - chunk_start_idx = np.arange(0, seq_len, chunk_size) - - enc_streaming_mask = ( - adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) - .unsqueeze(0) - .expand([batch_size, -1, -1]) - ) - return enc_streaming_mask - - def forward_embeddings( - self, - hidden_states: torch.Tensor, - masks: torch.Tensor, - ): - """Forwarding the inputs through the top embedding layers""" - seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) - if seq_len <= 0: - raise ValueError( - f"Sequence length after time reduction is invalid: {seq_len}." - "Your input feature is too short." - ) - - batch_size = hidden_states.shape[0] - - enc_streaming_mask = self._streaming_mask( - seq_len, batch_size, self.config.chunk_size, self.config.left_chunk - ) - enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) - - hidden_states, masks = self.embed(hidden_states, masks) - - streaming_mask = enc_streaming_mask - if streaming_mask is not None and masks is not None: - hs_mask = masks & streaming_mask - elif masks is not None: - hs_mask = masks - else: - hs_mask = streaming_mask - - return hidden_states, hs_mask, masks - - def calculate_hs_mask( - self, hidden_states: torch.Tensor, device: torch.device, mask: torch.Tensor - ): - max_audio_length = hidden_states.shape[1] - batch_size = hidden_states.shape[0] - enc_streaming_mask = self._streaming_mask( - max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk - ) - enc_streaming_mask = enc_streaming_mask.to(device) - if mask is None: - return enc_streaming_mask - - feature_lens = mask.sum(1) - padding_length = feature_lens - pad_mask = torch.arange(0, max_audio_length, device=device).expand( - padding_length.size(0), -1 - ) < padding_length.unsqueeze(1) - pad_mask = pad_mask.unsqueeze(1) - pad_mask = pad_mask & enc_streaming_mask - return pad_mask - - def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None): - hidden_states = self.encoder_embedding(hidden_states) - hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) - - unfolded = False - bs, seq_len, _ = hidden_states.shape - max_seq_len = 500 # maximum position for absolute positional encoding - if seq_len > max_seq_len: - # audio sequence is longer than max_seq_len, - # unfold it into chunks of max_seq_len - unfolded = True - # the unfold op will drop residual frames, - # pad it to the multiple of max_seq_len - if seq_len % max_seq_len > 0: - chunk_pad_size = max_seq_len - (seq_len % max_seq_len) - else: - chunk_pad_size = 0 - if chunk_pad_size > 0: - hidden_states_pad = F.pad( - hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0 - ) - hidden_states = hidden_states_pad.to(hidden_states.device) - - hidden_states = unfold_tensor(hidden_states, max_seq_len) - masks_unfold = None - if mask is not None: - # revise hs_mask here because the previous calculated hs_mask - # did not consider extra pad - subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] - extra_padded_subsamlped_pad_mask = F.pad( - subsampled_pad_mask, (0, chunk_pad_size), "constant", False - ) # extra padding to the pad mask - extra_padded_subsamlped_pad_mask = ( - extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() - ) - masks_unfold = unfold_tensor( - extra_padded_subsamlped_pad_mask, max_seq_len - ) # unfold the pad mask like we did to the input tensor - masks_unfold = masks_unfold.squeeze( - -1 - ).bool() # unfold op does not support bool tensor - hs_mask = self.calculate_hs_mask( - hidden_states, hidden_states.device, masks_unfold - ) # calculate hs_mask based on the unfolded pad mask - - relative_attention_bias = self.relative_attention_bias_layer(hidden_states) - attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias - - for layer in self.encoders: - hidden_states = layer(hidden_states, attention_mask) - - if unfolded: - embed_dim = hidden_states.shape[-1] - hidden_states = hidden_states.reshape(bs, -1, embed_dim) - # if we ever padded before unfolding, we need to remove the padding - if chunk_pad_size > 0: - hidden_states = hidden_states[:, :-chunk_pad_size, :] - - return hidden_states - - -class Phi4MMAudioEmbedding(nn.Module): - def __init__(self, config: Phi4MultimodalConfig): - super().__init__() - self.config = config - self.layer_idx = config.audio_config.feature_layer - - self.encoder = Phi4MultimodalAudioModel(config.audio_config) - - audio_config = config.audio_config - proj_input_size = audio_config.hidden_size * audio_config.downsample_rate - self.vision_speech_projection = Phi4MMProjector( - proj_input_size, config.hidden_size - ) - self.speech_projection = Phi4MMProjector(proj_input_size, config.hidden_size) - - def get_projection( - self, - audio_projection_mode: Literal["speech", "vision"], - ) -> Phi4MMProjector: - if audio_projection_mode == "speech": - return self.speech_projection - elif audio_projection_mode == "vision": - return self.vision_speech_projection - - def forward( - self, - audio_input_features: torch.FloatTensor, - audio_embed_sizes=None, - audio_attention_mask=None, - audio_projection_mode="speech", - ) -> torch.FloatTensor: - audio_projection = self.get_projection(audio_projection_mode) - - target_device = audio_projection.up.bias.device - target_dtype = audio_projection.up.bias.dtype - - audio_input_features = audio_input_features.to( - device=target_device, dtype=target_dtype - ) - - audio_encoder_hidden_states = self.encoder( - audio_input_features, audio_attention_mask - ) - audio_embeds = audio_projection(audio_encoder_hidden_states) - - return audio_embeds.flatten(0, 1) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Phi4MMImagePixelInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of images - - p: Number of patches (1 + num_patches) - - c: Number of channels (3) - - h: Height of each image patch - - w: Width of each image patch - - nc: Number of crops - - H_mask: Height of attention mask - - W_mask: Width of attention mask - """ - - type: Literal["pixel_values"] - - pixel_values: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape( - "bn", "p", 3, "h", "w", dynamic_dims={"p"} - ), # may be different per batch and image - ] - - image_sizes: Annotated[ - torch.Tensor, - TensorShape("bn", 2), # (height, width) - ] - - num_img_tokens: Annotated[ - list[int], - TensorShape("bn"), - ] - - image_attention_mask: Annotated[ - torch.Tensor, - TensorShape("bn", "nc", 32, 32), # H_mask, W_mask - ] - - -class Phi4MMImageEmbeddingInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of images - - f: Image feature size - - h: Hidden size (must match language model backbone) - """ - - type: Literal["image_embeds"] - - data: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "f", "h"), - ] - - -class Phi4MMAudioFeatureInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of audios - - f: Number of Mel filterbank bins (80) - - t: Time frames (M) - """ - - type: Literal["audio_features"] - - audio_features: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "t", 80, dynamic_dims={"t"}), - ] - - -class Phi4MMAudioEmbeddingInputs(TensorSchema): - """ - Dimensions: - - b: Batch size - - n: Number of audios - - f: Audio feature size - - h: Hidden size (must match language model backbone) - """ - - type: Literal["audio_embeds"] - - data: Annotated[ - NestedTensors, - TensorShape("b", "n", "f", "h"), - ] - - -Phi4MMImageInput: TypeAlias = Phi4MMImagePixelInputs | Phi4MMImageEmbeddingInputs -Phi4MMAudioInputs: TypeAlias = Phi4MMAudioFeatureInputs | Phi4MMAudioEmbeddingInputs - - -def cat_with_pad(tensors, dim, padding_value=0): - """ - cat along dim, while pad to max for all other dims - """ - ndim = tensors[0].dim() - assert all(t.dim() == ndim for t in tensors[1:]), ( - "All tensors must have the same number of dimensions" - ) - - out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] - out_size[dim] = sum(t.shape[dim] for t in tensors) - output = tensors[0].new_full(out_size, padding_value) - - index = 0 - for t in tensors: - # Create a slice list where every dimension except dim is full slice - slices = [slice(0, t.shape[d]) for d in range(ndim)] - # Update only the concat dimension slice - slices[dim] = slice(index, index + t.shape[dim]) - - output[slices] = t - index += t.shape[dim] - - return output - - -class Phi4MMProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> Phi4MultimodalConfig: - return self.ctx.get_hf_config(Phi4MultimodalConfig) - - def get_hf_processor(self, **kwargs: object) -> Phi4MMProcessor: - return self.ctx.get_hf_processor(Phi4MMProcessor, **kwargs) - - def get_feature_extractor(self, **kwargs: object) -> Phi4MultimodalFeatureExtractor: - return self.get_hf_processor(**kwargs).audio_processor - - def get_image_processor( - self, - processor: Phi4MMProcessor | None = None, - ) -> Phi4MultimodalImageProcessorFast: - if processor is None: - processor = self.get_hf_processor() - return processor.image_processor - - def get_dynamic_hd( - self, - processor: Phi4MMProcessor | None = None, - ) -> int: - return self.get_image_processor(processor).dynamic_hd - - def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"audio": None, "image": None} - - def _find_target_aspect_ratio( - self, - orig_width: int, - orig_height: int, - image_size: int, - max_num: int, - min_num: int, - ): - w_crop_num = math.ceil(orig_width / float(image_size)) - h_crop_num = math.ceil(orig_height / float(image_size)) - if w_crop_num * h_crop_num > max_num: - aspect_ratio = orig_width / orig_height - - # calculate the existing image aspect ratio - target_ratios = set( - (i, j) - for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num - ) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - - # find the closest aspect ratio to the target - image_processor = self.get_image_processor() - target_aspect_ratio = image_processor.find_closest_aspect_ratio( - aspect_ratio, - target_ratios, - orig_width, - orig_height, - image_size, - ) - - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - else: - target_width = image_size * w_crop_num - target_height = image_size * h_crop_num - target_aspect_ratio = (w_crop_num, h_crop_num) - return target_aspect_ratio, target_height, target_width - - def _compute_num_image_tokens( - self, - orig_width: int, - orig_height: int, - dynamic_hd_size: int, - vit_image_size: int, - vit_patch_size: int, - token_compression_factor: int = 2, - ): - """ - compute the number of tokens an image is expected to take up considering - the image encoder architecture and exclude output features containing - only padding pixels - - for siglip, vit_image_size=448, vit_patch_size=14, so output will be - 32x32 feature map - NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 - """ - assert vit_image_size % vit_patch_size == 0, ( - "vit_image_size must be divisible by vit_patch_size" - ) - assert vit_image_size // vit_patch_size % token_compression_factor == 0, ( - "vit_image_size // vit_patch_size must be divisible by " - "token_compression_factor" - ) - - target_aspect_ratio, target_height, target_width = ( - self._find_target_aspect_ratio( - orig_width, orig_height, vit_image_size, dynamic_hd_size, min_num=1 - ) - ) - assert target_aspect_ratio[0] * vit_image_size == target_width, ( - f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" - ) - assert target_aspect_ratio[1] * vit_image_size == target_height, ( - f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" - ) - assert ( - target_height % vit_image_size == 0 and target_width % vit_image_size == 0 - ) - - padding_height, padding_width = _get_padding_size( - orig_width, orig_height, target_height, target_width - ) - assert padding_width == 0 or padding_height == 0, ( - "padding_width or padding_height must be 0" - ) - - target_feat_width = target_width // vit_patch_size - target_feat_height = target_height // vit_patch_size - if padding_width >= vit_patch_size: - assert padding_height == 0, "padding_height not 0" - non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size - ) - non_pad_feat_height = target_feat_height - elif padding_height >= vit_patch_size: - assert padding_width == 0, "padding_width not 0" - non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size - ) - non_pad_feat_width = target_feat_width - else: - # small padding shorter than a vit patch - non_pad_feat_width = target_feat_width - non_pad_feat_height = target_feat_height - - feat_width = non_pad_feat_width // token_compression_factor - feat_height = non_pad_feat_height // token_compression_factor - # NOTE it's possible that the non-padding feature is not divisible - if non_pad_feat_width % token_compression_factor != 0: - feat_width += 1 - if non_pad_feat_height % token_compression_factor != 0: - feat_height += 1 - num_hd_patch_tokens = feat_width * feat_height - num_hd_newline_tokens = feat_height - vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2 - num_sep_tokens = 1 - num_global_image_newline_tokens = vit_feature_size // token_compression_factor - - return ( - num_global_image_tokens - + num_sep_tokens - + num_hd_patch_tokens - + num_hd_newline_tokens - + num_global_image_newline_tokens - ) - - def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - processor: Phi4MMProcessor | None = None, - ) -> int: - hf_config = self.get_hf_config() - vision_config = hf_config.vision_config - vit_image_size = vision_config.image_size - vit_patch_size = vision_config.patch_size - - dynamic_hd_size = self.get_dynamic_hd(processor=processor) - - # we use default `token_compression_factor=2`, - # since it's not in HF vision config. - image_num_tokens = self._compute_num_image_tokens( - image_width, - image_height, - dynamic_hd_size=dynamic_hd_size, - vit_image_size=vit_image_size, - vit_patch_size=vit_patch_size, - ) - - return image_num_tokens - - def get_image_size_with_most_features( - self, - processor: Phi4MMProcessor | None = None, - ) -> ImageSize: - vit_image_size = self.get_hf_config().vision_config.image_size - - max_side = vit_image_size * self.get_dynamic_hd(processor=processor) - return ImageSize(height=max_side, width=vit_image_size) - - def get_audio_num_frames(self, audio_len: int, sr: float) -> int: - """ - Compute the output size of the `extract_features` method. - - Args: - audio_len (int): Length of the input waveform in samples. - sr (float): Sampling rate of the waveform, either 16000 or 8000. - - Returns: - tuple (int, int): Output size as (T, D), where: - T: Number of time frames. - D: Number of Mel filterbank bins (80). - """ - - # Resample to 16000 or 8000 if needed - if sr > 16000: - audio_len //= sr // 16000 - elif 8000 <= sr < 16000: - # We'll resample to 16K from 8K - audio_len *= 2 - elif sr < 8000: - raise RuntimeError(f"Unsupported sample rate {sr}") - - # Spectrogram parameters for 16 kHz - win_length = 400 # Frame length in samples - hop_length = 160 # Frame shift in samples - - # Calculate number of frames (T) - num_frames = (audio_len - win_length) // hop_length + 1 - if num_frames < 1: - raise ValueError("Waveform too short for given parameters.") - - # Return time frames (T) - return num_frames - - def _compute_audio_embed_size(self, audio_frames: int) -> int: - """ - Compute the size of audio embeddings from the number of audio frames. - """ - # `_compute_audio_embed_size` in audio_processor use torch for - # computation, therefore we re-implement it to use pythonic - # numeric computation to avoid extra tensor conversion. - audio_processor = self.get_feature_extractor() - audio_compression_rate = audio_processor.audio_compression_rate - audio_downsample_rate = audio_processor.audio_downsample_rate - - integer = audio_frames // audio_compression_rate - remainder = audio_frames % audio_compression_rate - result = integer + int(remainder > 0) - - integer = result // audio_downsample_rate - remainder = result % audio_downsample_rate - result = integer + int(remainder > 0) # qformer compression - - return result - - -class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_audios = mm_counts.get("audio", 0) - num_images = mm_counts.get("image", 0) - - tokenizer = self.info.get_tokenizer() - image_tokens: str = tokenizer.image_token * num_images - audio_tokens: str = tokenizer.audio_token * num_audios - - return image_tokens + audio_tokens - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalDataDict: - num_audios = mm_counts.get("audio", 0) - num_images = mm_counts.get("image", 0) - - target_width, target_height = self.info.get_image_size_with_most_features() - - image_overrides = mm_options.get("image") if mm_options else None - audio_overrides = mm_options.get("audio") if mm_options else None - - mm_data = { - "image": self._get_dummy_images( - width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides, - ), - "audio": self._get_dummy_audios( - length=_AUDIO_MAX_SOUNDFILE_SIZE, - num_audios=num_audios, - overrides=audio_overrides, - ), - } - - return mm_data - - -class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: - feature_extractor = self.info.get_feature_extractor() - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - if not mm_data: - prompt_ids = self.info.get_tokenizer().encode(prompt) - prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) - return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - - audio_data = mm_data.pop("audios", []) - if audio_data: - mm_data["audio"] = audio_data - - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs - ) - - if "image_pixel_values" in processed_outputs: - num_img_tokens = [ - self.info.get_num_image_tokens( - image_width=img_size[0], image_height=img_size[1] - ) - for img_size in processed_outputs["image_sizes"] - ] - processed_outputs["num_img_tokens"] = num_img_tokens - - if audio_data: - audio_features = processed_outputs["audio_input_features"] - sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate - feature_sizes = [ - self.info.get_audio_num_frames(len(audio), sr) for audio in audio_data - ] - processed_outputs["audio_input_features"] = [ - audio_features[idx, :size] for idx, size in enumerate(feature_sizes) - ] - - return processed_outputs - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - image_pixel_values=MultiModalFieldConfig.batched("image"), - image_attention_mask=MultiModalFieldConfig.batched("image"), - image_sizes=MultiModalFieldConfig.batched("image"), - num_img_tokens=MultiModalFieldConfig.batched("image"), - audio_input_features=MultiModalFieldConfig.batched("audio"), - ) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - tokenizer = self.info.get_tokenizer() - image_token_id: int = tokenizer.vocab[tokenizer.image_token] - audio_token_id: int = tokenizer.vocab[tokenizer.audio_token] - - hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - audio_processor = self.info.get_feature_extractor(**hf_processor_mm_kwargs) - - def get_image_replacement_phi4mm(item_idx: int): - images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems) - ) - - if isinstance(images, ImageEmbeddingItems): - num_image_tokens = images.get_feature_size(item_idx) - else: - image_size = images.get_image_size(item_idx) - num_image_tokens = self.info.get_num_image_tokens( - image_width=image_size.width, - image_height=image_size.height, - processor=hf_processor, - ) - - return [image_token_id] * num_image_tokens - - def get_audio_replacement_phi4mm(item_idx: int): - audios = mm_items.get_items("audio", AudioProcessorItems) - # TODO(Isotr0py): support embedding inputs - audio_len = audios.get_audio_length(item_idx) - audio_frames = self.info.get_audio_num_frames( - audio_len, audio_processor.sampling_rate - ) - audio_embed_size = self.info._compute_audio_embed_size(audio_frames) - - return [audio_token_id] * audio_embed_size - - return [ - PromptReplacement( - modality="audio", - target=[audio_token_id], - replacement=get_audio_replacement_phi4mm, - ), - PromptReplacement( - modality="image", - target=[image_token_id], - replacement=get_image_replacement_phi4mm, - ), - ] - - -@MULTIMODAL_REGISTRY.register_processor( - Phi4MMMultiModalProcessor, - info=Phi4MMProcessingInfo, - dummy_inputs=Phi4MMDummyInputsBuilder, -) -class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): - """ - Implements the Phi-4-multimodal-instruct model in vLLM. - """ - - merge_by_field_config = True - - packed_modules_mapping = { - "qkv_proj": [ - "qkv_proj", - ], - "gate_up_proj": [ - "gate_up_proj", - ], - } - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - # Multimodal embedding - "model.embed_tokens_extend.": "", - # LLM backbone - "model.": "language_model.model.", - }, - orig_to_new_substr={ - # projection - ".img_projection_": ".img_projection.", - ".up_proj_for_speech.": ".speech_projection.up.", - ".up_proj_for_vision_speech.": ".vision_speech_projection.up.", - ".down_proj_for_speech.": ".speech_projection.down.", - ".down_proj_for_vision_speech.": ".vision_speech_projection.down.", - }, - ) - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> str | None: - if modality.startswith("image"): - return "<|image|>" - if modality.startswith("audio"): - return "<|audio|>" - - raise ValueError("Only image or audio modality is supported") - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - multimodal_config = vllm_config.model_config.multimodal_config - self.config = config - self.multimodal_config = multimodal_config - - # TODO: Optionally initializes these for supporting input embeddings. - self.image_embed = Phi4MMImageEmbedding( - config, - # prefix=maybe_prefix(prefix, "image_embed"), - ) - self.audio_embed = Phi4MMAudioEmbedding( - config, - # prefix=maybe_prefix(prefix, "audio_embed"), - ) - - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - architectures=["Phi3ForCausalLM"], - ) - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) - - def _parse_and_validate_audio_input( - self, **kwargs: object - ) -> Phi4MMAudioInputs | None: - """ - Parse and validate the audio input to the model. This handles both - audio features and audio embeddings, but only the former is used for - now. - - Args: - kwargs (object): Keyword arguments. - - Returns: - Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. - """ - audio_features = kwargs.pop("audio_input_features", None) - audio_embeds = kwargs.pop("audio_embeds", None) - - if audio_features is None and audio_embeds is None: - return None - - if audio_features is not None: - return Phi4MMAudioFeatureInputs( - type="audio_features", - audio_features=audio_features, - ) - - if audio_embeds is not None: - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) - - raise AssertionError("This line should be unreachable.") - - def _process_audio_input( - self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str - ) -> NestedTensors: - """ - Create the audio embeddings from the audio input, where the audio input - is pairs of audio features and audio embed lengths. The audio input is - created by `input_mapper_for_phi4mm_audio`. - - Args: - audio_input (Phi4MMAudioInputs): Audio input. - - Returns: - NestedTensors: Audio embeddings - """ - if audio_input["type"] == "audio_embeds": - return audio_input["data"] - - audio_features = audio_input["audio_features"] - # (e.g. multiple examples) and the second dim is the multi-audio dim - # (e.g. multiple audios in the same example) - - dtype = next(self.audio_embed.parameters()).dtype - audio_embeds = [ - self.audio_embed( - features.unsqueeze(0).to(dtype), - audio_projection_mode=audio_projection_mode, - ) - for features in audio_features - ] - return audio_embeds - - def _parse_and_validate_image_input( - self, **kwargs: object - ) -> Phi4MMImagePixelInputs | None: - pixel_values = kwargs.get("image_pixel_values") - if pixel_values is None: - return None - - image_sizes = kwargs.get("image_sizes") - image_attention_mask = kwargs.get("image_attention_mask") - num_img_tokens = kwargs.get("num_img_tokens") - assert ( - image_sizes is not None - and image_attention_mask is not None - and num_img_tokens is not None - ), "Missing image inputs" - - return Phi4MMImagePixelInputs( - type="pixel_values", - pixel_values=pixel_values, - image_sizes=image_sizes, - image_attention_mask=image_attention_mask, - num_img_tokens=num_img_tokens, - ) - - def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = {} - - # Preserve the order of modalities if there are multiple of them - # from the order of kwargs. - for input_key in kwargs: - if ( - input_key in ("image_pixel_values", "image_embeds") - and "images" not in modalities - ): - modalities["images"] = self._parse_and_validate_image_input(**kwargs) - if ( - input_key in ("audio_input_features", "audio_embeds") - and "audios" not in modalities - ): - modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) - - return modalities - - def _process_image_input( - self, image_input: Phi4MMImagePixelInputs - ) -> list[torch.Tensor]: - if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) - else: - dtype = next(self.image_embed.parameters()).dtype - pixel_values = image_input["pixel_values"].to(dtype) - image_sizes = image_input["image_sizes"] - image_attention_mask = image_input["image_attention_mask"] - image_embeds = self.image_embed( - pixel_values, image_sizes, image_attention_mask - ) - return image_embeds - - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if not modalities: - return [] - - # The result multimodal_embeddings is tuple of tensors, with each - # tensor corresponding to a multimodal data item (image or video). - multimodal_embeddings: tuple[torch.Tensor, ...] = () - - # NOTE: It is important to iterate over the keys in this dictionary - # to preserve the order of the modalities. - audio_projection_mode = "speech" - for modality in modalities: - # make sure process images first - if modality == "images": - audio_projection_mode = "vision" - image_input = modalities["images"] - image_embeddings = self._process_image_input(image_input) - multimodal_embeddings += tuple(image_embeddings) - if modality == "audios": - audio_input = modalities["audios"] - audio_embeddings = self._process_audio_input( - audio_input, audio_projection_mode=audio_projection_mode - ) - multimodal_embeddings += tuple(audio_embeddings) - - return multimodal_embeddings - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ) -> torch.Tensor: - if intermediate_tensors is not None: - inputs_embeds = None - - hidden_states = self.language_model( - input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="language_model.", - connector=[ - "img_projection", - "vision_speech_projection", - "speech_projection", - ], - tower_model=["image_embed", "audio_embed"], - ) - - def get_language_model(self) -> torch.nn.Module: - return self.language_model diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 8425549a7bd20..179d5df869bea 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -984,8 +984,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): Implements the Phi-4-multimodal-instruct model in vLLM. """ - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": [ "qkv_proj", diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 49530776f8903..14f73d0c64586 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -352,7 +352,6 @@ class PhiMoEAttention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 54bde75cc0131..555e6ea4b8cb2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -59,8 +59,8 @@ from vllm.multimodal.processing import ( from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.tokenizers import MistralTokenizer -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -366,8 +366,6 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]) dummy_inputs=PixtralDummyInputsBuilder, ) class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 472de5590dcf8..6765ee0c5779c 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -574,7 +574,6 @@ class Plamo2AttentionMixer(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=config.rope_parameters, ) diff --git a/vllm/model_executor/models/plamo3.py b/vllm/model_executor/models/plamo3.py index 4aeb9d432dcc6..3557104d905cb 100644 --- a/vllm/model_executor/models/plamo3.py +++ b/vllm/model_executor/models/plamo3.py @@ -179,7 +179,6 @@ class Plamo3AttentionMixer(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=rope_parameters, ) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 12285cf9c1968..61a6e67805d6a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -114,7 +114,6 @@ class QWenAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, ) @@ -282,6 +281,9 @@ class QWenBaseModel(nn.Module): self.transformer.make_empty_intermediate_tensors ) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.wte(input_ids) + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 34c31d8deee23..f4c2d3cb75d25 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -122,6 +122,8 @@ class Qwen2Attention(nn.Module): prefix: str = "", attn_type: str = AttentionType.DECODER, dual_chunk_attention_config: dict[str, Any] | None = None, + qk_norm: bool = False, + rms_norm_eps: float = 1e-6, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -144,6 +146,7 @@ class Qwen2Attention(nn.Module): self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.dual_chunk_attention_config = dual_chunk_attention_config + self.qk_norm = qk_norm self.qkv_proj = QKVParallelLinear( hidden_size, @@ -162,9 +165,13 @@ class Qwen2Attention(nn.Module): prefix=f"{prefix}.o_proj", ) + # QK Normalization support (used in BAGEL and some other models) + if self.qk_norm: + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=rope_parameters, dual_chunk_attention_config=dual_chunk_attention_config, @@ -198,6 +205,23 @@ class Qwen2Attention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Apply QK normalization if enabled (before RoPE) + if self.qk_norm: + # Reshape to apply per-head normalization + # q shape: (total_tokens, q_size) -> (total_tokens, num_heads, head_dim) + total_tokens = q.shape[0] + q = q.view(total_tokens, self.num_heads, self.head_dim) + k = k.view(total_tokens, self.num_kv_heads, self.head_dim) + + # Apply normalization + q = self.q_norm(q) + k = self.k_norm(k) + + # Reshape back + q = q.view(total_tokens, self.q_size) + k = k.view(total_tokens, self.kv_size) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -228,6 +252,9 @@ class Qwen2DecoderLayer(nn.Module): else: attn_type = AttentionType.ENCODER_ONLY + # Check if QK normalization is enabled (used in BAGEL and some other models) + qk_norm = getattr(config, "qk_norm", False) + self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -239,6 +266,8 @@ class Qwen2DecoderLayer(nn.Module): prefix=f"{prefix}.self_attn", attn_type=attn_type, dual_chunk_attention_config=dual_chunk_attention_config, + qk_norm=qk_norm, + rms_norm_eps=config.rms_norm_eps, ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -481,6 +510,8 @@ class Qwen2Model(nn.Module): continue if is_pp_missing_parameter(name, self): continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) @@ -503,7 +534,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config.get_text_config() quant_config = vllm_config.quant_config self.config = config diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 7506ee8656fda..f9bce4bf981b2 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -88,7 +88,6 @@ from vllm.multimodal.processing import ( ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.tokenizer import encode_tokens from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( @@ -591,7 +590,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( tokenization_kwargs=tokenization_kwargs, ) tokenizer = self.info.get_tokenizer() - prompt_ids = encode_tokens(tokenizer, prompt) + prompt_ids = tokenizer.encode(prompt) else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) @@ -774,8 +773,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration( SupportsMRoPE, Qwen2_5OmniConditionalGenerationMixin, ): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", @@ -848,6 +845,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + multimodal_config=multimodal_config, ) else: self.visual = None diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 6ca490f467634..b730ac0315893 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -42,13 +42,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( ) from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import maybe_get_vit_flash_attn_backend -from vllm.attention.ops.vit_attn_wrappers import ( - vit_flash_attn_wrapper, - vit_torch_sdpa_wrapper, -) +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.forward_context import set_forward_context @@ -64,6 +60,9 @@ from vllm.model_executor.layers.linear import ( ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.vision import should_torch_compile_mm_vit @@ -77,7 +76,7 @@ from vllm.multimodal.evs import ( from vllm.multimodal.inputs import ( MultiModalFeatureSpec, MultiModalFieldConfig, - MultiModalKwargs, + MultiModalKwargsItems, ) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate @@ -99,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import ( Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, - apply_rotary_pos_emb_vision, ) from .utils import ( AutoWeightsLoader, @@ -267,10 +265,15 @@ class Qwen2_5_VisionMLP(nn.Module): bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] @@ -304,13 +307,16 @@ class Qwen2_5_VisionAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.tp_size = ( 1 if use_data_parallel @@ -342,18 +348,14 @@ class Qwen2_5_VisionAttention(nn.Module): prefix=f"{prefix}.proj", disable_tp=use_data_parallel, ) - self.attn_backend = attn_backend - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + multimodal_config=multimodal_config, ) - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) def forward( self, @@ -380,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module): qk_reshaped = einops.rearrange( qk, "b s two head head_dim -> (two b) s head head_dim", two=2 ) - qk_rotated = apply_rotary_pos_emb_vision( - qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin + qk_rotated = self.apply_rotary_emb( + qk_reshaped, + rotary_pos_emb_cos, + rotary_pos_emb_sin, ) qk_rotated = qk_rotated.view( 2, @@ -394,32 +398,17 @@ class Qwen2_5_VisionAttention(nn.Module): else: q, k, v = qkv.unbind(dim=2) - if self.is_flash_attn_backend: - context_layer = vit_flash_attn_wrapper( - q, - k, - v, - cu_seqlens, - max_seqlen, - batch_size, - self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, - ) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - from vllm.platforms import current_platform + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) - # Never remove the next contiguous logic - # Without it, hallucinations occur with the backend - if current_platform.is_rocm(): - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - context_layer = vit_torch_sdpa_wrapper( - q, - k, - v, - cu_seqlens, - ) + context_layer = einops.rearrange( + context_layer, "b s h d -> s b (h d)", b=batch_size + ).contiguous() output, _ = self.proj(context_layer) return output @@ -443,10 +432,8 @@ class Qwen2_5_VisionBlock(nn.Module): act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -458,10 +445,8 @@ class Qwen2_5_VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend=attn_backend, - attn_backend_override=attn_backend_override, ) self.mlp = Qwen2_5_VisionMLP( dim, @@ -469,8 +454,8 @@ class Qwen2_5_VisionBlock(nn.Module): act_fn=act_fn, bias=True, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -542,10 +527,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module): norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) @@ -586,9 +576,8 @@ class Qwen2_5_VisionTransformer(nn.Module): vision_config: Qwen2_5_VLVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -598,7 +587,6 @@ class Qwen2_5_VisionTransformer(nn.Module): depth = vision_config.depth self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads - self.use_data_parallel = use_data_parallel self.out_hidden_size = vision_config.out_hidden_size # args for get_window_index_thw @@ -612,7 +600,7 @@ class Qwen2_5_VisionTransformer(nn.Module): # DO NOT MOVE THIS IMPORT from vllm.compilation.backends import set_model_tag - with set_model_tag("Qwen2_5_VisionPatchEmbed"): + with set_model_tag("Qwen2_5_VisionPatchEmbed", is_encoder=True): self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, @@ -624,24 +612,22 @@ class Qwen2_5_VisionTransformer(nn.Module): head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = get_rope( head_size=head_dim, - rotary_dim=head_dim // 2, max_position=8192, is_neox_style=True, + rope_parameters={"partial_rotary_factor": 0.5}, ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, @@ -651,7 +637,7 @@ class Qwen2_5_VisionTransformer(nn.Module): f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) - with set_model_tag("Qwen2_5_VisionBlock"): + with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True): self.blocks = nn.ModuleList( [ Qwen2_5_VisionBlock( @@ -661,24 +647,22 @@ class Qwen2_5_VisionTransformer(nn.Module): act_fn=get_act_and_mul_fn(vision_config.hidden_act), norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, - attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] ) - with set_model_tag("Qwen2_5_VisionPatchMerger"): + with set_model_tag("Qwen2_5_VisionPatchMerger", is_encoder=True): self.merger = Qwen2_5_VisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, norm_layer=norm_layer, spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.merger", - use_data_parallel=use_data_parallel, ) @property @@ -973,7 +957,7 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) @@ -1039,9 +1023,6 @@ class Qwen2_5_VLForConditionalGeneration( SupportsMultiModalPruning, SupportsMRoPE, ): - merge_by_field_config = True - multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -1203,18 +1184,12 @@ class Qwen2_5_VLForConditionalGeneration( if multimodal_config.get_limit_per_prompt( "image" ) or multimodal_config.get_limit_per_prompt("video"): - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Qwen2_5_VisionTransformer( vision_config=config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self.quant_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) else: self.visual = None diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 7e883a393aa8d..f84ddfa84f6ab 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -313,8 +313,6 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing dummy_inputs=Qwen2AudioDummyInputsBuilder, ) class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("audio"): diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 5a428740082f6..2750f1864b81a 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -244,7 +244,6 @@ class Qwen2MoeAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, dual_chunk_attention_config=dual_chunk_attention_config, @@ -367,6 +366,8 @@ class Qwen2MoeModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -512,6 +513,12 @@ class Qwen2MoeModel(nn.Module): continue else: name = remapped_kv_scale_name + # GGUF: make sure that shared_expert_gate is a 2D tensor. + if ( + "mlp.shared_expert_gate" in name + and len(loaded_weight.shape) == 1 + ): + loaded_weight = loaded_weight[None, :] param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 8fbd896223944..321fbd764c0f5 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -25,6 +25,7 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +import math from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from typing import Annotated, Any, Literal, TypeAlias @@ -32,7 +33,6 @@ from typing import Annotated, Any, Literal, TypeAlias import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from einops import rearrange from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor @@ -44,12 +44,10 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, -) -from vllm.config import VllmConfig +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import parallel_state +from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU @@ -61,8 +59,7 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding.common import ( - apply_rotary_emb_torch, - dispatch_rotary_emb_function, + ApplyRotaryEmb, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -250,10 +247,15 @@ class Qwen2VisionMLP(nn.Module): hidden_features: int, act_layer: type[nn.Module] = QuickGELU, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.fc1 = ColumnParallelLinear( in_features, hidden_features, @@ -277,16 +279,6 @@ class Qwen2VisionMLP(nn.Module): return x -def apply_rotary_pos_emb_vision( - t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function( - default=partial(apply_rotary_emb_torch, is_neox_style=True) - ) - output = rotary_emb_function(t, cos, sin).type_as(t) - return output - - class Qwen2VisionAttention(nn.Module): def __init__( self, @@ -294,12 +286,16 @@ class Qwen2VisionAttention(nn.Module): num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() # Per attention head and per partition values. + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.tp_size = ( 1 if use_data_parallel @@ -328,41 +324,32 @@ class Qwen2VisionAttention(nn.Module): disable_tp=use_data_parallel, ) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Qwen2-VL does not support {self.attn_backend} backend now." - ) - - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = tensor_model_parallel_all_gather(qkv) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] new_shape = ( seq_len, @@ -386,60 +373,27 @@ class Qwen2VisionAttention(nn.Module): # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] q, k, v = self.split_qkv(x) - batch_size = q.shape[1] q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) - qk_rotated = apply_rotary_pos_emb_vision( - qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin + qk_rotated = self.apply_rotary_emb( + qk_concat, + rotary_pos_emb_cos, + rotary_pos_emb_sin, ) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) - output = self.flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, - ) - - context_layer = rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - from vllm.platforms import current_platform - - if current_platform.is_rocm(): - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() + context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) return output @@ -454,9 +408,8 @@ class Qwen2VisionBlock(nn.Module): act_layer: type[nn.Module] = QuickGELU, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -470,17 +423,16 @@ class Qwen2VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.mlp = Qwen2VisionMLP( dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -540,10 +492,15 @@ class Qwen2VisionPatchMerger(nn.Module): norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) @@ -587,9 +544,8 @@ class Qwen2VisionTransformer(nn.Module): vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() @@ -603,7 +559,11 @@ class Qwen2VisionTransformer(nn.Module): num_heads = vision_config.num_heads mlp_ratio = vision_config.mlp_ratio - self.use_data_parallel = use_data_parallel + self.use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.out_hidden_size = vision_config.hidden_size self.spatial_merge_size = spatial_merge_size @@ -621,9 +581,9 @@ class Qwen2VisionTransformer(nn.Module): head_dim = embed_dim // num_heads self.rotary_pos_emb = get_rope( head_size=head_dim, - rotary_dim=head_dim // 2, max_position=8192, is_neox_style=True, + rope_parameters={"partial_rotary_factor": 0.5}, ) self.blocks = nn.ModuleList( @@ -635,8 +595,7 @@ class Qwen2VisionTransformer(nn.Module): norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) for layer_idx in range(depth) ] @@ -647,7 +606,10 @@ class Qwen2VisionTransformer(nn.Module): norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.merger", - use_data_parallel=use_data_parallel, + multimodal_config=multimodal_config, + ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend if multimodal_config else None ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, @@ -708,7 +670,7 @@ class Qwen2VisionTransformer(nn.Module): AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, }: - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen def forward( @@ -811,14 +773,14 @@ def _create_qwen2vl_field_factory( image_embeds=MultiModalFieldConfig.flat_from_sizes( "image", image_embed_grid_sizes ), - image_grid_thw=MultiModalFieldConfig.batched("image"), + image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes ), video_embeds=MultiModalFieldConfig.flat_from_sizes( "video", video_embed_grid_sizes ), - video_grid_thw=MultiModalFieldConfig.batched("video"), + video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), ) return _qwen2vl_field_config @@ -959,13 +921,42 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): return num_video_tokens def get_image_size_with_most_features(self) -> ImageSize: - max_image_size, _ = self._get_vision_info( - image_width=9999999, - image_height=9999999, - num_frames=1, - image_processor=None, - ) - return max_image_size + # NOTE: Simply processing a huge size with _get_vision_info might not give a + # size that maximizes the number of featrues, i.e., the number of (merged) + # patches. This is because the number of patches limits the allowed aspect + # ratios. For example, suppose the maximum number of patches is 1280. A square + # image cannot be broken down into 1280 patches, so feeding a giant square image + # into _get_vision_info will not yield a size that maximizes the number of + # patches. Therefore, we directly factorize the maximum number of patches into + # height and width. The tricky part is to avoid extreme aspect ratios (>200 for + # qwen2-vl). If we can't find a suitable aspect ratio, we decrease the number of + # patches and retry. This is safe because the processor does not accept extreme + # aspect ratios, so there is no valid post-resize image with the number of + # patches that yields extreme aspect ratios. + + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + image_processor = self.get_image_processor() + max_pixels = image_processor.max_pixels or image_processor.size["longest_edge"] + unit = patch_size * merge_size + max_seq_len = max_pixels // (unit * unit) + + def closest_factor_pair(n: int) -> tuple[int, int]: + # left <= right + for d in range(math.isqrt(n), 0, -1): + if n % d == 0: + return d, n // d + return 1, n + + height_factor, width_factor = 1, max_seq_len + for seq_len in range(max_seq_len, 0, -1): + height_factor, width_factor = closest_factor_pair(seq_len) + if width_factor / height_factor <= 200: + break + + return ImageSize(width=unit * width_factor, height=unit * height_factor) def get_max_image_tokens(self) -> int: target_width, target_height = self.get_image_size_with_most_features() @@ -1131,9 +1122,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]) class Qwen2VLForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): - merge_by_field_config = True - multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} - # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -1286,18 +1274,12 @@ class Qwen2VLForConditionalGeneration( if multimodal_config.get_limit_per_prompt( "image" ) or multimodal_config.get_limit_per_prompt("video"): - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) else: self.visual = None @@ -1394,9 +1376,11 @@ class Qwen2VLForConditionalGeneration( else: pixel_values_videos = video_input["pixel_values_videos"] if self.use_data_parallel: - grid_thw_list = grid_thw.tolist() return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" + self.visual, + pixel_values_videos, + grid_thw.tolist(), + rope_type="rope_3d", ) else: video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) @@ -1576,15 +1560,6 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration): } ) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig - # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig. - config = vllm_config.model_config.hf_config - qwen2vl_config = config.text_config - qwen2vl_config.architectures = config.architectures - vllm_config.model_config.hf_config = qwen2vl_config - super().__init__(vllm_config=vllm_config, prefix=prefix) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 7d2b3e5f9bc79..0d0da52ed7382 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -111,7 +111,6 @@ class Qwen3Attention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=rope_parameters, dual_chunk_attention_config=dual_chunk_attention_config, diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 6f520706a3176..0be81ecc7dd3a 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -269,7 +269,6 @@ class Qwen3MoeAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, dual_chunk_attention_config=dual_chunk_attention_config, @@ -403,6 +402,7 @@ class Qwen3MoeModel(nn.Module): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config + self.quant_config = quant_config self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -505,6 +505,19 @@ class Qwen3MoeModel(nn.Module): loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + assert loaded_weight.numel() == 1, ( + f"KV scale numel {loaded_weight.numel()} != 1" + ) + loaded_weight = loaded_weight.squeeze() + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 661a182151d74..ccf6cc6e5894b 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -747,10 +747,8 @@ class Qwen3NextAttention(nn.Module): self.rotary_emb = get_rope( head_size=self.head_dim, - rotary_dim=self.head_dim, max_position=config.max_position_embeddings, rope_parameters=config.rope_parameters, - partial_rotary_factor=config.partial_rotary_factor, dual_chunk_attention_config=self.dual_chunk_attention_config, ) @@ -1094,6 +1092,8 @@ class Qwen3NextModel(nn.Module): name.endswith(".bias") or name.endswith("_bias") ) and name not in params_dict: continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -1110,6 +1110,11 @@ class Qwen3NextModel(nn.Module): continue if is_pp_missing_parameter(name, self): continue + if name not in params_dict: + logger.warning_once( + f"Parameter {name} not found in params_dict, skip loading" + ) + continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 39dd42552ae8f..089129e443c01 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -48,7 +48,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.attention.backends.registry import AttentionBackendEnum from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY @@ -62,6 +62,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems @@ -191,6 +192,7 @@ class Qwen3_VisionBlock(nn.Module): mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, norm_layer: Callable[[int], nn.Module] | None = None, + multimodal_config: MultiModalConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: @@ -204,6 +206,7 @@ class Qwen3_VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", ) self.mlp = Qwen3_VisionMLP( @@ -298,8 +301,8 @@ class Qwen3Omni_VisionTransformer(nn.Module): vision_config, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -332,9 +335,9 @@ class Qwen3Omni_VisionTransformer(nn.Module): head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = get_rope( head_size=head_dim, - rotary_dim=head_dim // 2, max_position=8192, is_neox_style=True, + rope_parameters={"partial_rotary_factor": 0.5}, ) self.blocks = nn.ModuleList( @@ -346,6 +349,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", ) for layer_idx in range(vision_config.depth) @@ -375,6 +379,12 @@ class Qwen3Omni_VisionTransformer(nn.Module): ] ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) + self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -493,7 +503,10 @@ class Qwen3Omni_VisionTransformer(nn.Module): cu_seqlens: torch.Tensor, ) -> torch.Tensor: max_seqlen = torch.zeros([], device=cu_seqlens.device) - if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen @@ -1127,8 +1140,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( SupportsMRoPE, Qwen3OmniMoeConditionalGenerationMixin, ): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", @@ -1137,6 +1148,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( } ) + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): @@ -1174,17 +1197,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Qwen3Omni_VisionTransformer( vision_config=thinker_config.vision_config, norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) self.quant_config = quant_config @@ -1763,3 +1781,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( mrope_position_delta = llm_positions.max() + 1 - seq_len return llm_positions, mrope_position_delta + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.merger", + tower_model=["visual.", "audio_tower."], + ) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1d3929b936a9f..c0589986d1fe8 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -50,7 +50,7 @@ from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import AttentionBackendEnum from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_pp_group from vllm.logger import init_logger @@ -67,12 +67,19 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.evs import ( + compute_mrope_for_media, + compute_retained_tokens_count, + compute_retention_mask, + recompute_mrope_positions, +) from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItem, MultiModalKwargsItems, + PlaceholderRange, VideoItem, ) from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser @@ -92,7 +99,9 @@ from .interfaces import ( SupportsLoRA, SupportsMRoPE, SupportsMultiModal, + SupportsMultiModalPruning, SupportsPP, + _require_is_multimodal, ) from .qwen2_5_vl import ( Qwen2_5_VisionAttention, @@ -103,7 +112,7 @@ from .qwen2_5_vl import ( Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs, ) -from .qwen2_vl import Qwen2VLProcessingInfo +from .qwen2_vl import Qwen2VLMultiModalDataParser, Qwen2VLProcessingInfo from .qwen3 import Qwen3ForCausalLM, Qwen3Model from .utils import ( AutoWeightsLoader, @@ -160,10 +169,15 @@ class Qwen3_VisionMLP(nn.Module): bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.linear_fc1 = ColumnParallelLinear( in_features, hidden_features, @@ -197,10 +211,9 @@ class Qwen3_VisionBlock(nn.Module): mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, norm_layer: Callable[[int], nn.Module] | None = None, + multimodal_config: MultiModalConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, ) -> None: super().__init__() if norm_layer is None: @@ -212,9 +225,8 @@ class Qwen3_VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel, - attn_backend=attn_backend, ) self.mlp = Qwen3_VisionMLP( dim, @@ -222,8 +234,8 @@ class Qwen3_VisionBlock(nn.Module): act_fn=act_fn, bias=True, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -255,10 +267,15 @@ class Qwen3_VisionPatchMerger(nn.Module): spatial_merge_size: int = 2, use_postshuffle_norm: bool = False, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.hidden_size = context_dim * (spatial_merge_size**2) self.use_postshuffle_norm = use_postshuffle_norm @@ -304,9 +321,8 @@ class Qwen3_VisionTransformer(nn.Module): vision_config: Qwen3VLVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -317,7 +333,6 @@ class Qwen3_VisionTransformer(nn.Module): self.spatial_merge_unit = self.spatial_merge_size**2 self.temporal_patch_size = vision_config.temporal_patch_size self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes - self.use_data_parallel = use_data_parallel self.num_grid_per_side = int(self.num_position_embeddings**0.5) # NOTE: This is used for creating empty tensor for all_gather for @@ -339,9 +354,9 @@ class Qwen3_VisionTransformer(nn.Module): head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = get_rope( head_size=head_dim, - rotary_dim=head_dim // 2, max_position=8192, is_neox_style=True, + rope_parameters={"partial_rotary_factor": 0.5}, ) self.merger = Qwen3_VisionPatchMerger( @@ -350,8 +365,8 @@ class Qwen3_VisionTransformer(nn.Module): norm_layer=norm_layer, spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.merger", - use_data_parallel=use_data_parallel, ) self.deepstack_merger_list = nn.ModuleList( @@ -363,13 +378,16 @@ class Qwen3_VisionTransformer(nn.Module): use_postshuffle_norm=True, norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", - use_data_parallel=use_data_parallel, ) for layer_idx in range(len(self.deepstack_visual_indexes)) ] ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend if multimodal_config else None + ) self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), @@ -393,9 +411,8 @@ class Qwen3_VisionTransformer(nn.Module): act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], norm_layer=norm_layer, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, ) for layer_idx in range(vision_config.depth) ] @@ -884,7 +901,10 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: - return MultiModalDataParser(video_needs_metadata=True) + return Qwen2VLMultiModalDataParser( + self.info.get_hf_config().vision_config.spatial_merge_size, + video_needs_metadata=True, + ) def _call_hf_processor( self, @@ -981,14 +1001,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) image_embeds=MultiModalFieldConfig.flat_from_sizes( "image", image_grid_sizes ), - image_grid_thw=MultiModalFieldConfig.batched("image"), + image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes ), video_embeds=MultiModalFieldConfig.flat_from_sizes( "video", video_grid_sizes ), - video_grid_thw=MultiModalFieldConfig.batched("video"), + video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), ) def _get_prompt_updates( @@ -1039,13 +1059,39 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) for curr_time in timestamps ] - num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token] + + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if video_pruning_rate is not None and video_pruning_rate > 0.0: + total_retained = compute_retained_tokens_count( + tokens_per_frame, + len(frames_idx_token), + video_pruning_rate, + ) + if len(frames_idx_token) == 0: + per_frame_token_counts = [] + elif len(frames_idx_token) == 1: + per_frame_token_counts = [tokens_per_frame] + else: + first_frame_tokens = tokens_per_frame + remaining_tokens = max(total_retained - first_frame_tokens, 0) + base = remaining_tokens // (len(frames_idx_token) - 1) + remainder = remaining_tokens % (len(frames_idx_token) - 1) + per_frame_token_counts = [first_frame_tokens] + for frame_idx in range(1, len(frames_idx_token)): + extra = base + (1 if (frame_idx - 1) < remainder else 0) + per_frame_token_counts.append(extra) + placeholder = [] - for frame_idx in frames_idx_token: - placeholder.extend(frame_idx) + for frame_idx, timestamp_tokens in enumerate(frames_idx_token): + placeholder.extend(timestamp_tokens) + tokens_this_frame = per_frame_token_counts[ + frame_idx if frame_idx < len(per_frame_token_counts) else -1 + ] placeholder.extend( [vision_start_token_id] - + [video_token_id] * num_tokens_per_frame + + [video_token_id] * tokens_this_frame + [vision_end_token_id] ) return PromptUpdateDetails.select_token_id(placeholder, video_token_id) @@ -1186,10 +1232,8 @@ class Qwen3VLForConditionalGeneration( SupportsPP, SupportsMRoPE, SupportsEagle3, + SupportsMultiModalPruning, ): - merge_by_field_config = True - multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} - packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1231,23 +1275,22 @@ class Qwen3VLForConditionalGeneration( self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) + if not multimodal_config.get_limit_per_prompt( "image" ) and not multimodal_config.get_limit_per_prompt("video"): self.visual = None else: - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - attn_backend_override=attn_backend_override, ) self.language_model = Qwen3LLMForCausalLM( @@ -1417,6 +1460,109 @@ class Qwen3VLForConditionalGeneration( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) + def _postprocess_image_embeds_evs( + self, + image_embeds_split: tuple[torch.Tensor, ...], + image_input: Qwen2_5_VLImageInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Append mrope positions for each for images. + This is necessary to recover correct mrope + positions after video pruning + + Args: + image_embeds_split: Tuple of image embeddings for + each image item. + image_input: Image input data. + + Returns: + Tuple of image embeddings for each image item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + merge_size = self.visual.spatial_merge_size + grid_thw = image_input["image_grid_thw"] + grid_thw_list = grid_thw.tolist() + image_embeds_out = [] + for emb, size in zip(image_embeds_split, grid_thw_list): + positions = compute_mrope_for_media(size, merge_size).to(emb.device) + emb = torch.cat([emb, positions], dim=1) + image_embeds_out.append(emb) + image_embeds_split = image_embeds_out + return tuple(image_embeds_split) + + def _postprocess_video_embeds_evs( + self, + video_embeds_split: tuple[torch.Tensor, ...], + video_input: Qwen2_5_VLVideoInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Prunes video embeddings via Efficient Video Sampling (EVS) + and then appends mrope positions for each retained embeddings + + Args: + video_embeds_split: Tuple of video embeddings for each video item. + video_input: Video input data. + + Returns: + Tuple of video embeddings for each video item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + merge_size = self.visual.spatial_merge_size + + # Cast to long to match the original code + # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa + second_per_grid_ts = video_input.get("second_per_grid_ts") + if second_per_grid_ts is None: + # For Qwen3-VL, second_per_grid_ts might not be available + # Use default value of 1.0 for each video + second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long) + else: + second_per_grid_ts = second_per_grid_ts.long() + tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0) + + video_embeds_out = [] + for emb, size, video_second_per_grid_t in zip( + video_embeds_split, grid_thw_list, second_per_grid_ts + ): + # For each video, we compute retention mask using EVS + retention_mask = compute_retention_mask( + emb, + size, + spatial_merge_size=self.visual.spatial_merge_size, + q=self.video_pruning_rate, + ) + + # Debug logging for EVS pruning + logger.debug( + "EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, " + "pruning_rate=%.2f, reduction=%.1f%%)", + emb.shape[0], + retention_mask.sum().item(), + size[0], + size[1], + size[2], + self.video_pruning_rate, + (1 - retention_mask.float().mean().item()) * 100, + ) + + positions = compute_mrope_for_media( + size, + merge_size, + tokens_per_second=tokens_per_second, + video_second_per_grid=video_second_per_grid_t.item(), + ).to(emb.device) + + emb = emb[retention_mask] + positions = positions[retention_mask] + emb = torch.cat([emb, positions], dim=1) + video_embeds_out.append(emb) + return tuple(video_embeds_out) + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} for input_key in kwargs: @@ -1439,6 +1585,20 @@ class Qwen3VLForConditionalGeneration( def iter_mm_grid_hw( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] ) -> Iterator[tuple[int, int, int]]: + """ + Iterate over multimodal features and yield grid information. + + For videos with EVS (Efficient Video Sampling) enabled, this function + computes the offset based on the pruned token count rather than relying + on input_tokens.index(), which would fail when tokens are pruned. + + Args: + input_tokens: List of token IDs in the prompt + mm_features: List of multimodal feature specifications + + Yields: + Tuple of (offset, grid_h, grid_w) for each frame/image + """ video_token_id = self.config.video_token_id spatial_merge_size = self.config.vision_config.spatial_merge_size for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): @@ -1451,42 +1611,289 @@ class Qwen3VLForConditionalGeneration( t, h, w = mm_feature.data["video_grid_thw"].data.tolist() llm_grid_h = h // spatial_merge_size llm_grid_w = w // spatial_merge_size - for _ in range(t): - offset = input_tokens.index(video_token_id, offset) - yield offset, llm_grid_h, llm_grid_w - offset += llm_grid_h * llm_grid_w + + # Check if EVS (Efficient Video Sampling) is enabled + is_evs_enabled = ( + hasattr(self, "video_pruning_rate") + and self.video_pruning_rate is not None + and self.video_pruning_rate > 0.0 + ) + + if is_evs_enabled: + frame_offsets = self._extract_frame_offsets_from_mask( + mm_feature.mm_position, t + ) + if frame_offsets is not None: + for rel_offset in frame_offsets: + yield offset + rel_offset, llm_grid_h, llm_grid_w + continue + + # If EVS is enabled but mask is missing, this indicates a bug + # in the prompt processing pipeline. The is_embed mask should + # always be present when video_pruning_rate > 0. + raise RuntimeError( + f"EVS is enabled (pruning_rate={self.video_pruning_rate}) " + "but is_embed mask is missing from mm_position. " + "This indicates a bug in prompt processing." + ) + else: + # Non-EVS mode: Use original logic with input_tokens.index() + for _ in range(t): + offset = input_tokens.index(video_token_id, offset) + yield offset, llm_grid_h, llm_grid_w + offset += llm_grid_h * llm_grid_w else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") + def _get_evs_mask_segments( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[torch.Tensor] | None: + """Extract contiguous segments from EVS is_embed mask. + + The EVS (Efficient Video Sampling) mask marks which placeholder + positions should be filled with video embeddings. This method splits + the mask into contiguous segments, where each segment represents one + retained frame. + + This is a pure function - it does not modify any state and always + returns the same output for the same input (idempotent). + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frame segments + + Returns: + List of tensors, each containing indices for one frame segment, + or None if EVS is not enabled or validation fails. + """ + is_embed_mask = getattr(mm_position, "is_embed", None) + if is_embed_mask is None: + return None + + # Find all True positions in the mask + mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1) + true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten() + if true_indices.numel() == 0: + return None + + # Split into contiguous segments (where diff > 1 indicates a gap) + if true_indices.numel() == 1: + segments = [true_indices] + else: + diffs = torch.diff(true_indices) + split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten() + if split_points.numel() == 0: + segments = [true_indices] + else: + segments = torch.tensor_split( + true_indices, split_points.add(1).tolist() + ) + + # Validate segment count matches expected frames + if len(segments) < expected_frames: + logger.debug( + "EVS mask segments (%d) do not match expected frames (%d)", + len(segments), + expected_frames, + ) + return None + + return segments[:expected_frames] + + def _extract_frame_offsets_from_mask( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[int] | None: + """Return relative offsets for each EVS-retained frame. + + The prompt processor stores a boolean mask inside ``mm_position`` that + marks which placeholder locations should be populated with video + embeddings. By splitting that mask into contiguous runs we can recover + the start of every retained frame without probing ``input_tokens``. + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frames + + Returns: + List of starting offsets (relative to mm_position) for each frame, + or None if EVS is not enabled. + """ + segments = self._get_evs_mask_segments(mm_position, expected_frames) + if segments is None: + return None + + return [int(segment[0].item()) for segment in segments] + + def _get_actual_frame_token_counts( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[int] | None: + """Return actual token count for each EVS-retained frame. + + This function calculates the actual number of tokens per frame by + analyzing the is_embed mask, accounting for EVS pruning. Each frame + may have a different token count due to content-aware pruning. + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frames + + Returns: + List of token counts for each frame, or None if EVS is not enabled. + """ + segments = self._get_evs_mask_segments(mm_position, expected_frames) + if segments is None: + return None + + return [len(seg) for seg in segments] + + def recompute_mrope_positions( + self, + input_ids: list[int], + multimodal_embeddings: tuple[torch.Tensor, ...], + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]: + """ + Update part of input mrope positions (starting with + num_computed_tokens index). Original mrope_positions are computed + for unpruned sequence and becomes incorrect once pruning occurs, + so once we prune media tokens we should reflect this in the + mrope_positions before we feed it to LLM. + + Args: + input_ids: (N,) All input tokens of the prompt (Containing + entire sequence). + multimodal_embeddings: Tuple of multimodal embeddings. + mrope_positions: Existing mrope positions (3, N) for entire + sequence + num_computed_tokens: A number of computed tokens so far. + + Returns: + Tuple of (multimodal_embeddings, mrope_positions, + mrope_position_delta). + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + # Device + device = ( + multimodal_embeddings[0].device + if len(multimodal_embeddings) + else mrope_positions.device + ) + + # Tensors + input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) + + mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] + mm_embeddings_pos = [ + mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings + ] + + positions, mrope_positions_delta = recompute_mrope_positions( + input_ids_t, + mm_embeddings_pos, + mrope_positions, + num_computed_tokens, + vision_start_token_id, + image_token_id, + video_token_id, + ) + + return tuple(mm_embeddings_out), positions, mrope_positions_delta + def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + # Pre-collect actual frame token counts for EVS mode + frame_token_counts_map = {} + for mm_feature in mm_features: + if mm_feature.modality == "video": + is_evs_enabled = ( + hasattr(self, "video_pruning_rate") + and self.video_pruning_rate is not None + and self.video_pruning_rate > 0.0 + ) + if is_evs_enabled: + t = mm_feature.data["video_grid_thw"].data.tolist()[0] + token_counts = self._get_actual_frame_token_counts( + mm_feature.mm_position, t + ) + assert token_counts is not None, ( + "EVS enabled but failed to extract frame token counts " + "from is_embed mask" + ) + frame_token_counts_map[mm_feature.mm_position.offset] = token_counts + llm_pos_ids_list = [] st = 0 + frame_counts_idx = {} + for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( input_tokens, mm_features ): text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( + + # Determine actual token count for this frame + base_offset = None + for feat_offset in frame_token_counts_map: + if offset >= feat_offset: + base_offset = feat_offset + + if base_offset is not None: + # EVS mode: use actual token count from is_embed mask + assert base_offset in frame_token_counts_map, ( + f"Found base_offset {base_offset} but not in frame_token_counts_map" + ) + + if base_offset not in frame_counts_idx: + frame_counts_idx[base_offset] = 0 + + counts = frame_token_counts_map[base_offset] + idx = frame_counts_idx[base_offset] + + assert idx < len(counts), ( + f"EVS frame index {idx} out of range (total frames: {len(counts)})" + ) + + actual_frame_tokens = counts[idx] + frame_counts_idx[base_offset] += 1 + else: + # Non-EVS mode (or image): use theoretical grid size + actual_frame_tokens = llm_grid_h * llm_grid_w + + # Add text segment + text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) + llm_pos_ids_list.append(text_positions) + st_idx += text_len + # Add frame segment with actual token count (not theoretical) grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) - llm_pos_ids_list.append(grid_indices + text_len + st_idx) - st = offset + llm_grid_h * llm_grid_w + # Only take the first actual_frame_tokens positions + frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx + llm_pos_ids_list.append(frame_positions) + # Update st using actual token count + st = offset + actual_frame_tokens + + # Handle final text segment if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - llm_pos_ids_list.append( + final_text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) + llm_pos_ids_list.append(final_text_positions) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return torch.from_numpy(llm_positions), mrope_position_delta def get_language_model(self) -> torch.nn.Module: @@ -1507,9 +1914,17 @@ class Qwen3VLForConditionalGeneration( multimodal_input = mm_input_by_modality[modality] if modality == "image": image_embeddings = self._process_image_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + image_embeddings = self._postprocess_image_embeds_evs( + image_embeddings, multimodal_input + ) multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + video_embeddings = self._postprocess_video_embeds_evs( + video_embeddings, multimodal_input + ) multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -1572,12 +1987,7 @@ class Qwen3VLForConditionalGeneration( if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds - if is_multimodal is None: - raise ValueError( - "`embed_input_ids` now requires `is_multimodal` arg, " - "please update your model runner according to " - "https://github.com/vllm-project/vllm/pull/16229." - ) + is_multimodal = _require_is_multimodal(is_multimodal) if self.use_deepstack: ( diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index a054bd5b3831e..3186804488e57 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -419,6 +419,10 @@ class Qwen3VLMoeForConditionalGeneration( self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) if not multimodal_config.get_limit_per_prompt( "image" @@ -429,8 +433,8 @@ class Qwen3VLMoeForConditionalGeneration( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, ) self.language_model = Qwen3MoeLLMForCausalLM( diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 55680b8e7ddfd..caac14716782a 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -703,8 +703,6 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): class QwenVLForConditionalGeneration( QWenBaseModel, SupportsPP, SupportsLoRA, SupportsMultiModal ): - merge_by_field_config = True - packed_modules_mapping = { "c_attn": ["c_attn"], "gate_up_proj": [ diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 73a61f1148b50..4575e91e13959 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -145,6 +145,7 @@ _TEXT_GENERATION_MODELS = { "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), + "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), @@ -263,10 +264,15 @@ _CROSS_ENCODER_MODELS = { _MULTIMODAL_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), + "AudioFlamingo3ForConditionalGeneration": ( + "audioflamingo3", + "AudioFlamingo3ForConditionalGeneration", + ), "AyaVisionForConditionalGeneration": ( "aya_vision", "AyaVisionForConditionalGeneration", ), + "BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"), "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ( @@ -373,7 +379,6 @@ _MULTIMODAL_MODELS = { ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), - "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501 "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 @@ -424,6 +429,10 @@ _SPECULATIVE_DECODING_MODELS = { "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "EagleMistralLarge3ForCausalLM": ( + "mistral_large_3_eagle", + "EagleMistralLarge3ForCausalLM", + ), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), @@ -502,6 +511,7 @@ _PREVIOUSLY_SUPPORTED_MODELS = { "MotifForCausalLM": "0.10.2", "Phi3SmallForCausalLM": "0.9.2", "Phi4FlashForCausalLM": "0.10.2", + "Phi4MultimodalForCausalLM": "0.12.0", # encoder-decoder models except whisper # have been removed for V0 deprecation. "BartModel": "0.10.2", diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 31cc645099141..45b6e93307ac3 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module): torch.arange(config.max_position_embeddings).unsqueeze(0), ) - self.position_embedding_type = config.position_embedding_type - if self.position_embedding_type != "absolute": - raise ValueError( - "Only 'absolute' position_embedding_type" + " is supported" - ) - def forward( self, input_ids: torch.Tensor, @@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel): def _build_model( self, vllm_config: VllmConfig, prefix: str = "" ) -> BertModel | BertWithRope: - if vllm_config.model_config.hf_config.position_embedding_type == "rotary": - return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) + hf_config = vllm_config.model_config.hf_config + kwargs = dict(vllm_config=vllm_config, prefix=prefix) + if getattr(hf_config, "position_embedding_type", "absolute") == "absolute": + return BertModel(**kwargs, embedding_class=RobertaEmbedding) else: - return BertModel( - vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding - ) + return JinaRobertaModel(**kwargs) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights_list = list(weights) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index 267c60157506d..f25223c782552 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -161,7 +161,6 @@ class SeedOssAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position, rope_parameters=rope_parameters, ) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 9db1423d98e07..2600dc1c9f79c 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -989,7 +989,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): is_pooling_model = True packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index bbce01995412c..efdee255ab5eb 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -6,14 +6,14 @@ within a vision language model.""" from collections.abc import Iterable import torch -from einops import rearrange, repeat from torch import nn from torch.nn import functional as F from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import maybe_get_vit_flash_attn_backend +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.config import MultiModalConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer @@ -25,11 +25,12 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + ApplyRotaryEmb, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.platforms import current_platform -from .vision import get_vit_attn_backend - class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: @@ -147,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module): return patch_embeds -# copy from flash_attn/layers/rotary.py -def rotate_half(x, interleaved=False): - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch(x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, @@ -190,14 +157,20 @@ def apply_rotary_pos_emb( ) -> tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() - if is_flash_attn_backend and not current_platform.is_xpu(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - apply_rotary_emb_func = apply_rotary_emb + apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + enable_fp32_compute=True, + ) + + if is_flash_attn_backend and not current_platform.is_cuda(): + apply_rotary_emb_func = apply_rotary_emb.forward_cuda else: - apply_rotary_emb_func = apply_rotary_emb_torch - q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q) - k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k) + apply_rotary_emb_func = apply_rotary_emb.forward_native + + q_embed = apply_rotary_emb_func(q, cos, sin) + k_embed = apply_rotary_emb_func(k, cos, sin) + return q_embed, k_embed @@ -208,6 +181,7 @@ class Siglip2Attention(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", use_data_parallel: bool = False, attn_backend_override: AttentionBackendEnum | None = None, @@ -227,20 +201,25 @@ class Siglip2Attention(nn.Module): self.dropout = config.attention_dropout self.is_causal = False - # TODO(Isotr0py): Enable data parallel after we support - # disabling TP on parallel linear layer + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, head_size=self.head_dim, total_num_heads=self.num_heads, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, ) self.out_proj = RowParallelLinear( input_size=self.embed_dim, output_size=self.embed_dim, quant_config=quant_config, prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, ) self.tp_size = ( @@ -249,31 +228,13 @@ class Siglip2Attention(nn.Module): self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.use_rope = config.use_rope - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend( + self.attn = MMEncoderAttention( + num_heads=self.num_heads_per_partition, head_size=self.head_dim, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, + prefix=f"{prefix}.attn", + multimodal_config=multimodal_config, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) - ) - - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - self.attn_backend = AttentionBackendEnum.TORCH_SDPA - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } - def forward( self, hidden_states: torch.Tensor, @@ -298,46 +259,23 @@ class Siglip2Attention(nn.Module): keys.unsqueeze(0), cos, sin, - self.is_flash_attn_backend, + self.attn.is_flash_attn_backend, ) queries = queries.squeeze(0) keys = keys.squeeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - if self.is_flash_attn_backend: - attn_output = self.flash_attn_varlen_func( - queries, - keys, - values, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - ).reshape(seq_length, -1) - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - # Execute attention entry by entry for speed & less VRAM. - batch_size = cu_seqlens.shape[0] - 1 - outputs = [] - cu = cu_seqlens.tolist() - for i in range(batch_size): - start_idx = cu[i] - end_idx = cu[i + 1] + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output = self.attn( + query=queries.unsqueeze(0), + key=keys.unsqueeze(0), + value=values.unsqueeze(0), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + attn_output = attn_output.reshape( + seq_length, self.num_heads_per_partition * self.head_dim + ) - # Each sequence is processed independently. - q_i = queries[start_idx:end_idx].unsqueeze(0) - k_i = keys[start_idx:end_idx].unsqueeze(0) - v_i = values[start_idx:end_idx].unsqueeze(0) - - # (1, seq_len, num_heads, head_dim) -> - # (1, num_heads, seq_len, head_dim) - q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)] - - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim) - output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1) - outputs.append(output_i) - - attn_output = torch.cat(outputs, dim=0) attn_output, _ = self.out_proj(attn_output) return attn_output @@ -347,25 +285,30 @@ class Siglip2MLP(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, ): super().__init__() self.config = config + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.activation_fn = get_act_fn(config.hidden_act) - # TODO(Isotr0py): Enable data parallel after we support - # disabling TP on parallel linear layer self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -380,9 +323,8 @@ class Siglip2EncoderLayer(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -390,16 +332,15 @@ class Siglip2EncoderLayer(nn.Module): self.self_attn = Siglip2Attention( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Siglip2MLP( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, ) def forward( @@ -444,9 +385,8 @@ class Siglip2Encoder(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -455,9 +395,8 @@ class Siglip2Encoder(nn.Module): Siglip2EncoderLayer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.layers.{idx}", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) for idx in range(config.num_hidden_layers) ] @@ -630,9 +569,8 @@ class Siglip2VisionTransformer(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -642,9 +580,8 @@ class Siglip2VisionTransformer(nn.Module): self.encoder = Siglip2Encoder( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.encoder", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -671,18 +608,16 @@ class Siglip2NavitModel(torch.nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.vision_model = Siglip2VisionTransformer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.vision_model", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, ) def forward( diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 55c25ce6190fb..f95fbffc1d0b4 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -647,8 +647,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[SkyworkR1VProcessing dummy_inputs=SkyworkR1VDummyInputsBuilder, ) class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 7bef56110cab7..964aa902704b3 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -160,7 +160,6 @@ class SolarAttention(nn.Module): self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embeddings, rope_parameters=config.rope_parameters, ) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 65092584edced..ea4342882feb4 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -119,9 +119,6 @@ class StablelmAttention(nn.Module): self.num_key_value_heads = max(1, self.total_num_key_value_heads // tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - self.partial_rotary_factor = getattr( - config, "rope_pct", getattr(config, "partial_rotary_factor", 1) - ) self.scaling = self.head_dim**-0.5 self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim @@ -151,10 +148,8 @@ class StablelmAttention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.config.max_position_embeddings, rope_parameters=self.config.rope_parameters, - partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention( self.num_heads, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 46422f303ff43..569ca9b082cfa 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -112,7 +112,6 @@ class Starcoder2Attention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=self.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 077cce84a98dd..7077f1a22e8d7 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -196,7 +196,6 @@ class Step3TextAttention(nn.Module): ) self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, max_position=max_position_embedding, rope_parameters=rope_parameters, ) diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 3e55ada0ed2e1..e5038e56a2708 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -916,8 +916,6 @@ class Step3VisionTransformer(nn.Module): dummy_inputs=Step3VLDummyInputsBuilder, ) class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.": "language_model.model.", diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py index a74fd80c06d8c..fbf5594851ece 100644 --- a/vllm/model_executor/models/swin.py +++ b/vllm/model_executor/models/swin.py @@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module): self, hidden_states: torch.Tensor, attention_mask: torch.FloatTensor | None = None, - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = False, ) -> tuple[torch.Tensor, ...]: batch_size, dim, num_channels = hidden_states.shape @@ -201,12 +200,9 @@ class SwinAttention(nn.Module): self, hidden_states: torch.Tensor, attention_mask: torch.FloatTensor | None = None, - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = False, ) -> tuple[torch.Tensor]: - self_outputs = self.self( - hidden_states, attention_mask, head_mask, output_attentions - ) + self_outputs = self.self(hidden_states, attention_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] return outputs @@ -339,18 +335,14 @@ class SwinStage(nn.Module): self, hidden_states: torch.Tensor, input_dimensions: tuple[int, int], - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = False, always_partition: bool | None = False, ) -> tuple[torch.Tensor]: height, width = input_dimensions for i, layer_module in enumerate(self.blocks): - layer_head_mask = head_mask[i] if head_mask is not None else None - layer_outputs = layer_module( hidden_states, input_dimensions, - layer_head_mask, output_attentions, always_partition, ) @@ -425,17 +417,13 @@ class SwinEncoder(nn.Module): self, hidden_states: torch.Tensor, input_dimensions: tuple[int, int], - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = False, always_partition: bool | None = False, ) -> tuple[torch.Tensor]: for i, layer_module in enumerate(self.layers): - layer_head_mask = head_mask[i] if head_mask is not None else None - layer_outputs = layer_module( hidden_states, input_dimensions, - layer_head_mask, output_attentions, always_partition, ) @@ -473,7 +461,6 @@ class SwinModel(nn.Module): def forward( self, pixel_values: torch.FloatTensor | None = None, - head_mask: torch.FloatTensor | None = None, output_attentions: bool | None = None, ) -> tuple[torch.Tensor]: embedding_output, input_dimensions = self.embeddings(pixel_values) @@ -481,7 +468,6 @@ class SwinModel(nn.Module): encoder_outputs = self.encoder( embedding_output, input_dimensions, - head_mask=head_mask, output_attentions=output_attentions, ) diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 4d310712f303e..7e82a4d725a62 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -400,8 +400,6 @@ def init_vision_tower_for_tarsier( dummy_inputs=TarsierDummyInputsBuilder, ) class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 19052c8d49e44..402081a70631e 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -64,7 +64,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal -from .interfaces_base import default_pooling_type +from .interfaces_base import attn_type logger = init_logger(__name__) @@ -220,14 +220,13 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): ) -@default_pooling_type("All") +@attn_type("attention_free") @MULTIMODAL_REGISTRY.register_processor( TerratorchMultiModalProcessor, info=TerratorchProcessingInfo, dummy_inputs=TerratorchInputBuilder, ) class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): - merge_by_field_config = True supports_multimodal_raw_input_only = True is_pooling_model = True diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index f3ebc6da8e302..45e746ac2d356 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -36,6 +36,8 @@ from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.interfaces import ( + SupportsEagle, + SupportsEagle3, SupportsLoRA, SupportsPP, SupportsQuant, @@ -92,7 +94,15 @@ def vllm_flash_attention_forward( ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward -class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): +class Base( + nn.Module, + VllmModel, + SupportsQuant, + SupportsLoRA, + SupportsPP, + SupportsEagle, + SupportsEagle3, +): embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -131,17 +141,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): self.pp_group = get_pp_group() self.tp_group = get_tp_group() - # Weights to skip in `self.load_weights` + # Attrs for weight loading (see self.load_weights) self.skip_prefixes: list[str] = [] """Skip loading weights whose qualname starts with these prefixes.""" self.skip_substrs: list[str] = [] """Skip loading weights whose qualname contains these substrings.""" self.ignore_unexpected_prefixes: list[str] = [] - """Ignore unexpected weights whose qualname starts with these prefixes. - """ + """Ignore unexpected weights whose qualname starts with these prefixes.""" self.ignore_unexpected_suffixes: list[str] = [] """Ignore unexpected weights whose qualname ends with these suffixes.""" + # Attrs for Eagle3 (see self.set_aux_hidden_state_layers) + self._target_class: type[nn.Module] = nn.Module + """Target class for Eagle3 aux hidden state recording.""" + self._layer_names: dict[int, str] = {} + """Mapping from layer index to layer name for Eagle3.""" + self._output_aux_hidden_states_kwargs: dict[str, bool] = {} + """Kwargs to pass to model forward for Eagle3 aux hidden states.""" + if self.quant_config: quant_method_name = self.quant_config.get_name() # Check for unsupported quantization methods. @@ -278,6 +295,15 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): for child_name, child_module in module.named_children(): new_module = child_module qual_name = maybe_prefix(prefix, child_name) + # Populate Eagle3 attrs + if ( + isinstance(module, nn.ModuleList) + and len(module) == self.text_config.num_hidden_layers + ): + self._target_class = type(child_module) + layer_name = qual_name.removeprefix("model.") + self._layer_names[int(child_name)] = layer_name + # Replace modules as needed if isinstance(child_module, nn.Linear): generator = (p for p in tp_plan if re.match(p, qual_name)) pattern = next(generator, None) @@ -425,19 +451,26 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): else: position_ids = positions[None, ...] - hidden_states = self.model( + outputs = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, use_cache=False, position_ids=position_ids, attention_instances=self.attention_instances, return_dict=False, + **self._output_aux_hidden_states_kwargs, **kwargs, - )[0][0, ...] # we remove batch dimension for now + ) + # We must remove the batch dimension from these outputs + hidden_states = outputs[0][0, ...] + if self._output_aux_hidden_states_kwargs: + aux_hidden_states = [x[0][0, ...] for x in outputs[1:]] if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) + if self._output_aux_hidden_states_kwargs and len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def load_weights( @@ -462,3 +495,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): f"Transformers modeling backend requires transformers>={required} " f"for {feature}, but got {installed}" ) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.check_version("5.0.0.dev0", "Eagle3 support") + from transformers.utils.generic import OutputRecorder + + # The default value in PreTrainedModel is None + if self.model._can_record_outputs is None: + self.model._can_record_outputs = {} + + target_class = self._target_class + for layer in layers: + # layer - 1 because we want the input to the layer + layer_name = self._layer_names[layer - 1] + layer_key = f"aux_hidden_state_{layer}" + aux_hidden_state_i = OutputRecorder(target_class, layer_name=layer_name) + self.model._can_record_outputs[layer_key] = aux_hidden_state_i + self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = self.text_config.num_hidden_layers + return (2, num_layers // 2, num_layers - 3) diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index ccf6053719871..9d77dee2810c3 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -264,7 +264,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): supports_multimodal_raw_input_only = True - merge_by_field_config = True + # Backwards compatibility for prev released models. State dicts back then # had different formats and cannot be loaded with `AutoModel` mapping as is hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 26a8355cd22b5..7e1b7c90c9204 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -4,15 +4,22 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" +import copy +import inspect from collections.abc import Iterable, Mapping, Sequence +from types import SimpleNamespace from typing import Annotated, Any, Literal, TypeAlias import torch from torch import nn from torch.nn import functional as F from transformers import BatchFeature, ProcessorMixin +from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.whisper import WhisperFeatureExtractor -from transformers.models.whisper.modeling_whisper import WhisperEncoder +from transformers.models.whisper.modeling_whisper import ( + WhisperEncoder, + WhisperEncoderLayer, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -282,7 +289,7 @@ class StackAudioFrames(nn.Module): return audio_embeds -class UltravoxProjector(nn.Module): +class UltravoxFeedForwardProjector(nn.Module): def __init__(self, config: UltravoxConfig): super().__init__() self.hidden_dim = config.hidden_size @@ -310,7 +317,9 @@ class UltravoxProjector(nn.Module): self.ln_mid = nn.Identity() self.ln_post = RMSNorm(dim_out) - def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + def forward( + self, audio_features: torch.Tensor, audio_token_len: torch.Tensor + ) -> torch.Tensor: audio_features = self._pad_and_stack(audio_features) audio_features = self.ln_pre(audio_features) hidden_states = self.linear_1(audio_features) @@ -321,6 +330,76 @@ class UltravoxProjector(nn.Module): return hidden_states +class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin): + def __init__(self, config: UltravoxConfig): + super().__init__() + self.config = SimpleNamespace(is_decoder=False) + + self._pad_and_stack = StackAudioFrames(config.stack_factor) + dim_in = config.audio_config.hidden_size * config.stack_factor + + projector_audio_config = copy.deepcopy(config.audio_config) + + self.ln_pre = RMSNorm(dim_in) + self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model) + + self.embed_positions = nn.Embedding( + projector_audio_config.max_source_positions, + projector_audio_config.d_model, + ) + + self.layers = nn.ModuleList( + [ + WhisperEncoderLayer(projector_audio_config) + for _ in range(config.num_projector_layers) + ] + ) + + self.ln_post = RMSNorm(projector_audio_config.d_model) + self.linear_out = nn.Linear( + projector_audio_config.d_model, config.text_config.hidden_size + ) + + def forward( + self, audio_features: torch.Tensor, audio_token_len: torch.Tensor + ) -> torch.Tensor: + audio_features = self._pad_and_stack(audio_features) + + max_len_stacked = audio_features.shape[1] + attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[ + None, : + ].lt(audio_token_len[:, None]) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape, audio_features.dtype + ) + + hidden_states = self.ln_pre(audio_features) + hidden_states = self.linear_in(hidden_states) + + positions = self.embed_positions( + torch.arange(hidden_states.size(1), device=hidden_states.device) + ) + hidden_states = hidden_states + positions + + # Backward compatibility for Transformers v4 where layer_head_mask + # was a required argument for WhisperEncoderLayer.forward + kwargs = {} + if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters: + kwargs["layer_head_mask"] = None + + for layer in self.layers: + layer_outputs = layer( + hidden_states, + attention_mask=extended_attention_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.linear_out(hidden_states) + return hidden_states + + class ModifiedWhisperEncoder(WhisperEncoder): """ Encoder portion of OpenAI's Whisper model. @@ -407,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder): attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states) + # Backward compatibility for Transformers v4 where layer_head_mask + # was a required argument for WhisperEncoderLayer.forward + kwargs = {} + if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters: + kwargs["layer_head_mask"] = None + for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, attention_mask, - layer_head_mask=None, + **kwargs, ) hidden_states = layer_outputs[0] @@ -426,8 +511,6 @@ class ModifiedWhisperEncoder(WhisperEncoder): dummy_inputs=UltravoxDummyInputsBuilder, ) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): - merge_by_field_config = True - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -464,7 +547,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): prefix="audio_tower.", ) ) - self.multi_modal_projector = UltravoxProjector(config) + if config.num_projector_layers > 0: + self.multi_modal_projector = UltravoxTransformerProjector(config) + else: + self.multi_modal_projector = UltravoxFeedForwardProjector(config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.wrapped_model_config, @@ -496,7 +582,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ) def _audio_features_to_embeddings( - self, input_features: torch.Tensor, audio_lens: torch.Tensor + self, + input_features: torch.Tensor, + audio_lens: torch.Tensor, + audio_token_len: torch.Tensor, ) -> torch.Tensor: audio_features = input_features.to(self.audio_tower.dtype) batch_size = audio_features.size(0) @@ -512,7 +601,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): batch_features = batch_features.to(self.audio_tower.dtype) # Process through projector - batch_embeddings = self.multi_modal_projector(batch_features) + batch_embeddings = self.multi_modal_projector( + batch_features, audio_token_len[start:end] + ) audio_embeddings.append(batch_embeddings) # Concatenate results @@ -559,7 +650,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): audio_lens = audio_input["lens"] audio_token_len = audio_input["token_len"] - embeddings = self._audio_features_to_embeddings(audio_features, audio_lens) + embeddings = self._audio_features_to_embeddings( + audio_features, audio_lens, audio_token_len + ) # We should flatten and concatenate embeddings based on token lengths # For example, with token_len = [4, 2, 3], flattened_embeddings will be diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index e5d70eb7bc2fc..024c50f1207ed 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -88,17 +88,11 @@ def get_vit_attn_backend( """ Get the available attention backend for Vision Transformer. """ - if attn_backend_override is not None: - return attn_backend_override - - # Lazy import to avoid circular dependency - from vllm.attention.selector import get_env_variable_attn_backend - - selected_backend: AttentionBackendEnum | None = get_env_variable_attn_backend() - if selected_backend is not None: - return selected_backend - - return current_platform.get_vit_attn_backend(head_size, dtype) + return current_platform.get_vit_attn_backend( + head_size, + dtype, + backend=attn_backend_override, + ) def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 0a39ea7ef5bff..331f0c54ecfbc 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -51,8 +51,8 @@ from vllm.multimodal.processing import ( ) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.tokenizers import MistralTokenizer -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.tokenizers.mistral import MistralTokenizer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .utils import init_vllm_registered_model, maybe_prefix @@ -331,8 +331,6 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) class VoxtralForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription ): - merge_by_field_config = True - supported_languages = ISO639_1_SUPPORTED_LANGS packed_modules_mapping = { diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index c72b5e1c091f2..b513e3513b2e2 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -48,7 +48,7 @@ from vllm.multimodal.processing import ( PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.transformers_utils.processor import cached_get_processor +from vllm.transformers_utils.processor import cached_processor_from_config from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.torch_utils import set_default_torch_dtype @@ -522,6 +522,7 @@ class WhisperEncoder(nn.Module): def forward(self, input_features: torch.Tensor | list[torch.Tensor]): hidden_states = [] + input_is_batched = False for features in input_features: embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv2(embeds)) @@ -530,7 +531,13 @@ class WhisperEncoder(nn.Module): embeds.dtype ) hidden_states.append(embeds) - hidden_states = torch.cat(hidden_states) + input_is_batched = embeds.ndim > 2 + # Input to MHA must be B x T x D + if input_is_batched: + # Models using WhisperEncoder may handle batching internally. + hidden_states = torch.cat(hidden_states) + else: + hidden_states = torch.stack(hidden_states, dim=0) for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) @@ -603,8 +610,7 @@ class WhisperModel(nn.Module): positions: torch.Tensor, encoder_outputs: list[torch.Tensor], ) -> torch.Tensor: - assert len(encoder_outputs) in (0, 1) - enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None + enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None decoder_outputs = self.decoder( input_ids=input_ids, positions=positions, @@ -775,7 +781,6 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo class WhisperForConditionalGeneration( nn.Module, SupportsTranscription, SupportsMultiModal ): - merge_by_field_config = True packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", @@ -791,6 +796,7 @@ class WhisperForConditionalGeneration( # Whisper only supports audio-conditioned generation. supports_transcription_only = True + supports_segment_timestamp = True supported_languages = ISO639_1_SUPPORTED_LANGS @classmethod @@ -849,7 +855,7 @@ class WhisperForConditionalGeneration( def get_speech_to_text_config( cls, model_config: ModelConfig, task_type: str ) -> SpeechToTextConfig: - processor = cached_get_processor(model_config.model) + processor = cached_processor_from_config(model_config) return SpeechToTextConfig( max_audio_clip_s=processor.feature_extractor.chunk_length, @@ -863,7 +869,7 @@ class WhisperForConditionalGeneration( stt_config: SpeechToTextConfig, model_config: ModelConfig, ) -> int | None: - processor = cached_get_processor(model_config.model) + processor = cached_processor_from_config(model_config) hop_length = processor.feature_extractor.hop_length assert hop_length is not None # NOTE(NickLucche) user can't pass encoder @@ -913,7 +919,10 @@ class WhisperForConditionalGeneration( def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) - return [self.model.get_encoder_outputs(audio_input["input_features"])] + # Split concatenated encoder outputs into one tensor per audio input + enc_output = self.model.get_encoder_outputs(audio_input["input_features"]) + # The assumption is we can only process whole mm items (audios) + return enc_output.unbind(dim=0) def embed_input_ids( self, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 653b5b9beef7b..fe157887eea91 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -230,7 +230,6 @@ class Zamba2Attention(nn.Module): if config.use_mem_rope: self.rotary_emb = get_rope( head_size=self.attention_head_dim, - rotary_dim=self.attention_head_dim, max_position=config.max_position_embeddings, rope_parameters=config.rope_parameters, is_neox_style=True, diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 8aad59e84ff25..b89371d987541 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -50,6 +50,31 @@ def set_weight_attrs( setattr(weight, key, value) +def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor): + """ + Replace a parameter of a layer while maintaining the ability to reload the weight. + Called within implementations of the `process_weights_after_loading` method. + + This function should not be called on weights which are tied/shared + + Args: + layer: Layer containing parameter to replace + param_name: Name of parameter to replace + new_data: New data of the new parameter + """ + # should not be used on a tied/shared param + if isinstance(new_data, torch.nn.Parameter): + new_data = new_data.data + new_param = torch.nn.Parameter(new_data, requires_grad=False) + + old_param: torch.nn.Parameter | None = getattr(layer, param_name, None) + if old_param is not None and hasattr(old_param, "weight_loader"): + weight_loader = old_param.weight_loader + set_weight_attrs(new_param, {"weight_loader": weight_loader}) + + setattr(layer, param_name, new_param) + + def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: parent_map = getattr(model, "packed_modules_mapping", None) parent_map = copy.deepcopy(parent_map) if parent_map is not None else {} diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index e0c584df8760b..936f6b1e28ce1 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -89,7 +89,7 @@ def _extract_data_from_linear_base_module( assert m.quant_method.quant_config is not None w = m.weight - ws = m.weight_scale + ws = m.weight_scale_inv if hasattr(m, "weight_scale_inv") else m.weight_scale quant_block_size = m.quant_method.quant_config.weight_block_size assert isinstance(w, torch.Tensor) diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index b93a42ffd24c1..51b8f77f29088 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -11,6 +11,7 @@ import pybase64 import torch from vllm.utils.import_utils import PlaceholderModule +from vllm.utils.serial_utils import tensor2base64 from .base import MediaIO @@ -126,17 +127,21 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]): def load_bytes(self, data: bytes) -> torch.Tensor: buffer = BytesIO(data) - return torch.load(buffer, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(buffer, weights_only=True) + return tensor.to_dense() def load_base64(self, media_type: str, data: str) -> torch.Tensor: return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> torch.Tensor: - return torch.load(filepath, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(filepath, weights_only=True) + return tensor.to_dense() def encode_base64(self, media: torch.Tensor) -> str: - buffer = BytesIO() - torch.save(media, buffer) - buffer.seek(0) - binary_data = buffer.read() - return pybase64.b64encode(binary_data).decode("utf-8") + return tensor2base64(media) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index fef118a93c6cb..53eb4c591ef99 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -2,12 +2,42 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path from typing import Generic, TypeVar +import numpy as np + _T = TypeVar("_T") +@dataclass +class MediaWithBytes(Generic[_T]): + """ + Wrapper that couples a media object with its original encoded bytes. + + This ensures the raw bytes and media object remain synchronized, + preventing cache corruption from in-place modifications. + + The wrapper delegates attribute access to the underlying media object, + making it behave transparently like the wrapped type (e.g., PIL.Image). + + NOTE: Currently, this wrapper is used only for the image modality. + """ + + media: _T + original_bytes: bytes + + def __array__(self, *args, **kwargs) -> np.ndarray: + """Allow np.array(obj) to return np.array(obj.media).""" + return np.array(self.media, *args, **kwargs) + + def __getattr__(self, name: str): + """Delegate attribute access to the underlying media object.""" + # This is only called when the attribute is not found on self + return getattr(self.media, name) + + class MediaIO(ABC, Generic[_T]): @abstractmethod def load_bytes(self, data: bytes) -> _T: diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index c1531cbfdc31d..67bdf5e1557f9 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -25,7 +25,6 @@ from .inputs import ( MultiModalBatchedField, MultiModalFeatureSpec, MultiModalFieldElem, - MultiModalKwargs, MultiModalKwargsItem, MultiModalKwargsItems, NestedTensors, @@ -90,7 +89,6 @@ MultiModalCacheValue: TypeAlias = ( | MultiModalProcessorCacheItemMetadata | MultiModalKwargsItems | MultiModalKwargsItem - | MultiModalKwargs | Mapping[str, NestedTensors] ) @@ -108,12 +106,7 @@ class MultiModalCache: # These are not subclasses of dict if isinstance( leaf, - ( - MultiModalKwargs, - MultiModalKwargsItems, - MultiModalKwargsItem, - MultiModalFieldElem, - ), + (MultiModalKwargsItems, MultiModalKwargsItem, MultiModalFieldElem), ): return cls.get_item_size(leaf.data) # type: ignore @@ -302,6 +295,19 @@ class BaseMultiModalProcessorCache( """ return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + @abstractmethod + def touch_sender_cache_item(self, mm_hash: str) -> None: + """ + Update the cache eviction order for a multi-modal item. + + This is used to touch the item in the cache without changing + its value. + + Args: + mm_hash: The hash of the multi-modal item. + """ + raise NotImplementedError + @abstractmethod def make_stats(self, *, delta: bool = False) -> CacheInfo: """ @@ -353,6 +359,10 @@ class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): return mm_item + @override + def touch_sender_cache_item(self, mm_hash: str) -> None: + self._cache.touch(mm_hash) + @override def clear_cache(self) -> None: self._cache.clear() @@ -407,6 +417,10 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): return mm_item + @override + def touch_sender_cache_item(self, mm_hash: str) -> None: + self._cache.touch(mm_hash) + @override def clear_cache(self) -> None: self._cache.clear() @@ -501,6 +515,12 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e) return mm_item + @override + def touch_sender_cache_item(self, mm_hash: str) -> None: + """Touch the item in shared memory cache to prevent eviction. + Increments writer_flag on sender side.""" + self._shm_cache.touch(mm_hash) + @override def clear_cache(self) -> None: self._shm_cache.clear() @@ -610,11 +630,37 @@ class BaseMultiModalReceiverCache( self, mm_features: list["MultiModalFeatureSpec"], ) -> list["MultiModalFeatureSpec"]: - """Update multimodal features with cached encoder outputs.""" + """ + Update multimodal features with cached encoder outputs. + Touch all identifier at first before update to avoid + item in updated list evict during update. + """ + for feature in mm_features: + self.touch_receiver_cache_item(feature.identifier, feature.data) + for feature in mm_features: feature.data = self.get_and_update_item(feature.data, feature.identifier) return mm_features + @abstractmethod + def touch_receiver_cache_item( + self, + mm_hash: str, + mm_item: MultiModalKwargsItem | None = None, + ) -> None: + """ + Update the cache eviction order for a multi-modal item. + + This is used to touch the item in the cache without changing + its value. + + Args: + mm_hash: The hash of the multi-modal item. + mm_item: The multi-modal item itself. This is optional and + may not be needed by some cache implementations. + """ + raise NotImplementedError + class MultiModalReceiverCache(BaseMultiModalReceiverCache): """ @@ -651,6 +697,14 @@ class MultiModalReceiverCache(BaseMultiModalReceiverCache): self._cache[mm_hash] = mm_item return mm_item + @override + def touch_receiver_cache_item( + self, + mm_hash: str, + mm_item: MultiModalKwargsItem | None = None, + ) -> None: + self._cache.touch(mm_hash) + @override def clear_cache(self) -> None: self._cache.clear() @@ -703,6 +757,20 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache): return mm_item + @override + def touch_receiver_cache_item( + self, + mm_hash: str, + mm_item: MultiModalKwargsItem | None = None, + ) -> None: + """Touch the item in shared memory cache to prevent eviction. + Increments reader_count on receiver side.""" + assert mm_item is not None + if "address" in mm_item: + address = cast(int, mm_item["address"].data) + monotonic_id = cast(int, mm_item["monotonic_id"].data) + self._shm_cache.touch(mm_hash, address=address, monotonic_id=monotonic_id) + @override def clear_cache(self) -> None: self._shm_cache.clear() diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index d0dcbb25fcce8..cc50322fed902 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -12,6 +12,8 @@ from PIL import Image from vllm.logger import init_logger +from .base import MediaWithBytes + logger = init_logger(__name__) @@ -31,14 +33,26 @@ class MultiModalHasher: if Image.ExifTags.Base.ImageID in exif and isinstance( exif[Image.ExifTags.Base.ImageID], uuid.UUID ): - # If the image has exif ImageID tag, use that return (exif[Image.ExifTags.Base.ImageID].bytes,) + data = {"mode": obj.mode, "data": np.asarray(obj)} - if obj.palette is not None: - data["palette"] = obj.palette.palette - if obj.palette.rawmode is not None: - data["palette_rawmode"] = obj.palette.rawmode + palette = obj.palette + if palette is not None: + data["palette"] = palette.palette + if palette.rawmode is not None: + data["palette_rawmode"] = palette.rawmode + return cls.iter_item_to_bytes("image", data) + + if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image): + exif = obj.media.getexif() + if Image.ExifTags.Base.ImageID in exif and isinstance( + exif[Image.ExifTags.Base.ImageID], uuid.UUID + ): + return (exif[Image.ExifTags.Base.ImageID].bytes,) + + return cls.iter_item_to_bytes("image", obj.original_bytes) + if isinstance(obj, torch.Tensor): tensor_obj: torch.Tensor = obj.cpu() tensor_dtype = tensor_obj.dtype diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 21e8bef97a787..1506ecb8c7aa0 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -8,7 +8,7 @@ import pybase64 import torch from PIL import Image -from .base import MediaIO +from .base import MediaIO, MediaWithBytes def rescale_image_size( @@ -74,8 +74,12 @@ class ImageMediaIO(MediaIO[Image.Image]): ) self.rgba_background_color = rgba_bg - def _convert_image_mode(self, image: Image.Image) -> Image.Image: + def _convert_image_mode( + self, image: Image.Image | MediaWithBytes[Image.Image] + ) -> Image.Image: """Convert image mode with custom background color.""" + if isinstance(image, MediaWithBytes): + image = image.media if image.mode == self.image_mode: return image elif image.mode == "RGBA" and self.image_mode == "RGB": @@ -83,18 +87,18 @@ class ImageMediaIO(MediaIO[Image.Image]): else: return convert_image_mode(image, self.image_mode) - def load_bytes(self, data: bytes) -> Image.Image: + def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]: image = Image.open(BytesIO(data)) - image.load() - return self._convert_image_mode(image) + return MediaWithBytes(self._convert_image_mode(image), data) - def load_base64(self, media_type: str, data: str) -> Image.Image: + def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]: return self.load_bytes(pybase64.b64decode(data, validate=True)) - def load_file(self, filepath: Path) -> Image.Image: - image = Image.open(filepath) - image.load() - return self._convert_image_mode(image) + def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]: + with open(filepath, "rb") as f: + data = f.read() + image = Image.open(BytesIO(data)) + return MediaWithBytes(self._convert_image_mode(image), data) def encode_base64( self, @@ -118,13 +122,21 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): def load_bytes(self, data: bytes) -> torch.Tensor: buffer = BytesIO(data) - return torch.load(buffer, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(buffer, weights_only=True) + return tensor.to_dense() def load_base64(self, media_type: str, data: str) -> torch.Tensor: return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> torch.Tensor: - return torch.load(filepath, weights_only=True) + # Enable sparse tensor integrity checks to prevent out-of-bounds + # writes from maliciously crafted tensors + with torch.sparse.check_sparse_tensor_invariants(): + tensor = torch.load(filepath, weights_only=True) + return tensor.to_dense() def encode_base64(self, media: torch.Tensor) -> str: return pybase64.b64encode(media.numpy()).decode("utf-8") diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 7518a023c5f50..6b1cbbe24e2e7 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from PIL.Image import Image from transformers.feature_extraction_utils import BatchFeature + from .base import MediaWithBytes from .processing import MultiModalHashes else: @@ -59,7 +60,7 @@ Represents a single audio item, which can be passed to a HuggingFace `AudioProcessor`. """ -ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"] +ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor", "MediaWithBytes[HfImageItem]"] """ A `transformers.image_utils.ImageInput` representing a single image item, which can be passed to a HuggingFace `ImageProcessor`. @@ -174,6 +175,31 @@ class PlaceholderRange: return int(self.is_embed.sum().item()) + def extract_embeds_range(self) -> list[tuple[int, int]]: + """Extract the start and end indices of the embedded region in prompt. + + For example, given `PlaceholderRange(offset=2, length=5)` and + `is_embed = [False, True, False, True, True]`, the output is + `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`. + + Returns: + A tuple `(start, end)` representing the start and end + indices (inclusive) of the embedded region. + Returns full placeholder range if `is_embed` is `None`. + """ + if self.is_embed is None: + return [(self.offset, self.offset + self.length)] + + mask_i = self.is_embed.int() + starts = torch.nonzero( + torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1 + ).flatten() + ends = torch.nonzero( + torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1 + ).flatten() + ranges = torch.stack((starts, ends), dim=1) + self.offset + return [tuple(x) for x in ranges.tolist()] + def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False @@ -200,8 +226,10 @@ Uses a list instead of a tensor if the dimensions of each element do not match. def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: - """Equality check between - [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects.""" + """ + Equality check between + [`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects. + """ if isinstance(a, torch.Tensor): return isinstance(b, torch.Tensor) and torch.equal(a, b) elif isinstance(b, torch.Tensor): @@ -220,13 +248,44 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: return a == b +def _nested_tensors_h2d( + tensors: NestedTensors, + device: torch.types.Device, +) -> NestedTensors: + if device is None: + return tensors + + return json_map_leaves( + ( + lambda x: x.to(device=device, non_blocking=True) + if isinstance(x, torch.Tensor) + else x + ), + tensors, + ) + + BatchedTensorInputs: TypeAlias = dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via -[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch]. +[`MultiModalKwargsItems.get_data`][vllm.multimodal.inputs.MultiModalKwargsItems.get_data]. """ +def batched_tensors_equal(a: BatchedTensorInputs, b: BatchedTensorInputs) -> bool: + """ + Equality check between + [`BatchedTensorInputs`][vllm.multimodal.inputs.BatchedTensorInputs] objects. + """ + for k in a: + if k not in b: + return False + if not nested_tensors_equal(a[k], b[k]): + return False + + return True + + @dataclass class MultiModalFeatureSpec: """ @@ -317,7 +376,7 @@ class MultiModalFieldElem: ) # noqa: E721 -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class BaseMultiModalField(ABC): """ Defines how to interpret tensor data belonging to a keyword argument in @@ -325,6 +384,12 @@ class BaseMultiModalField(ABC): multi-modal items, and vice versa. """ + keep_on_cpu: bool = False + """ + If `True`, then this field is excluded from being moved to the accelerator + when `MultiModalKwargsItems.get_data()` is called to batch the data. + """ + def _field_factory(self, *, modality: str, key: str): f = partial( MultiModalFieldElem, @@ -369,6 +434,7 @@ class BaseMultiModalField(ABC): self, elems: list[MultiModalFieldElem], *, + device: torch.types.Device = None, pin_memory: bool = False, ) -> NestedTensors: """ @@ -382,11 +448,17 @@ class BaseMultiModalField(ABC): if len(set(field_types)) > 1: raise ValueError(f"Cannot merge different {field_types=}") + if device is not None and self.keep_on_cpu: + device = "cpu" + if pin_memory and self.keep_on_cpu: + pin_memory = False + batch = [elem.data for elem in elems] - return self._reduce_data(batch, pin_memory=pin_memory) + out = self._reduce_data(batch, pin_memory=pin_memory) + return _nested_tensors_h2d(out, device=device) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class MultiModalBatchedField(BaseMultiModalField): """ Info: @@ -428,7 +500,7 @@ class MultiModalBatchedField(BaseMultiModalField): return batch -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class MultiModalFlatField(BaseMultiModalField): """ Info: @@ -488,7 +560,7 @@ class MultiModalFlatField(BaseMultiModalField): return [e for elem in batch for e in elem] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class MultiModalSharedField(BaseMultiModalField): """ Info: @@ -515,9 +587,10 @@ class MultiModalSharedField(BaseMultiModalField): return batch[0] +@dataclass(frozen=True) class MultiModalFieldConfig: @staticmethod - def batched(modality: str): + def batched(modality: str, *, keep_on_cpu: bool = False): """ Defines a field where an element in the batch is obtained by indexing into the first dimension of the underlying data. @@ -525,6 +598,7 @@ class MultiModalFieldConfig: Args: modality: The modality of the multi-modal item that uses this keyword argument. + keep_on_cpu: Whether to keep this field on the CPU for the model inputs. Example: @@ -541,7 +615,7 @@ class MultiModalFieldConfig: ``` """ return MultiModalFieldConfig( - field=MultiModalBatchedField(), + field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu), modality=modality, ) @@ -550,6 +624,8 @@ class MultiModalFieldConfig: modality: str, slices: Sequence[slice] | Sequence[Sequence[slice]], dim: int = 0, + *, + keep_on_cpu: bool = False, ): """ Defines a field where an element in the batch is obtained by @@ -562,6 +638,7 @@ class MultiModalFieldConfig: slices (dim>0) that is used to extract the data corresponding to it. dim: The dimension to extract data, default to 0. + keep_on_cpu: Whether to keep this field on the CPU for the model inputs. Example: @@ -596,12 +673,22 @@ class MultiModalFieldConfig: ``` """ return MultiModalFieldConfig( - field=MultiModalFlatField(slices=slices, dim=dim), + field=MultiModalFlatField( + slices=slices, + dim=dim, + keep_on_cpu=keep_on_cpu, + ), modality=modality, ) @staticmethod - def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0): + def flat_from_sizes( + modality: str, + size_per_item: "torch.Tensor", + dim: int = 0, + *, + keep_on_cpu: bool = False, + ): """ Defines a field where an element in the batch is obtained by slicing along the first dimension of the underlying data. @@ -612,6 +699,7 @@ class MultiModalFieldConfig: size_per_item: For each multi-modal item, the size of the slice that is used to extract the data corresponding to it. dim: The dimension to slice, default to 0. + keep_on_cpu: Whether to keep this field on the CPU for the model inputs. Example: @@ -659,10 +747,20 @@ class MultiModalFieldConfig: for i in range(len(size_per_item)) ] - return MultiModalFieldConfig.flat(modality, slices, dim=dim) + return MultiModalFieldConfig.flat( + modality, + slices, + dim=dim, + keep_on_cpu=keep_on_cpu, + ) @staticmethod - def shared(modality: str, batch_size: int): + def shared( + modality: str, + batch_size: int, + *, + keep_on_cpu: bool = False, + ): """ Defines a field where an element in the batch is obtained by taking the entirety of the underlying data. @@ -673,6 +771,7 @@ class MultiModalFieldConfig: modality: The modality of the multi-modal item that uses this keyword argument. batch_size: The number of multi-modal items which share this data. + keep_on_cpu: Whether to keep this field on the CPU for the model inputs. Example: @@ -691,18 +790,15 @@ class MultiModalFieldConfig: ``` """ return MultiModalFieldConfig( - field=MultiModalSharedField(batch_size), + field=MultiModalSharedField( + batch_size=batch_size, + keep_on_cpu=keep_on_cpu, + ), modality=modality, ) - def __init__(self, field: BaseMultiModalField, modality: str) -> None: - super().__init__() - - self.field = field - self.modality = modality - - def __repr__(self) -> str: - return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})" + field: BaseMultiModalField + modality: str def build_elems( self, @@ -721,13 +817,13 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): """ @staticmethod - def dummy(modality: str): + def dummy(modality: str, nbytes: int = 1): """Convenience class for testing.""" mm_elem = MultiModalFieldElem( modality=modality, key="dummy", - data=torch.empty(1), - field=MultiModalSharedField(1), + data=torch.empty(nbytes, dtype=torch.uint8), + field=MultiModalSharedField(batch_size=1), ) return MultiModalKwargsItem.from_elems([mm_elem]) @@ -822,7 +918,13 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): return self # type: ignore[return-value] - def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": + def get_data( + self, + *, + device: torch.types.Device = None, + pin_memory: bool = False, + ) -> BatchedTensorInputs: + """Construct a dictionary of keyword arguments to pass to the model.""" elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) for modality, items in self.items(): for i, item in enumerate(items): @@ -834,12 +936,16 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): for key, elem in item.items(): elems_by_key[key].append(elem) - return MultiModalKwargs( - { - key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) - for key, elems in elems_by_key.items() - } - ) + data = { + key: elems[0].field.reduce_data( + elems, + device=device, + pin_memory=pin_memory, + ) + for key, elems in elems_by_key.items() + } + + return data MultiModalKwargsOptionalItems: TypeAlias = ( @@ -848,6 +954,7 @@ MultiModalKwargsOptionalItems: TypeAlias = ( ) +@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.14.") class MultiModalKwargs(UserDict[str, NestedTensors]): """ A dictionary that represents the keyword arguments to @@ -857,7 +964,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): @staticmethod @deprecated( "`MultiModalKwargs.from_hf_inputs` is deprecated and " - "will be removed in v0.13. " + "will be removed in v0.14. " "Please use `MultiModalKwargsItems.from_hf_inputs` and " "access the tensor data using `.get_data()`." ) @@ -870,7 +977,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): @staticmethod @deprecated( "`MultiModalKwargs.from_items` is deprecated and " - "will be removed in v0.13. " + "will be removed in v0.14. " "Please use `MultiModalKwargsItems.from_seq` and " "access the tensor data using `.get_data()`." ) @@ -881,91 +988,6 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ): return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory) - @staticmethod - def _try_stack( - nested_tensors: NestedTensors, pin_memory: bool = False - ) -> NestedTensors: - """ - Stack the inner dimensions that have the same shape in - a nested list of tensors. - - Thus, a dimension represented by a list means that the inner - dimensions are different for each element along that dimension. - """ - if isinstance(nested_tensors, torch.Tensor): - return nested_tensors - - # TODO: Remove these once all models have been migrated - if isinstance(nested_tensors, np.ndarray): - return torch.from_numpy(nested_tensors) - if isinstance(nested_tensors, (int, float)): - return torch.tensor(nested_tensors) - - stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors] - if not is_list_of(stacked, torch.Tensor, check="all"): - # Only tensors (not lists) can be stacked. - return stacked - - tensors_ = cast(list[torch.Tensor], stacked) - if len(tensors_) == 1: - # An optimization when `tensors_` contains only one tensor: - # - produce exactly same result as `torch.stack(tensors_)` - # - will achieve zero-copy if the tensor is contiguous - return tensors_[0].unsqueeze(0).contiguous() - - if any(t.shape != tensors_[0].shape for t in tensors_): - # The tensors have incompatible shapes and can't be stacked. - return tensors_ - - outputs = torch.empty( - len(tensors_), - *tensors_[0].shape, - dtype=tensors_[0].dtype, - device=tensors_[0].device, - pin_memory=pin_memory, - ) - return torch.stack(tensors_, out=outputs) - - @staticmethod - def batch( - inputs_list: list["MultiModalKwargs"], pin_memory: bool = False - ) -> BatchedTensorInputs: - """ - Batch multiple inputs together into a dictionary. - - The resulting dictionary has the same keys as the inputs. - If the corresponding value from each input is a tensor and they all - share the same shape, the output value is a single batched tensor; - otherwise, the output value is a list containing the original value - from each input. - """ - if len(inputs_list) == 0: - return {} - - # We need to consider the case where each item in the batch - # contains different modalities (i.e. different keys). - item_lists = defaultdict[str, list[NestedTensors]](list) - - for inputs in inputs_list: - for k, v in inputs.items(): - item_lists[k].append(v) - - return { - k: MultiModalKwargs._try_stack(item_list, pin_memory) - for k, item_list in item_lists.items() - } - - @staticmethod - def as_kwargs( - batched_inputs: BatchedTensorInputs, - *, - device: torch.types.Device, - ) -> BatchedTensorInputs: - return json_map_leaves( - lambda x: x.to(device=device, non_blocking=True), - batched_inputs, - ) - def __getitem__(self, key: str): if key not in self: raise KeyError( diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 810f29072a0fe..a69afc3176cab 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -23,6 +23,7 @@ from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import LazyLoader from .audio import AudioResampler +from .base import MediaWithBytes from .inputs import ( AudioItem, HfAudioItem, @@ -84,6 +85,12 @@ class ModalityDataItems(ABC, Generic[_T, _I]): """Get all data items.""" return [self.get(idx) for idx in range(self.get_count())] + def get_item_for_hash(self, index: int) -> object: + return self.get(index) + + def get_all_items_for_hash(self) -> list[object]: + return [self.get_item_for_hash(idx) for idx in range(self.get_count())] + @abstractmethod def get_processor_data(self) -> Mapping[str, object]: """Get the data to pass to the HF processor.""" @@ -98,14 +105,22 @@ class ModalityDataItems(ABC, Generic[_T, _I]): class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): """Base class for data items that are arranged in a list.""" + def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T: + """Extract media from wrapper if present.""" + return item.media if isinstance(item, MediaWithBytes) else item + def get_count(self) -> int: return len(self.data) def get(self, index: int) -> _T: + return self._unwrap(self.data[index]) + + def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]: + # Return raw item for hashing (preserves original_bytes if present) return self.data[index] def get_processor_data(self) -> Mapping[str, object]: - return {f"{self.modality}s": self.data} + return {f"{self.modality}s": self.get_all()} def get_passthrough_data(self) -> Mapping[str, object]: return {} @@ -119,11 +134,17 @@ class EmbeddingItems( or a list of embedding tensors (one per item). """ + def _unwrap( + self, item: torch.Tensor | MediaWithBytes[torch.Tensor] + ) -> torch.Tensor: + """Extract media from wrapper if present.""" + return item.media if isinstance(item, MediaWithBytes) else item + def get_count(self) -> int: return len(self.data) def get(self, index: int) -> torch.Tensor: - return self.data[index] + return self._unwrap(self.data[index]) def get_processor_data(self) -> Mapping[str, object]: return {} @@ -463,7 +484,7 @@ class MultiModalDataParser: return ImageEmbeddingItems(data) if ( - isinstance(data, PILImage.Image) + isinstance(data, (PILImage.Image, MediaWithBytes)) or isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 3 ): diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 912cff2343dd0..0390773783961 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -25,7 +25,6 @@ from typing_extensions import TypeVar, assert_never from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.transformers_utils.processor import cached_processor_from_config -from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens from vllm.utils.collection_utils import flatten_2d_lists, full_groupby from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.jsontree import JSONTree, json_map_leaves @@ -80,9 +79,9 @@ def _cached_encode( tokenizer: TokenizerLike, text: str, *, - add_special_tokens: bool | None = None, + add_special_tokens: bool = True, ) -> list[int]: - return encode_tokens(tokenizer, text, add_special_tokens=add_special_tokens) + return tokenizer.encode(text, add_special_tokens=add_special_tokens) @lru_cache(maxsize=2048) @@ -90,11 +89,9 @@ def _cached_decode( tokenizer: TokenizerLike, token_ids: tuple[int, ...], *, - skip_special_tokens: bool | None = None, + skip_special_tokens: bool = False, ) -> str: - return decode_tokens( - tokenizer, list(token_ids), skip_special_tokens=skip_special_tokens - ) + return tokenizer.decode(list(token_ids), skip_special_tokens=skip_special_tokens) def _seq2text( @@ -110,7 +107,7 @@ def _seq2text( raise ValueError("You cannot decode tokens when `skip_tokenizer_init=True`") if not use_cache: - return decode_tokens(tokenizer, seq) + return tokenizer.decode(seq) return _cached_decode(tokenizer, tuple(seq)) @@ -126,7 +123,7 @@ def _seq2tokens( raise ValueError("You cannot encode text when `skip_tokenizer_init=True`") if not use_cache: - return encode_tokens(tokenizer, seq, add_special_tokens=False) + return tokenizer.encode(seq, add_special_tokens=False) return _cached_encode(tokenizer, seq, add_special_tokens=False) @@ -1248,7 +1245,13 @@ _I = TypeVar("_I", bound=BaseProcessingInfo) MultiModalHashes = dict[str, list[str]] """ -A collection of hashes with a similar structure as +A collection of the multi-modal hash for each item, with a similar structure as +[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. +""" + +MultiModalIsCached = dict[str, list[bool]] +""" +A collection of the `is_cached` flag for each item, with a similar structure as [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ @@ -1681,7 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # For None entries, compute a hash; otherwise, use provided ID. computed: list[str] = [] - for i, item in enumerate(items): + for i, item in enumerate(items.get_all_items_for_hash()): item_uuid = mm_uuids_per_modality[i] # NOTE: Even if a item_uuid is provided, we still compute a @@ -1725,7 +1728,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): cache: BaseMultiModalProcessorCache, mm_data_items: MultiModalDataItems, mm_hashes: MultiModalHashes, - ) -> MultiModalDataItems: + ) -> tuple[MultiModalIsCached, MultiModalDataItems]: mm_is_cached = { modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items() } @@ -1752,7 +1755,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): missing_modality_data.append(data) mm_missing_data[modality] = missing_modality_data - return self._to_mm_items(mm_missing_data) + return mm_is_cached, self._to_mm_items(mm_missing_data) def _recompute_cached_prompt_update( self, @@ -1769,14 +1772,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, cache: BaseMultiModalProcessorCache, mm_hashes: MultiModalHashes, + mm_is_cached: MultiModalIsCached, mm_missing_kwargs: MultiModalKwargsItems, mm_missing_prompt_updates: MultiModalPromptUpdates, ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]: - # Need to calculate this at the beginning to avoid skipping cache logic - # for subsequently repeated items in the same modality - mm_is_cached = { - modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items() - } + # Need to touch all mm hashes before update to avoid hash in updated + # list evict during update + for hashes in mm_hashes.values(): + for item_hash in hashes: + cache.touch_sender_cache_item(item_hash) mm_missing_next_idx = defaultdict[str, int](lambda: 0) @@ -1789,15 +1793,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): missing_prompt_updates = mm_missing_prompt_updates.get(modality, []) for item_idx, item_hash in enumerate(hashes): - kwargs: MultiModalKwargsItem | None if not mm_is_cached[modality][item_idx]: missing_next_idx = mm_missing_next_idx[modality] - kwargs = missing_kwargs[missing_next_idx] - updates = missing_prompt_updates[missing_next_idx] + missing_kwargs_item = missing_kwargs[missing_next_idx] + missing_updates_item = missing_prompt_updates[missing_next_idx] mm_missing_next_idx[modality] += 1 - item = kwargs, updates + item = missing_kwargs_item, missing_updates_item else: item = None @@ -1896,7 +1899,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_uuids=mm_uuids, ) - mm_missing_data_items = self._get_cache_missing_items( + mm_is_cached, mm_missing_data_items = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, mm_hashes=mm_hashes, @@ -1933,6 +1936,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( cache, mm_hashes=mm_hashes, + mm_is_cached=mm_is_cached, mm_missing_kwargs=mm_missing_kwargs, mm_missing_prompt_updates=mm_missing_prompt_updates, ) @@ -2191,8 +2195,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): tokenizer = self.info.get_tokenizer() decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data) if isinstance(decoder_prompt_raw, str): - decoder_prompt_ids = encode_tokens( - tokenizer, decoder_prompt_raw, add_special_tokens=False + decoder_prompt_ids = tokenizer.encode( + decoder_prompt_raw, add_special_tokens=False ) else: decoder_prompt_ids = decoder_prompt_raw diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 2fdae46e547b0..00a84f9dec4f7 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -6,8 +6,7 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.tokenizers import TokenizerLike -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from .cache import BaseMultiModalProcessorCache from .processing import ( diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 1020554e2e073..7fd05af583b0a 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -19,7 +19,6 @@ from PIL import Image, UnidentifiedImageError import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection from vllm.logger import init_logger -from vllm.utils.jsontree import json_map_leaves from vllm.utils.registry import ExtensionManager from .audio import AudioEmbeddingMediaIO, AudioMediaIO @@ -67,8 +66,9 @@ class MediaConnector: to set num_frames for video, set `--media-io-kwargs '{"video":{"num_frames":40}}'` connection: HTTP connection client to download media contents. - allowed_local_media_path: A local directory to load media files - from. + allowed_local_media_path: A local directory to load media files from. + allowed_media_domains: If set, only media URLs that belong to this + domain can be used for multi-modal inputs. """ super().__init__() @@ -123,16 +123,16 @@ class MediaConnector: "Cannot load local files without `--allowed-local-media-path`." ) - filepath = Path(url2pathname(url_spec.path)) + filepath = Path(url2pathname(url_spec.netloc + url_spec.path)) if allowed_local_media_path not in filepath.resolve().parents: raise ValueError( f"The file path {filepath} must be a subpath " - f"of `--allowed-local-media-path` {allowed_local_media_path}." + f"of `--allowed-local-media-path {allowed_local_media_path}`." ) return media_io.load_file(filepath) - def _assert_url_in_allowed_media_domains(self, url_spec) -> None: + def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None: if ( self.allowed_media_domains and url_spec.hostname not in self.allowed_media_domains @@ -413,7 +413,7 @@ def group_mm_kwargs_by_modality( device: torch.types.Device = None, pin_memory: bool = False, merge_by_field_config: bool | None = None, - multimodal_cpu_fields: Set[str] = frozenset(), + multimodal_cpu_fields: Set[str] | None = None, ) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]: """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same modality together into the same `MultiModalKwargs` instance. @@ -426,59 +426,28 @@ def group_mm_kwargs_by_modality( Yields: A tuple `(modality, num_items, grouped_kwargs)`. """ - if merge_by_field_config is None: - raise RuntimeError( - "`group_mm_kwargs_by_modality` now requires " - "`merge_by_field_config` arg, please update your model runner " - "according to https://github.com/vllm-project/vllm/pull/25676." - ) - if merge_by_field_config is False: + if merge_by_field_config is not None: logger.warning_once( - "The legacy code for batching multi-modal kwargs is deprecated and " - "will be removed in v0.12. Please update your model with " - "`merge_by_field_config=True` to use the new code defined by " - "`MultiModalFieldConfig`. You can refer to " - "https://github.com/vllm-project/vllm/issues/26149 " - "for some examples on how to do this." + "The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` " + "is deprecated and will be removed in v0.14." + ) + if multimodal_cpu_fields is not None: + logger.warning_once( + "The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` " + "is deprecated and will be removed in v0.14." ) - from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems + from vllm.multimodal.inputs import MultiModalKwargsItems for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): items_lst = list(items) + mm_kwargs_items = MultiModalKwargsItems.from_seq(items_lst) + mm_kwargs_data = mm_kwargs_items.get_data( + device=device, + pin_memory=pin_memory, + ) - if merge_by_field_config: - mm_kwargs_group: BatchedTensorInputs = dict( - MultiModalKwargsItems.from_seq(items_lst).get_data( - pin_memory=pin_memory - ) - ) - - if device is not None: - mm_kwargs_group = { - k: json_map_leaves( - lambda x: x.to(device=device, non_blocking=True) - if isinstance(x, torch.Tensor) - else x, - v, - ) - if k not in multimodal_cpu_fields - else v - for k, v in mm_kwargs_group.items() - } - else: - mm_kwargs_group = MultiModalKwargs.as_kwargs( - MultiModalKwargs.batch( - [ - MultiModalKwargsItems.from_seq([item]).get_data() - for item in items_lst - ], - pin_memory=pin_memory, - ), - device=device, - ) - - yield modality, len(items_lst), mm_kwargs_group + yield modality, len(items_lst), mm_kwargs_data def fetch_audio( @@ -489,9 +458,16 @@ def fetch_audio( Args: audio_url: URL of the audio file to fetch. audio_io_kwargs: Additional kwargs passed to handle audio IO. + + Warning: + This method has direct access to local files and is only intended + to be called by user code. Never call this from the online server! """ media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs} - media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) + media_connector = MediaConnector( + media_io_kwargs=media_io_kwargs, + allowed_local_media_path="/", + ) return media_connector.fetch_audio(audio_url) @@ -503,9 +479,16 @@ def fetch_image( Args: image_url: URL of the image file to fetch. image_io_kwargs: Additional kwargs passed to handle image IO. + + Warning: + This method has direct access to local files and is only intended + to be called by user code. Never call this from the online server! """ media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs} - media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) + media_connector = MediaConnector( + media_io_kwargs=media_io_kwargs, + allowed_local_media_path="/", + ) return media_connector.fetch_image(image_url) @@ -517,7 +500,14 @@ def fetch_video( Args: video_url: URL of the video file to fetch. video_io_kwargs: Additional kwargs passed to handle video IO. + + Warning: + This method has direct access to local files and is only intended + to be called by user code. Never call this from the online server! """ media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs} - media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) + media_connector = MediaConnector( + media_io_kwargs=media_io_kwargs, + allowed_local_media_path="/", + ) return media_connector.fetch_video(video_url) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 763f90fde7b6d..024252799cf74 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -267,7 +267,7 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend): return frames, metadata -class VideoMediaIO(MediaIO[npt.NDArray]): +class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): def __init__( self, image_io: ImageMediaIO, @@ -283,8 +283,15 @@ class VideoMediaIO(MediaIO[npt.NDArray]): # They can be passed to the underlying # media loaders (e.g. custom implementations) # for flexible control. + + # Allow per-request override of video backend via kwargs. + # This enables users to specify a different backend than the + # global VLLM_VIDEO_LOADER_BACKEND env var, e.g.: + # --media-io-kwargs '{"video": {"video_backend": "torchcodec"}}' + video_loader_backend = ( + kwargs.pop("video_backend", None) or envs.VLLM_VIDEO_LOADER_BACKEND + ) self.kwargs = kwargs - video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 2b2c2f9cdc571..e1b461d79a655 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -10,6 +10,7 @@ import sys from dataclasses import dataclass from typing import TYPE_CHECKING +import psutil import regex as re import torch @@ -22,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum logger = init_logger(__name__) if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig else: VllmConfig = None @@ -125,20 +127,13 @@ class CpuPlatform(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: logger.info("Cannot use %s backend on CPU.", selected_backend) - if use_mla: + if attn_selector_config.use_mla: raise NotImplementedError("MLA is not supported on CPU.") - if use_sparse: + if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on CPU.") return AttentionBackendEnum.CPU_ATTN.get_path() @@ -147,11 +142,21 @@ class CpuPlatform(Platform): from vllm.utils.mem_constants import GiB_bytes kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE + node_dir = "/sys/devices/system/node" if kv_cache_space is None: - kv_cache_space = 4 * GiB_bytes # type: ignore + nodes = ( + [d for d in os.listdir(node_dir) if d.startswith("node")] + if os.path.exists(node_dir) + else [] + ) + num_numa_nodes = len(nodes) or 1 + free_cpu_memory = psutil.virtual_memory().total // num_numa_nodes + DEFAULT_CPU_MEM_UTILIZATION = 0.5 + kv_cache_space = int(free_cpu_memory * DEFAULT_CPU_MEM_UTILIZATION) + kv_cache_space_gib = kv_cache_space / GiB_bytes logger.warning_once( - "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) " - "for CPU backend is not set, using 4 by default." + "VLLM_CPU_KVCACHE_SPACE not set. Using " + f"{kv_cache_space_gib:.2f} GiB for KV cache." ) else: kv_cache_space *= GiB_bytes @@ -313,10 +318,16 @@ class CpuPlatform(Platform): # We need to find the location of PyTorch's libgomp torch_pkg = os.path.dirname(torch.__file__) site_root = os.path.dirname(torch_pkg) - torch_libs = os.path.join(site_root, "torch.libs") - pytorch_libgomp_so_candidates = glob.glob( - os.path.join(torch_libs, "libgomp-*.so*") - ) + # Search both torch.libs and torch/lib - See: https://github.com/vllm-project/vllm/issues/30470 + torch_libs_paths = [ + os.path.join(site_root, "torch.libs"), + os.path.join(torch_pkg, "lib"), + ] + pytorch_libgomp_so_candidates = [] + for torch_libs in torch_libs_paths: + pytorch_libgomp_so_candidates.extend( + glob.glob(os.path.join(torch_libs, "libgomp*.so*")) + ) if pytorch_libgomp_so_candidates: pytorch_libgomp_so = pytorch_libgomp_so_candidates[0] if ld_preload_str: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 4bf9401b6b051..2dc4ba5d70cac 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -7,15 +7,13 @@ pynvml. However, it should not initialize cuda context. import os from collections.abc import Callable from functools import cache, wraps -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar import torch from typing_extensions import ParamSpec # import custom ops, trigger op registration import vllm._C # noqa -import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.utils.import_utils import import_pynvml @@ -24,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig from vllm.config.cache import CacheDType else: @@ -149,6 +148,8 @@ class CudaPlatformBase(Platform): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + from vllm.attention.backends.registry import AttentionBackendEnum + parallel_config = vllm_config.parallel_config model_config = vllm_config.model_config @@ -171,7 +172,7 @@ class CudaPlatformBase(Platform): and cache_config.block_size is not None ): use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") - # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, + # If `--attention-config.backend` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the # required block_size. @@ -179,23 +180,25 @@ class CudaPlatformBase(Platform): use_cutlass_mla = False use_flashinfer_mla = False - if envs.VLLM_ATTENTION_BACKEND is None: + if vllm_config.attention_config.backend is None: # Default case - if cls.is_device_capability(100): - # Blackwell => Force CutlassMLA. + if cls.is_device_capability_family(100) and not use_sparse: + # Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2). use_cutlass_mla = True - # TODO: This does not work, because the - # global_force_attn_backend_context_manager is not set. - # See vllm/attention/selector.py:_cached_get_attn_backend - envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA" + # Set the backend in AttentionConfig so it's used during + # backend selection + vllm_config.attention_config.backend = ( + AttentionBackendEnum.CUTLASS_MLA + ) else: # Not Blackwell use_flashmla = True else: # Forced case - use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" - use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" - use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" + backend = vllm_config.attention_config.backend + use_flashmla = backend == AttentionBackendEnum.FLASHMLA + use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA + use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA from vllm.attention.ops.flashmla import is_flashmla_dense_supported @@ -229,27 +232,20 @@ class CudaPlatformBase(Platform): logger.info( "Forcing kv cache block size to 64 for FlashMLASparse backend." ) - # lazy import to avoid circular import - from vllm.config import CUDAGraphMode - compilation_config = vllm_config.compilation_config + scheduler_config = vllm_config.scheduler_config + # Note: model_config may be None during testing if ( - parallel_config.all2all_backend == "deepep_high_throughput" - and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode != CUDAGraphMode.NONE + model_config is not None + and model_config.is_mm_prefix_lm + and scheduler_config.is_multimodal_model + and not scheduler_config.disable_chunked_mm_input ): - # TODO: Piecewise Cuda graph might be enabled - # if torch compile cache key issue fixed - # See https://github.com/vllm-project/vllm/pull/25093 - logger.info( - "WideEP: Disabling CUDA Graphs since DeepEP high-throughput " - "kernels are optimized for prefill and are incompatible with " - "CUDA Graphs. " - "In order to use CUDA Graphs for decode-optimized workloads, " - "use --all2all-backend with another option, such as " - "deepep_low_latency, pplx, or allgather_reducescatter." + logger.warning( + "Forcing --disable_chunked_mm_input for models " + "with multimodal-bidirectional attention." ) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE + scheduler_config.disable_chunked_mm_input = True @classmethod def get_current_memory_usage( @@ -259,35 +255,11 @@ class CudaPlatformBase(Platform): torch.cuda.reset_peak_memory_stats(device) return torch.cuda.max_memory_allocated(device) - @classmethod - def get_vit_attn_backend( - cls, head_size: int, dtype: torch.dtype - ) -> "AttentionBackendEnum": - # Try FlashAttention first - if (cc := cls.get_device_capability()) and cc.major >= 8: - try: - backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() - if backend_class.supports_head_size( - head_size - ) and backend_class.supports_dtype(dtype): - return AttentionBackendEnum.FLASH_ATTN - except ImportError: - pass - - return AttentionBackendEnum.TORCH_SDPA - @classmethod def get_valid_backends( cls, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - device_capability, - attn_type, + device_capability: DeviceCapability, + attn_selector_config: "AttentionSelectorConfig", ) -> tuple[ list[tuple["AttentionBackendEnum", int]], dict["AttentionBackendEnum", list[str]], @@ -295,20 +267,15 @@ class CudaPlatformBase(Platform): valid_backends_priorities = [] invalid_reasons = {} - backend_priorities = _get_backend_priorities(use_mla, device_capability) + backend_priorities = _get_backend_priorities( + attn_selector_config.use_mla, device_capability + ) for priority, backend in enumerate(backend_priorities): try: backend_class = backend.get_class() invalid_reasons_i = backend_class.validate_configuration( - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - device_capability, - attn_type, + device_capability=device_capability, + **attn_selector_config._asdict(), ) except ImportError: invalid_reasons_i = ["ImportError"] @@ -323,35 +290,19 @@ class CudaPlatformBase(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: "CacheDType | None", - block_size: int | None, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: - if attn_type is None: - attn_type = AttentionType.DECODER - device_capability = cls.get_device_capability() assert device_capability is not None + attn_selector_config = attn_selector_config._replace(block_size=None) # First try checking just the selected backend, if there is one. if selected_backend is not None: try: backend_class = selected_backend.get_class() invalid_reasons = backend_class.validate_configuration( - head_size, - dtype, - kv_cache_dtype, - None, - use_mla, - has_sink, - use_sparse, - device_capability, - attn_type, + device_capability=device_capability, + **attn_selector_config._asdict(), ) except ImportError: invalid_reasons = ["ImportError"] @@ -367,15 +318,8 @@ class CudaPlatformBase(Platform): # No selected backend or the selected backend is invalid, # so we try finding a valid backend. valid_backends_priorities, invalid_reasons = cls.get_valid_backends( - head_size, - dtype, - kv_cache_dtype, - None, - use_mla, - has_sink, - use_sparse, - device_capability, - attn_type, + device_capability=device_capability, + attn_selector_config=attn_selector_config, ) reasons_str = ( "{" @@ -385,11 +329,7 @@ class CudaPlatformBase(Platform): ) + "}" ) - config_str = ( - f"head_size: {head_size}, dtype: {dtype}, " - f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, " - f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}" - ) + config_str = attn_selector_config.__repr__() logger.debug_once( f"Some attention backends are not valid for {cls.device_name} with " f"{config_str}. Reasons: {reasons_str}." @@ -408,14 +348,50 @@ class CudaPlatformBase(Platform): ) selected_index = sorted_indices[0] selected_backend = valid_backends_priorities[selected_index][0] - logger.info( + logger.info_once( "Using %s attention backend out of potential backends: %s", selected_backend.name, - [b[0].name for b in valid_backends_priorities], + tuple(b[0].name for b in valid_backends_priorities), + scope="local", ) return selected_backend.get_path() + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + return [ + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.FLASH_ATTN, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + # Try FlashAttention first + if (cc := cls.get_device_capability()) and cc.major >= 8: + try: + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size( + head_size + ) and backend_class.supports_dtype(dtype): + return AttentionBackendEnum.FLASH_ATTN + except ImportError: + pass + + return AttentionBackendEnum.TORCH_SDPA + @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 27c6fac09f498..d4b40045df384 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -7,7 +7,7 @@ import platform import random import sys from datetime import timedelta -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple, Optional import numpy as np import torch @@ -18,8 +18,8 @@ from vllm.logger import init_logger if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig - from vllm.config.cache import CacheDType from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -222,28 +222,52 @@ class Platform: with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 - @classmethod - def get_vit_attn_backend( - cls, head_size: int, dtype: torch.dtype - ) -> "AttentionBackendEnum": - return AttentionBackendEnum.TORCH_SDPA - @classmethod def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: "CacheDType | None", - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: """Get the attention backend class of a device.""" return "" + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + return [ + AttentionBackendEnum.TORCH_SDPA, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + """ + Get the vision attention backend class of a device. + + NOTE: ViT Attention should be checked and override in the platform-specific + implementation. we should not override this in any other places, like + the model_executor/models/<model_name>.py. + + We check if the backend is None or not: + 1. If not, check if the backend is supported by the platform. + 2. If None, continue to the default selection logic. + """ + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention" + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + logger.info_once( + f"Using default backend {AttentionBackendEnum.TORCH_SDPA} for vit attention" + ) + return AttentionBackendEnum.TORCH_SDPA + @classmethod def get_device_capability( cls, @@ -300,6 +324,21 @@ class Platform: return current_capability.to_int() == capability + @classmethod + def is_device_capability_family( + cls, + capability: int, + device_id: int = 0, + ) -> bool: + """ + Returns True if the device capability is any <major>.x. + Mirrors CUDA 13 'family' architecture semantics (e.g. 10.x, 11.x, 12.x). + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + return (current_capability.to_int() // 10) == (capability // 10) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ccf3446a3a6e5..c237f7cf887c1 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -3,7 +3,7 @@ import os from functools import cache, lru_cache, wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig logger = init_logger(__name__) @@ -123,8 +124,6 @@ def use_rocm_custom_paged_attention( alibi_slopes: torch.Tensor | None = None, sinks: torch.Tensor | None = None, ) -> bool: - from vllm._aiter_ops import rocm_aiter_ops - GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) @@ -140,7 +139,6 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (rocm_aiter_ops.is_pa_attn_enabled()) and sinks is None ) @@ -187,41 +185,19 @@ class RocmPlatform(Platform): if not on_gfx9(): supported_quantization += ["bitsandbytes"] - @classmethod - def get_vit_attn_backend( - cls, head_size: int, dtype: torch.dtype - ) -> AttentionBackendEnum: - from importlib.util import find_spec - - from vllm._aiter_ops import rocm_aiter_ops - - if rocm_aiter_ops.is_mha_enabled(): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. - return AttentionBackendEnum.ROCM_AITER_FA - - if on_gfx9() and find_spec("flash_attn") is not None: - return AttentionBackendEnum.FLASH_ATTN - - return AttentionBackendEnum.TORCH_SDPA - @classmethod def get_attn_backend_cls( cls, - selected_backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla, - has_sink, - use_sparse, - attn_type: str | None = None, + selected_backend: "AttentionBackendEnum", + attn_selector_config: "AttentionSelectorConfig", ) -> str: from vllm._aiter_ops import rocm_aiter_ops - if use_sparse: - if kv_cache_dtype.startswith("fp8"): + block_size = attn_selector_config.block_size + kv_cache_dtype = attn_selector_config.kv_cache_dtype + + if attn_selector_config.use_sparse: + if kv_cache_dtype and kv_cache_dtype.startswith("fp8"): raise ValueError( "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." ) @@ -231,7 +207,7 @@ class RocmPlatform(Platform): logger.info_once("Using Sparse MLA backend on V1 engine.") return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() - if use_mla: + if attn_selector_config.use_mla: if selected_backend is None: selected_backend = ( AttentionBackendEnum.ROCM_AITER_MLA @@ -321,6 +297,43 @@ class RocmPlatform(Platform): "ROCm. Note that V0 attention backends have been removed." ) + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + return [ + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TORCH_SDPA, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + from importlib.util import find_spec + + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_mha_enabled(): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. + return AttentionBackendEnum.ROCM_AITER_FA + + if on_gfx9() and find_spec("flash_attn") is not None: + return AttentionBackendEnum.FLASH_ATTN + + return AttentionBackendEnum.TORCH_SDPA + @classmethod def set_device(cls, device: torch.device) -> None: """ @@ -380,11 +393,43 @@ class RocmPlatform(Platform): compilation_config = vllm_config.compilation_config parallel_config = vllm_config.parallel_config is_eager_execution = compilation_config == CUDAGraphMode.NONE - use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() + use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled() + + if compilation_config.cudagraph_mode.has_full_cudagraphs(): + # decode context parallel does not support full cudagraphs + if parallel_config.decode_context_parallel_size > 1: + logger.warning_once( + "Decode context parallel (DCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # prefill context parallel do not support full cudagraphs + elif parallel_config.prefill_context_parallel_size > 1: + logger.warning_once( + "Prefill context parallel (PCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 + if ( + envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER + # NOTE: This block has been deprecated + # or get_env_variable_attn_backend() + # == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN + # TODO: monitor https://github.com/vllm-project/vllm/pull/30396 + # to see how we can transition to the new way of selecting + # attention backends + ): + cache_config.block_size = 64 + logger.warning( + "[ROCM_AITER_UNIFIED_ATTN]: Setting kv cache block size to 64." + ) + else: + cache_config.block_size = 16 if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" @@ -396,6 +441,9 @@ class RocmPlatform(Platform): ): compilation_config.custom_ops.append("+rms_norm") + if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops: + compilation_config.custom_ops.append("+quant_fp8") + @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index cbc0a996f3661..7c479bf2b6a0e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Optional, cast import torch from tpu_info import device @@ -16,6 +16,7 @@ from .interface import Platform, PlatformEnum if TYPE_CHECKING: from typing import TypeAlias + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams @@ -57,16 +58,9 @@ class TpuPlatform(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, - use_mla: bool, - has_sink, - use_sparse, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: - if use_sparse: + if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on TPU.") if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) @@ -74,6 +68,32 @@ class TpuPlatform(Platform): logger.info("Using Pallas V1 backend.") return AttentionBackendEnum.PALLAS.get_path() + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + return [ + AttentionBackendEnum.PALLAS, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention" + f"Supported backends are: {cls.get_supported_vit_attn_backends()}." + ) + logger.info_once(f"Using backend {backend} for vit attention.") + return backend + + logger.info_once( + f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention." + ) + return AttentionBackendEnum.PALLAS + @classmethod def set_device(cls, device: torch.device) -> None: """ diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 768714fb16726..af8979af36643 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -3,7 +3,7 @@ import contextlib import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -14,6 +14,7 @@ from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig else: VllmConfig = None @@ -42,14 +43,7 @@ class XPUPlatform(Platform): def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: str | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse, - attn_type: str | None = None, + attn_selector_config: "AttentionSelectorConfig", ) -> str: from vllm.v1.attention.backends.utils import set_kv_cache_layout @@ -59,7 +53,7 @@ class XPUPlatform(Platform): "only NHD layout is supported by XPU attention kernels." ) - if use_sparse: + if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on XPU.") if selected_backend == AttentionBackendEnum.TRITON_ATTN: logger.info_once("Using Triton backend.") @@ -70,12 +64,40 @@ class XPUPlatform(Platform): elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " - f"with use_mla: {use_mla}" + f"with use_mla: {attn_selector_config.use_mla}" ) logger.info("Using Flash Attention backend.") return AttentionBackendEnum.FLASH_ATTN.get_path() + @classmethod + def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + # XPU only supports FLASH_ATTN for vision attention. + return [ + AttentionBackendEnum.FLASH_ATTN, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["AttentionBackendEnum"] = None, + ) -> "AttentionBackendEnum": + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: " + f"{cls.get_supported_vit_attn_backends()}." + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend + + logger.info_once( + f"Using backend {AttentionBackendEnum.FLASH_ATTN} for vit attention" + ) + return AttentionBackendEnum.FLASH_ATTN + @classmethod def set_device(cls, device: torch.device) -> None: """ @@ -109,12 +131,6 @@ class XPUPlatform(Platform): device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory - @classmethod - def get_vit_attn_backend( - cls, head_size: int, dtype: torch.dtype - ) -> "AttentionBackendEnum": - return AttentionBackendEnum.FLASH_ATTN - @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/profiler/gpu_profiler.py b/vllm/profiler/wrapper.py similarity index 69% rename from vllm/profiler/gpu_profiler.py rename to vllm/profiler/wrapper.py index 798c615221b9f..f891a88f90394 100644 --- a/vllm/profiler/gpu_profiler.py +++ b/vllm/profiler/wrapper.py @@ -3,26 +3,27 @@ from abc import ABC, abstractmethod from contextlib import nullcontext +from typing import Literal import torch from typing_extensions import override -import vllm.envs as envs +from vllm.config import ProfilerConfig from vllm.logger import init_logger logger = init_logger(__name__) class WorkerProfiler(ABC): - def __init__(self) -> None: - self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS + def __init__(self, profiler_config: ProfilerConfig) -> None: + self._delay_iters = profiler_config.delay_iterations if self._delay_iters > 0: logger.info_once( "GPU profiling will start " f"{self._delay_iters} steps after start_profile." ) - self._max_iters = envs.VLLM_PROFILER_MAX_ITERS + self._max_iters = profiler_config.max_iterations if self._max_iters > 0: logger.info_once( "GPU profiling will stop " @@ -60,7 +61,7 @@ class WorkerProfiler(ABC): """Call _stop with error handling but no safeguards.""" try: self._stop() - logger.info("Profiler stopped successfully.") + logger.info_once("Profiler stopped successfully.", scope="local") except Exception as e: logger.warning("Failed to stop profiler: %s", e) self._running = False # Always mark as not running, assume stop worked @@ -90,7 +91,7 @@ class WorkerProfiler(ABC): and self._delay_iters > 0 and self._active_iteration_count == self._delay_iters ): - logger.info("Starting profiler after delay...") + logger.info_once("Starting profiler after delay...", scope="local") self._call_start() if self._running: @@ -104,7 +105,9 @@ class WorkerProfiler(ABC): # Automatically stop the profiler after max iters # will be marked as not running, but leave as active so that stop # can clean up properly - logger.info("Max profiling iterations reached. Stopping profiler...") + logger.info_once( + "Max profiling iterations reached. Stopping profiler...", scope="local" + ) self._call_stop() return @@ -124,7 +127,7 @@ class WorkerProfiler(ABC): def shutdown(self) -> None: """Ensure profiler is stopped when shutting down.""" - logger.info_once("Shutting down profiler") + logger.info_once("Shutting down profiler", scope="local") if self._running: self.stop() @@ -133,38 +136,53 @@ class WorkerProfiler(ABC): return nullcontext() +TorchProfilerActivity = Literal["CPU", "CUDA", "XPU"] +TorchProfilerActivityMap = { + "CPU": torch.profiler.ProfilerActivity.CPU, + "CUDA": torch.profiler.ProfilerActivity.CUDA, + "XPU": torch.profiler.ProfilerActivity.XPU, +} + + class TorchProfilerWrapper(WorkerProfiler): - def __init__(self, worker_name: str, local_rank: int) -> None: - super().__init__() + def __init__( + self, + profiler_config: ProfilerConfig, + worker_name: str, + local_rank: int, + activities: list[TorchProfilerActivity], + ) -> None: + super().__init__(profiler_config) self.local_rank = local_rank - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + self.profiler_config = profiler_config + torch_profiler_trace_dir = profiler_config.torch_profiler_dir if local_rank in (None, 0): - logger.info( + logger.info_once( "Torch profiling enabled. Traces will be saved to: %s", torch_profiler_trace_dir, + scope="local", ) logger.debug( "Profiler config: record_shapes=%s," "profile_memory=%s,with_stack=%s,with_flops=%s", - envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - envs.VLLM_TORCH_PROFILER_WITH_STACK, - envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + profiler_config.torch_profiler_record_shapes, + profiler_config.torch_profiler_with_memory, + profiler_config.torch_profiler_with_stack, + profiler_config.torch_profiler_with_flops, ) + + self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1 self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, - with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + activities=[TorchProfilerActivityMap[activity] for activity in activities], + record_shapes=profiler_config.torch_profiler_record_shapes, + profile_memory=profiler_config.torch_profiler_with_memory, + with_stack=profiler_config.torch_profiler_with_stack, + with_flops=profiler_config.torch_profiler_with_flops, on_trace_ready=torch.profiler.tensorboard_trace_handler( torch_profiler_trace_dir, worker_name=worker_name, - use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, + use_gzip=profiler_config.torch_profiler_use_gzip, ), ) @@ -176,9 +194,10 @@ class TorchProfilerWrapper(WorkerProfiler): def _stop(self) -> None: self.profiler.stop() - if envs.VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: - rank = self.local_rank - profiler_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_config = self.profiler_config + rank = self.local_rank + if profiler_config.torch_profiler_dump_cuda_time_total: + profiler_dir = profiler_config.torch_profiler_dir profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" sort_key = "self_cuda_time_total" table = self.profiler.key_averages().table(sort_by=sort_key) @@ -189,6 +208,12 @@ class TorchProfilerWrapper(WorkerProfiler): # only print profiler results on rank 0 if rank == 0: print(table) + if self.dump_cpu_time_total and rank == 0: + logger.info( + self.profiler.key_averages().table( + sort_by="self_cpu_time_total", row_limit=50 + ) + ) @override def annotate_context_manager(self, name: str): @@ -196,8 +221,8 @@ class TorchProfilerWrapper(WorkerProfiler): class CudaProfilerWrapper(WorkerProfiler): - def __init__(self) -> None: - super().__init__() + def __init__(self, profiler_config: ProfilerConfig) -> None: + super().__init__(profiler_config) # Note: lazy import to avoid dependency issues if CUDA is not available. import torch.cuda.profiler as cuda_profiler diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 36e58dba6b497..7b918d2e3b78f 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -44,6 +44,10 @@ _REASONING_PARSERS_TO_REGISTER = { "granite_reasoning_parser", "GraniteReasoningParser", ), + "holo2": ( + "holo2_reasoning_parser", + "Holo2ReasoningParser", + ), "hunyuan_a13b": ( "hunyuan_a13b_reasoning_parser", "HunyuanA13BReasoningParser", diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 4a04292be009e..bf593ca4e52a0 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -63,6 +63,31 @@ class ReasoningParser: True if the reasoning content ends in the input_ids. """ + def is_reasoning_end_streaming( + self, input_ids: list[int], delta_ids: list[int] + ) -> bool: + """ + Check if the reasoning content ends in the input_ids on a + decode step. + + It is used in structured engines like `xgrammar` to check if the + reasoning content ends in the model output during a decode step. + `input_ids` the entire model output and `delta_ids` are the last few + computed tokens of the model output (like during a decode step). + + Parameters: + input_ids: list[int] + The entire model output. + delta_ids: list[int] + The last few computed tokens of the model output at the current decode step. + + Returns: + bool + True if the reasoning content ends in the `delta_ids` on a + decode step. + """ + return self.is_reasoning_end(input_ids) + @abstractmethod def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ @@ -121,7 +146,7 @@ class ReasoningParser: self, original_tag: str | None, tool_server: ToolServer | None, - ) -> str: + ) -> str | None: """ Instance method that is implemented for preparing the structured tag Otherwise, None is returned @@ -160,7 +185,10 @@ class ReasoningParserManager: if name in cls.lazy_parsers: return cls._load_lazy_parser(name) - raise KeyError(f"Reasoning parser '{name}' not found.") + registered = ", ".join(cls.list_registered()) + raise KeyError( + f"Reasoning parser '{name}' not found. Available parsers: {registered}" + ) @classmethod def list_registered(cls) -> list[str]: diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py index 35084c0e7cc86..43067ca4afe05 100644 --- a/vllm/reasoning/basic_parsers.py +++ b/vllm/reasoning/basic_parsers.py @@ -64,8 +64,21 @@ class BaseThinkingReasoningParser(ReasoningParser): ) def is_reasoning_end(self, input_ids: list[int]) -> bool: + start_token_id = self.start_token_id end_token_id = self.end_token_id - return any(input_id == end_token_id for input_id in reversed(input_ids)) + + for i in range(len(input_ids) - 1, -1, -1): + if input_ids[i] == start_token_id: + return False + if input_ids[i] == end_token_id: + return True + return False + + def is_reasoning_end_streaming( + self, input_ids: list[int], delta_ids: list[int] + ) -> bool: + end_token_id = self.end_token_id + return end_token_id in delta_ids def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ diff --git a/vllm/reasoning/deepseek_v3_reasoning_parser.py b/vllm/reasoning/deepseek_v3_reasoning_parser.py index afdf73262aca0..6604f70badbcf 100644 --- a/vllm/reasoning/deepseek_v3_reasoning_parser.py +++ b/vllm/reasoning/deepseek_v3_reasoning_parser.py @@ -35,6 +35,11 @@ class DeepSeekV3ReasoningParser(ReasoningParser): def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: return self._parser.is_reasoning_end(input_ids) + def is_reasoning_end_streaming( + self, input_ids: list[int], delta_ids: list[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + def extract_content_ids(self, input_ids: list[int]) -> list[int]: return self._parser.extract_content_ids(input_ids) diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py index 0c1b54d0bd359..e0920ef3160b2 100644 --- a/vllm/reasoning/gptoss_reasoning_parser.py +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.harmony_utils import parse_chat_output +from vllm.entrypoints.openai.parser.harmony_utils import parse_chat_output from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.entrypoints.tool_server import ToolServer from vllm.logger import init_logger @@ -145,7 +145,7 @@ class GptOssReasoningParser(ReasoningParser): # This function prepares the structural tag to format reasoning output def prepare_structured_tag( self, original_tag: str | None, tool_server: ToolServer | None - ) -> str: + ) -> str | None: if original_tag is None: if tool_server is None: return json.dumps(no_func_reaonsing_tag) diff --git a/vllm/reasoning/holo2_reasoning_parser.py b/vllm/reasoning/holo2_reasoning_parser.py new file mode 100644 index 0000000000000..f80190d28d6aa --- /dev/null +++ b/vllm/reasoning/holo2_reasoning_parser.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from vllm.logger import init_logger +from vllm.reasoning import ( + ReasoningParser, +) +from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser +from vllm.tokenizers import TokenizerLike + +logger = init_logger(__name__) + + +class Holo2ReasoningParser(ReasoningParser): + """ + Reasoning parser for the Holo2 models which are based on Qwen3. + + The Holo2 model uses <think>...</think> tokens to denote reasoning text but <think> + is part of the chat template. This parser extracts the reasoning content until + </think> in the model's output. + + The model provides a switch to enable or disable reasoning + output via the 'thinking=False' parameter. + + Chat template args: + - thinking: Whether to enable reasoning output (default: True) + + + Parsing rules on model output: + - thinking == False + -> Model output is treated as purely the content |content| + - thinking == True + -> Model output is |reasoning_content|</think>|content| + """ + + def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + # Deepseek V3 and Holo2 are similar. However, Holo2 models think by default. + # this parser without user specified chat template args is initiated once for + # all requests in the structured output manager. So it is important that without + # user specified chat template args, the default thinking is True. + + enable_thinking = bool(chat_kwargs.get("thinking", True)) + + if enable_thinking: + self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + else: + self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: list[int], delta_ids: list[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/vllm/reasoning/identity_reasoning_parser.py b/vllm/reasoning/identity_reasoning_parser.py index e92f8add0391a..e998e071efcf6 100644 --- a/vllm/reasoning/identity_reasoning_parser.py +++ b/vllm/reasoning/identity_reasoning_parser.py @@ -32,6 +32,11 @@ class IdentityReasoningParser(ReasoningParser): # Always return True, since we never treat reasoning specially return True + def is_reasoning_end_streaming( + self, input_ids: list[int], delta_ids: list[int] + ) -> bool: + return True + def extract_content_ids(self, input_ids: list[int]) -> list[int]: # Identity: return all tokens as content return input_ids diff --git a/vllm/reasoning/minimax_m2_reasoning_parser.py b/vllm/reasoning/minimax_m2_reasoning_parser.py index 138d1b4e6dacf..a2b9224cb3bff 100644 --- a/vllm/reasoning/minimax_m2_reasoning_parser.py +++ b/vllm/reasoning/minimax_m2_reasoning_parser.py @@ -19,6 +19,10 @@ logger = init_logger(__name__) class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for MiniMax M2 model. + + MiniMax M2 models don't generate <think> start token, only </think> end + token. All content before </think> is reasoning, content after is the + actual response. """ @property @@ -31,6 +35,45 @@ class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser): """The token that ends reasoning content.""" return "</think>" + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """ + Extract reasoning content from a delta message for streaming. + + MiniMax M2 models don't generate <think> start token, so we assume + all content is reasoning until we encounter the </think> end token. + """ + # Skip single end token + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.end_token_id: + return None + + # Check if end token has already appeared in previous tokens + # meaning we're past the reasoning phase + if self.end_token_id in previous_token_ids: + # We're past the reasoning phase, this is content + return DeltaMessage(content=delta_text) + + # Check if end token is in delta tokens + if self.end_token_id in delta_token_ids: + # End token in delta, split reasoning and content + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + return DeltaMessage( + reasoning=reasoning if reasoning else None, + content=content if content else None, + ) + + # No end token yet, all content is reasoning + return DeltaMessage(reasoning=delta_text) + class MiniMaxM2AppendThinkReasoningParser(ReasoningParser): """ diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py index b61e50c188f8c..de3d1296ec734 100644 --- a/vllm/reasoning/mistral_reasoning_parser.py +++ b/vllm/reasoning/mistral_reasoning_parser.py @@ -3,20 +3,29 @@ from functools import cached_property +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ResponsesRequest, +) from vllm.logger import init_logger from vllm.reasoning import ReasoningParser -from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser -from vllm.tokenizers import MistralTokenizer +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser +from vllm.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) -class MistralReasoningParser(DeepSeekR1ReasoningParser): +class MistralReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for Mistral models. - The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning + The Mistral models uses `[THINK]`...`[/THINK]` tokens to denote reasoning text. This parser extracts the reasoning content from the model output. + + A valid reasoning trace should always start with a `[THINK]` token and end with + a `[/THINK]` token. + + If `[THINK]` token is not generated, then this parser only returns content. """ def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs): @@ -53,3 +62,93 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser): from mistral_common.tokens.tokenizers.base import SpecialTokens return SpecialTokens.end_think + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + has_eot_token = False + + for id in input_ids[::-1]: + if id == self.start_token_id: + # Reasoning ends only if a BOT token is found before a EOT token. + return has_eot_token + elif id == self.end_token_id: + has_eot_token = True + return False + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract the content + """ + has_bot_token = False + has_eot_token = False + bot_token_index = -1 + eot_token_index = -1 + # One for loop instead of multiple lookups + for i, token_id in enumerate(input_ids): + # We filter that we have multiple BOT tokens which should not + # happen for a well prompted trained model + if token_id == self.start_token_id and not has_bot_token: + has_bot_token = True + bot_token_index = i + elif token_id == self.end_token_id: + has_eot_token = True + eot_token_index = i + break + + # 1. Only BOT has been outputted + if has_bot_token and not has_eot_token: + # Should be = [] if model is well prompted and trained. + return input_ids[:bot_token_index] + # 2. Neither BOT or EOT have been outputted + elif not has_bot_token and not has_eot_token: + return input_ids + # 3. Both BOT and EOT have been outputted. + elif has_bot_token and has_eot_token: + return input_ids[:bot_token_index] + input_ids[eot_token_index + 1 :] + # 4. Only EOT has been outputted => this should not have occured for a model + # well prompted and trained. + else: + return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :] + + def extract_reasoning( + self, model_output: str, request: ChatCompletionRequest | ResponsesRequest + ) -> tuple[str | None, str | None]: + """ + Extract reasoning content from the model output. + """ + if not model_output: + return (None, "") + + # Check if the start token is present in the model output, remove it + # if it is present. + prev_bot_token, bot_token, post_bot_token = model_output.partition( + self.start_token + ) + + has_bot_token = bool(bot_token) + # Valid EOT tokens should follow BOT token + has_valid_eot_token = has_bot_token and self.end_token in post_bot_token + + # 1. If there is BOT token followed by EOT token + if has_bot_token and has_valid_eot_token: + prev_eot_token, _, post_eot_token = post_bot_token.partition(self.end_token) + # If model is well prompted and trained prev_bot_token should be "" + content = prev_bot_token + post_eot_token + return prev_eot_token, content if content else None + # 2. Only BOT token + elif has_bot_token: + # If model is well prompted and trained prev_bot_token should be "" + return post_bot_token, prev_bot_token if prev_bot_token else None + # 3. EOT token has been outputted without BOT or neither has been outputted + else: + has_non_valid_eot_token = self.end_token in prev_bot_token + # 3.a EOT token has been outputted without BOT + # If model is well prompted and trained `has_non_valid_eot_token` should + # be `False` and the parser outputs all tokens as 'content' + if has_non_valid_eot_token: + prev_eot_token, _, post_eot_token = prev_bot_token.partition( + self.end_token + ) + return None, prev_eot_token + post_eot_token + # 3.b neither BOT or EOT have been outputted + else: + return None, prev_bot_token diff --git a/vllm/tokenizers/__init__.py b/vllm/tokenizers/__init__.py index 14f0148cf7ba8..31e74b1a16e20 100644 --- a/vllm/tokenizers/__init__.py +++ b/vllm/tokenizers/__init__.py @@ -1,15 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .hf import HfTokenizer -from .mistral import MistralTokenizer from .protocol import TokenizerLike -from .registry import TokenizerRegistry, get_tokenizer +from .registry import ( + TokenizerRegistry, + cached_get_tokenizer, + cached_tokenizer_from_config, + get_tokenizer, + init_tokenizer_from_config, +) __all__ = [ "TokenizerLike", - "HfTokenizer", - "MistralTokenizer", "TokenizerRegistry", + "cached_get_tokenizer", "get_tokenizer", + "cached_tokenizer_from_config", + "init_tokenizer_from_config", ] diff --git a/vllm/tokenizers/deepseek_v32.py b/vllm/tokenizers/deepseek_v32.py new file mode 100644 index 0000000000000..bf279a5cf67c5 --- /dev/null +++ b/vllm/tokenizers/deepseek_v32.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path +from typing import Any + +from transformers import BatchEncoding + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + +from .deepseek_v32_encoding import encode_messages +from .hf import CachedHfTokenizer +from .protocol import TokenizerLike + + +class DeepseekV32Tokenizer(CachedHfTokenizer): + @classmethod + def from_pretrained( + cls, + path_or_repo_id: str | Path, + *args, + trust_remote_code: bool = False, + revision: str | None = None, + download_dir: str | None = None, + **kwargs, + ) -> "TokenizerLike": + tokenizer = super().from_pretrained( + path_or_repo_id, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + download_dir=download_dir, + **kwargs, + ) + return DeepseekV32Tokenizer(tokenizer) + + def __init__(self, tokenizer: TokenizerLike) -> None: + super().__init__() + + self.tokenizer = tokenizer + self.name_or_path = getattr(tokenizer, "name_or_path", "") + + self._added_vocab = self.tokenizer.get_added_vocab() + self._added_vocab_size = len(self._added_vocab) + + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: list[dict[str, Any]] | None = None, + **kwargs, + ) -> str | list[int]: + thinking = kwargs.get("thinking", False) + thinking_mode = "thinking" + if not thinking: + thinking_mode = "chat" + conversation = kwargs.get("conversation", messages) + messages = conversation.copy() + if tools is not None and len(tools) > 0: + messages.insert(0, {"role": "system"}) + messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key] + + # Historical reasoning content is dropped when a new user message is introduced + drop_thinking = messages[-1]["role"] == "user" + + encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) + prompt_str = encode_messages(messages, **encode_config) # type: ignore + + if kwargs.get("tokenize", True): + tokenizer_kwargs = { + k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs + } + return self.encode( + prompt_str, + add_special_tokens=False, + **tokenizer_kwargs, + ) + + return prompt_str + + def num_special_tokens_to_add(self) -> int: + return len(self.encode("")) + + @property + def all_special_tokens(self) -> list[str]: + return self.tokenizer.all_special_tokens + + @property + def all_special_ids(self) -> list[int]: + return self.tokenizer.all_special_ids + + @property + def bos_token_id(self) -> int: + return self.tokenizer.bos_token_id + + @property + def eos_token_id(self) -> int: + return self.tokenizer.eos_token_id + + @property + def pad_token_id(self) -> int: + return self.tokenizer.pad_token_id + + @property + def is_fast(self) -> bool: + return self.tokenizer.is_fast + + @property + def vocab_size(self) -> int: + return self.tokenizer.vocab_size + + @property + def max_token_id(self) -> int: + return self.tokenizer.max_token_id + + @property + def truncation_side(self) -> str: + return self.tokenizer.truncation_side + + def __hash__(self) -> int: + return hash(id(self)) + + def __len__(self) -> int: + # </think> is an added token in DeepseekV32 tokenizer + return self.vocab_size + self._added_vocab_size + + def __call__( + self, + text: str | list[str], + text_pair: str | None = None, + add_special_tokens: bool = True, + truncation: bool = False, + max_length: int | None = None, + ) -> "BatchEncoding": + return self.tokenizer( + text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + truncation=truncation, + max_length=max_length, + ) + + def get_vocab(self) -> dict[str, int]: + return self.tokenizer.get_vocab() + + def get_added_vocab(self) -> dict[str, int]: + return self._added_vocab.copy() + + def encode( + self, + text: str, + truncation: bool | None = None, + max_length: int | None = None, + add_special_tokens: bool = True, + ) -> list[int]: + return self.tokenizer.encode( + text, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, + ) + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + return self.tokenizer.convert_tokens_to_string(tokens) + + def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: + return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens) + + def convert_ids_to_tokens( + self, + ids: list[int], + skip_special_tokens: bool = False, + ) -> list[str]: + return self.tokenizer.convert_ids_to_tokens( + ids, skip_special_tokens=skip_special_tokens + ) diff --git a/vllm/tokenizers/deepseek_v32_encoding.py b/vllm/tokenizers/deepseek_v32_encoding.py new file mode 100644 index 0000000000000..521bd92959312 --- /dev/null +++ b/vllm/tokenizers/deepseek_v32_encoding.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +# copy from https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/encoding/encoding_dsv32.py +import copy +import json +from typing import Any + +import regex as re + +# flake8: noqa: E501 +TOOLS_SYSTEM_TEMPLATE = """## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user: +<{dsml_token}function_calls> +<{dsml_token}invoke name="$FUNCTION_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter> +... +</{dsml_token}invoke> +<{dsml_token}invoke name="$FUNCTION_NAME2"> +... +</{dsml_token}invoke> +</{dsml_token}function_calls> +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: +<{dsml_token}function_calls> +... +</{dsml_token}function_calls> +<function_results> +... +</function_results> +{thinking_start_token}...thinking about results{thinking_end_token} +Here are the functions available in JSONSchema format: +<functions> +{tool_schemas} +</functions> +""" + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "<think>" +thinking_end_token: str = "</think>" +dsml_token: str = "|DSML|" +system_msg_template: str = "{content}" +user_msg_template: str = "<|User|>{content}<|Assistant|>" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" +thinking_template = "{reasoning_content}" + +response_format_template: str = "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +tool_call_template: str = ( + '<{dsml_token}invoke name="{name}">\n{arguments}\n</{dsml_token}invoke>' +) +tool_calls_template = ( + "<{dsml_token}function_calls>\n{tool_calls}\n</{dsml_token}function_calls>" +) + +tool_output_template: str = "\n<result>{content}</result>" + + +def to_json(value: Any) -> str: + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: dict[str, str]) -> str: + p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>""" + P_dsml_strs = [] + if isinstance(tool_call["arguments"], str): + arguments = json.loads(tool_call["arguments"]) + else: + arguments = tool_call["arguments"] + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments( + tool_name: str, tool_args: dict[str, tuple[str, str]] +) -> dict[str, str]: + def _decode_value(key: str, value: str, string: str): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = ( + "{" + + ", ".join( + [_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()] + ) + + "}" + ) + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools: list[dict[str, str | dict[str, Any]]]) -> str: + tools_json = [to_json(t) for t in tools] + + return TOOLS_SYSTEM_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages: list[dict[str, Any]]) -> int: + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +def render_message( + index: int, messages: list[dict[str, Any]], thinking_mode: str +) -> str: + assert 0 <= index < len(messages) + assert thinking_mode in ["chat", "thinking"], ( + f"Invalid thinking_mode `{thinking_mode}`" + ) + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning_content = msg.get("reasoning") or msg.get("reasoning_content") + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + + if response_format: + prompt += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + elif role == "developer": + assert content, f"Invalid message for role `{role}`: {msg}" + content_developer = "" + if tools: + content_developer += "\n\n" + render_tools(tools) + + if response_format: + content_developer += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + content_developer += "\n\n# The user's message is: {}".format(content) + + prompt += user_msg_template.format(content=content_developer) + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "user": + prompt += user_msg_template.format(content=content) + + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "tool": + prev_assistant_idx = index - 1 + assistant_msg = messages[prev_assistant_idx] + while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool": + prev_assistant_idx -= 1 + assistant_msg = messages[prev_assistant_idx] + + assert ( + index == 0 + or prev_assistant_idx >= 0 + and assistant_msg.get("role") == "assistant" + ), f"Invalid messages at {index}:\n{assistant_msg}" + + tool_call_order = index - prev_assistant_idx + assistant_tool_calls = assistant_msg.get("tool_calls") + assert assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order, ( + "No tool calls but found tool output" + ) + + if tool_call_order == 1: + prompt += "\n\n<function_results>" + + prompt += tool_output_template.format(content=content) + + if tool_call_order == len(assistant_tool_calls): + prompt += "\n</function_results>" + + if index >= last_user_idx and thinking_mode == "thinking": + prompt += "\n\n" + thinking_start_token + else: + prompt += "\n\n" + thinking_end_token + + elif role == "assistant": + prev_assistant_idx = index + thinking_part = "" + + tool_calls_content = "" + if tool_calls: + tool_calls = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tool_call.get("name"), + arguments=encode_arguments_to_dsml(tool_call), + ) + for tool_call in tool_calls + ] + tool_calls_content += "\n\n" + tool_calls_template.format( + dsml_token=dsml_token, tool_calls="\n".join(tool_calls) + ) + + summary_content = content or "" + + if thinking_mode == "thinking" and index > last_user_idx: + assert reasoning_content or tool_calls, ( + f"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message" + ) + thinking_part = ( + thinking_template.format(reasoning_content=reasoning_content or "") + + thinking_end_token + ) + + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tool_calls_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + return prompt + + +def drop_thinking_messages( + messages: list[dict[str, Any]], last_user_idx: int | None = None +) -> list[dict[str, Any]]: + messages_wo_thinking: list[dict[str, Any]] = [] + last_user_idx = ( + find_last_user_index(messages) if last_user_idx is None else last_user_idx + ) + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in ["user", "system", "tool"] or idx >= last_user_idx: + messages_wo_thinking.append(msg) + continue + + elif role == "assistant": + msg_wo_thinking = copy.copy(msg) + msg_wo_thinking.pop("reasoning_content", None) + msg_wo_thinking.pop("reasoning", None) + messages_wo_thinking.append(msg_wo_thinking) + + return messages_wo_thinking + + +def encode_messages( + messages: list[dict[str, Any]], + thinking_mode: str, + context: list[dict[str, Any]] | None = None, + drop_thinking: bool = True, + add_default_bos_token: bool = True, +) -> str: + context = context if context else [] + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + if thinking_mode == "thinking" and drop_thinking: + full_messages = drop_thinking_messages(full_messages) + + for idx in range(len(messages)): + prompt += render_message( + idx + len(context), full_messages, thinking_mode=thinking_mode + ) + + return prompt + + +def _read_until_stop( + index: int, text: str, stop: list[str] +) -> tuple[int, str, None | str]: + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls(index: int, text: str): + tool_calls: list[dict[str, Any]] = [] + stop_token = None + tool_calls_end_token = f"</{dsml_token}function_calls>" + + while index < len(text): + index, _, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}invoke", tool_calls_end_token] + ) + assert _ == ">\n", "Tool call format error" + + if stop_token == tool_calls_end_token: + break + + assert stop_token is not None, "Missing special token" + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"] + ) + + p_tool_name = re.findall( + r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL + ) + assert len(p_tool_name) == 1, "Tool name format error" + tool_name = p_tool_name[0] + + tool_args: dict[str, tuple[str, str]] = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop( + index, text, [f"/{dsml_token}parameter"] + ) + + param_kv = re.findall( + r'^ name="(.*?)" string="(true|false)">(.*?)<$', + param_content, + flags=re.DOTALL, + ) + assert len(param_kv) == 1, "Parameter format error" + param_name, string, param_value = param_kv[0] + + assert param_name not in tool_args, "Duplicate parameter name" + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"] + ) + assert content == ">\n", "Parameter format error" + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +# NOTE: This function is designed to parse only correctly +# formatted string and will not attempt to correct malformed output +# that may be generated by the model. +def parse_message_from_completion_text(text: str, thinking_mode: str): + summary_content, reasoning_content, tool_calls = "", "", [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}function_calls" + + is_thinking, is_tool_calling = thinking_mode == "thinking", False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop( + index, text, [thinking_end_token, tool_calls_start_token] + ) + reasoning_content = content_delta + assert stop_token == thinking_end_token, "Invalid thinking format" + + index, content_delta, stop_token = _read_until_stop( + index, text, [eos_token, tool_calls_start_token] + ) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + assert stop_token == eos_token, "Invalid summary format" + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + assert not tool_ends_text, "Unexpected content after tool calls" + + assert len(text) == index and stop_token in [eos_token, None], ( + "Unexpected content at end" + ) + + for sp_token in [ + bos_token, + eos_token, + thinking_start_token, + thinking_end_token, + dsml_token, + ]: + assert sp_token not in summary_content and sp_token not in reasoning_content, ( + "Unexpected special token in content" + ) + + return { + "role": "assistant", + "content": summary_content, + "reasoning_content": reasoning_content, + "reasoning": reasoning_content, + "tool_calls": tool_calls_to_openai_format(tool_calls), + } diff --git a/vllm/tokenizers/hf.py b/vllm/tokenizers/hf.py index 3445073120387..a7b565dca5d8f 100644 --- a/vllm/tokenizers/hf.py +++ b/vllm/tokenizers/hf.py @@ -3,22 +3,18 @@ import contextlib import copy from pathlib import Path -from typing import TYPE_CHECKING +from typing import TypeAlias -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config from .protocol import TokenizerLike -from .registry import TokenizerRegistry -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +HfTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast -def get_cached_tokenizer( - tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast", -) -> TokenizerLike: +def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: """ By default, transformers will recompute multiple tokenizer properties each time they are called, leading to a significant slowdown. @@ -65,11 +61,10 @@ def get_cached_tokenizer( CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" cached_tokenizer.__class__ = CachedTokenizer - return cached_tokenizer # type: ignore + return cached_tokenizer -@TokenizerRegistry.register("hf") -class HfTokenizer(TokenizerLike): +class CachedHfTokenizer(TokenizerLike): @classmethod def from_pretrained( cls, @@ -79,7 +74,7 @@ class HfTokenizer(TokenizerLike): revision: str | None = None, download_dir: str | None = None, **kwargs, - ) -> "TokenizerLike": + ) -> HfTokenizer: try: tokenizer = AutoTokenizer.from_pretrained( path_or_repo_id, diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index 7e6745004b01f..534b0da484a5d 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -3,10 +3,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, cast +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.logger import init_logger from .protocol import TokenizerLike -from .registry import TokenizerRegistry if TYPE_CHECKING: from mistral_common.protocol.instruct.request import ( @@ -14,12 +15,15 @@ if TYPE_CHECKING: ) from mistral_common.tokens.tokenizers.tekken import Tekkenizer from transformers import BatchEncoding - from transformers.tokenization_mistral_common import ( - MistralCommonTokenizer as TransformersMistralTokenizer, - ) - from vllm.entrypoints.chat_utils import ChatCompletionMessageParam - from vllm.entrypoints.openai.protocol import ChatCompletionRequest + try: + # Transformers v5 + from transformers.tokenization_mistral_common import MistralCommonBackend + except ImportError: + # Transformers v4 + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as MistralCommonBackend, + ) logger = init_logger(__name__) @@ -97,6 +101,8 @@ def _prepare_apply_chat_template_tools_and_messages( continue_final_message: bool = False, add_generation_prompt: bool = False, ) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]: + from mistral_common.protocol.instruct.tool_calls import Function, Tool + if add_generation_prompt and continue_final_message: raise ValueError( "Cannot set both `add_generation_prompt` and " @@ -139,6 +145,33 @@ def _prepare_apply_chat_template_tools_and_messages( if function.get("description") is None: function["description"] = "" + # We filter not supported arguments to avoid throwing an error. + # TODO(juliendenize): remove this once OpenAI API is better supported by + # `mistral-common`. + tools_fields = set(Tool.model_fields.keys()) + function_fields = set(Function.model_fields.keys()) + for tool in tools: + tool_keys = list(tool.keys()) + for tool_key in tool_keys: + if tool_key not in tools_fields: + tool.pop(tool_key) + logger.warning_once( + f"'{tool_key}' is not supported by mistral-common for tools. " + "It has been poped from the tool definition." + ) + if tool["type"] == "function": + function_keys = list(tool["function"].keys()) + for function_key in function_keys: + if function_key not in function_fields: + tool["function"].pop(function_key) + logger.warning_once( + f"'{function_key}' is not supported by mistral-common " + "for function tools. It has been poped from the " + "function definition." + ) + else: + raise ValueError("mistral-common only supports function tools.") + return messages, tools @@ -166,7 +199,6 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: return tokenizer.unk_id -@TokenizerRegistry.register("mistral") class MistralTokenizer(TokenizerLike): @classmethod def from_pretrained( @@ -179,11 +211,17 @@ class MistralTokenizer(TokenizerLike): **kwargs, ) -> "MistralTokenizer": from mistral_common.protocol.instruct.validator import ValidationMode - from transformers.tokenization_mistral_common import ( - MistralCommonTokenizer as TransformersMistralTokenizer, - ) - tokenizer = TransformersMistralTokenizer.from_pretrained( + try: + # Transformers v5 + from transformers.tokenization_mistral_common import MistralCommonBackend + except ImportError: + # Transformers v4 + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as MistralCommonBackend, + ) + + tokenizer = MistralCommonBackend.from_pretrained( path_or_repo_id, *args, mode=ValidationMode.test, @@ -194,7 +232,7 @@ class MistralTokenizer(TokenizerLike): return cls(tokenizer) - def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None: + def __init__(self, tokenizer: "MistralCommonBackend") -> None: super().__init__() from mistral_common.protocol.instruct.validator import ValidationMode @@ -268,6 +306,9 @@ class MistralTokenizer(TokenizerLike): for i in all_special_ids ] + def num_special_tokens_to_add(self) -> int: + return len(self.encode("")) + # the following attributes are set to fit vLLM's design and are used # by the structured output backends. @property @@ -380,6 +421,7 @@ class MistralTokenizer(TokenizerLike): ) -> list[int]: add_generation_prompt = kwargs.pop("add_generation_prompt", False) continue_final_message = kwargs.get("continue_final_message", False) + tokenize = kwargs.get("tokenize", True) padding = kwargs.get("padding", False) truncation = kwargs.get("truncation", False) max_length = kwargs.get("max_length") @@ -392,7 +434,7 @@ class MistralTokenizer(TokenizerLike): conversation=messages, tools=tools, continue_final_message=continue_final_message, - tokenize=True, + tokenize=tokenize, padding=padding, truncation=truncation, max_length=max_length, @@ -410,6 +452,13 @@ class MistralTokenizer(TokenizerLike): ids, skip_special_tokens=skip_special_tokens ) + def batch_decode( + self, ids: list[list[int]] | list[int], skip_special_tokens: bool = False + ) -> str: + return self.transformers_tokenizer.batch_decode( + ids, skip_special_tokens=skip_special_tokens + ) + def convert_tokens_to_string(self, tokens: list[str]) -> str: from mistral_common.tokens.tokenizers.base import ( SpecialTokenPolicy, diff --git a/vllm/tokenizers/protocol.py b/vllm/tokenizers/protocol.py index 6c807bd998781..28754f9e10d00 100644 --- a/vllm/tokenizers/protocol.py +++ b/vllm/tokenizers/protocol.py @@ -22,6 +22,9 @@ class TokenizerLike(Protocol): ) -> "TokenizerLike": raise NotImplementedError + def num_special_tokens_to_add(self) -> int: + raise NotImplementedError + @property def all_special_tokens(self) -> list[str]: raise NotImplementedError @@ -94,7 +97,7 @@ class TokenizerLike(Protocol): messages: list["ChatCompletionMessageParam"], tools: list[dict[str, Any]] | None = None, **kwargs, - ) -> list[int]: + ) -> str | list[int]: raise NotImplementedError def convert_tokens_to_string(self, tokens: list[str]) -> str: diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index d5e7899321615..72447ef04e87c 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -1,64 +1,48 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib.util -from collections.abc import Callable +from dataclasses import dataclass, field +from functools import lru_cache from pathlib import Path -from typing import TypeVar, overload +from typing import TYPE_CHECKING import huggingface_hub +from typing_extensions import TypeVar, assert_never, deprecated import vllm.envs as envs from vllm.logger import init_logger -from vllm.transformers_utils.gguf_utils import get_gguf_file_path_from_hf -from vllm.transformers_utils.repo_utils import list_filtered_repo_files -from vllm.transformers_utils.utils import ( +from vllm.transformers_utils.gguf_utils import ( check_gguf_file, + get_gguf_file_path_from_hf, is_gguf, is_remote_gguf, split_remote_gguf, ) +from vllm.transformers_utils.repo_utils import list_filtered_repo_files from vllm.utils.import_utils import resolve_obj_by_qualname from .protocol import TokenizerLike +if TYPE_CHECKING: + from vllm.config.model import ModelConfig, RunnerType + logger = init_logger(__name__) -_T = TypeVar("_T", bound=type[TokenizerLike]) + +_VLLM_TOKENIZERS = { + "deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"), + "hf": ("hf", "CachedHfTokenizer"), + "mistral": ("mistral", "MistralTokenizer"), +} -class TokenizerRegistry: - # Tokenizer name -> tokenizer_cls or (tokenizer module, tokenizer class) - REGISTRY: dict[str, type[TokenizerLike] | tuple[str, str]] = {} +@dataclass +class _TokenizerRegistry: + # Tokenizer mode -> (tokenizer module, tokenizer class) + tokenizers: dict[str, tuple[str, str]] = field(default_factory=dict) - # In-tree tokenizers - @staticmethod - @overload - def register(tokenizer_mode: str) -> Callable[[_T], _T]: ... - - # OOT tokenizers - @staticmethod - @overload - def register(tokenizer_mode: str, module: str, class_name: str) -> None: ... - - @staticmethod - def register( - tokenizer_mode: str, - module: str | None = None, - class_name: str | None = None, - ) -> Callable[[_T], _T] | None: - # In-tree tokenizers - if module is None or class_name is None: - - def wrapper(tokenizer_cls: _T) -> _T: - assert tokenizer_mode not in TokenizerRegistry.REGISTRY - TokenizerRegistry.REGISTRY[tokenizer_mode] = tokenizer_cls - - return tokenizer_cls - - return wrapper - - # OOT tokenizers - if tokenizer_mode in TokenizerRegistry.REGISTRY: + def register(self, tokenizer_mode: str, module: str, class_name: str) -> None: + if tokenizer_mode in self.tokenizers: logger.warning( "%s.%s is already registered for tokenizer_mode=%r. " "It is overwritten by the new one.", @@ -67,36 +51,42 @@ class TokenizerRegistry: tokenizer_mode, ) - TokenizerRegistry.REGISTRY[tokenizer_mode] = (module, class_name) + self.tokenizers[tokenizer_mode] = (module, class_name) return None - @staticmethod - def get_tokenizer(tokenizer_mode: str, *args, **kwargs) -> "TokenizerLike": - if tokenizer_mode not in TokenizerRegistry.REGISTRY: + def load_tokenizer_cls(self, tokenizer_mode: str) -> type[TokenizerLike]: + if tokenizer_mode not in self.tokenizers: raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.") - item = TokenizerRegistry.REGISTRY[tokenizer_mode] - if isinstance(item, type): - return item.from_pretrained(*args, **kwargs) - - module, class_name = item + module, class_name = self.tokenizers[tokenizer_mode] logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}") - class_ = resolve_obj_by_qualname(f"{module}.{class_name}") - return class_.from_pretrained(*args, **kwargs) + return resolve_obj_by_qualname(f"{module}.{class_name}") + + def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike: + tokenizer_cls = self.load_tokenizer_cls(tokenizer_mode) + return tokenizer_cls.from_pretrained(*args, **kwargs) -def get_tokenizer( +TokenizerRegistry = _TokenizerRegistry( + { + mode: (f"vllm.tokenizers.{mod_relname}", cls_name) + for mode, (mod_relname, cls_name) in _VLLM_TOKENIZERS.items() + } +) + + +def resolve_tokenizer_args( tokenizer_name: str | Path, *args, + runner_type: "RunnerType" = "generate", tokenizer_mode: str = "auto", - trust_remote_code: bool = False, - revision: str | None = None, - download_dir: str | None = None, **kwargs, -) -> TokenizerLike: - """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" +): + revision: str | None = kwargs.get("revision") + download_dir: str | None = kwargs.get("download_dir") + if envs.VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -120,16 +110,6 @@ def get_tokenizer( ) tokenizer_name = tokenizer_path - if tokenizer_mode == "slow": - if kwargs.get("use_fast", False): - raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") - - tokenizer_mode = "hf" - kwargs["use_fast"] = False - - if "truncation_side" not in kwargs: - kwargs["truncation_side"] = "left" - # Separate model folder from file path for GGUF models if is_gguf(tokenizer_name): if check_gguf_file(tokenizer_name): @@ -145,6 +125,21 @@ def get_tokenizer( ) kwargs["gguf_file"] = gguf_file + if "truncation_side" not in kwargs: + if runner_type == "generate" or runner_type == "draft": + kwargs["truncation_side"] = "left" + elif runner_type == "pooling": + kwargs["truncation_side"] = "right" + else: + assert_never(runner_type) + + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + + tokenizer_mode = "hf" + kwargs["use_fast"] = False + # Try to use official Mistral tokenizer if possible if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"): allow_patterns = ["tekken.json", "tokenizer.model.v*"] @@ -160,38 +155,79 @@ def get_tokenizer( if tokenizer_mode == "auto": tokenizer_mode = "hf" - tokenizer_args = (tokenizer_name, *args) - tokenizer_kwargs = dict( + return tokenizer_mode, tokenizer_name, args, kwargs + + +cached_resolve_tokenizer_args = lru_cache(resolve_tokenizer_args) + + +def tokenizer_args_from_config(config: "ModelConfig", **kwargs): + return cached_resolve_tokenizer_args( + config.tokenizer, + runner_type=config.runner_type, + tokenizer_mode=config.tokenizer_mode, + revision=config.tokenizer_revision, + trust_remote_code=config.trust_remote_code, + **kwargs, + ) + + +_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike) + + +def get_tokenizer( + tokenizer_name: str | Path, + *args, + tokenizer_cls: type[_T] = TokenizerLike, # type: ignore[assignment] + trust_remote_code: bool = False, + revision: str | None = None, + download_dir: str | None = None, + **kwargs, +) -> _T: + """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" + tokenizer_mode, tokenizer_name, args, kwargs = cached_resolve_tokenizer_args( + tokenizer_name, + *args, trust_remote_code=trust_remote_code, revision=revision, download_dir=download_dir, **kwargs, ) - if tokenizer_mode == "custom": - logger.warning_once( - "TokenizerRegistry now uses `tokenizer_mode` as the registry key " - "instead of `tokenizer_name`. " - "Please update the definition of `.from_pretrained` in " - "your custom tokenizer to accept `args=%s`, `kwargs=%s`. " - "Then, you can pass `tokenizer_mode=%r` instead of " - "`tokenizer_mode='custom'` when initializing vLLM.", - tokenizer_args, - str(tokenizer_kwargs), - tokenizer_mode, - ) + if tokenizer_cls == TokenizerLike: + tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode) + else: + tokenizer_cls_ = tokenizer_cls - tokenizer_mode = str(tokenizer_name) - - tokenizer = TokenizerRegistry.get_tokenizer( - tokenizer_mode, - *tokenizer_args, - **tokenizer_kwargs, - ) + tokenizer = tokenizer_cls_.from_pretrained(tokenizer_name, *args, **kwargs) if not tokenizer.is_fast: logger.warning( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) - return tokenizer + return tokenizer # type: ignore + + +cached_get_tokenizer = lru_cache(get_tokenizer) + + +def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs): + if model_config.skip_tokenizer_init: + return None + + return cached_get_tokenizer( + model_config.tokenizer, + runner_type=model_config.runner_type, + tokenizer_mode=model_config.tokenizer_mode, + revision=model_config.tokenizer_revision, + trust_remote_code=model_config.trust_remote_code, + **kwargs, + ) + + +@deprecated( + "Renamed to `cached_tokenizer_from_config`. The old name will be removed in v0.14." +) +def init_tokenizer_from_config(model_config: "ModelConfig"): + return cached_tokenizer_from_config(model_config) diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py new file mode 100644 index 0000000000000..181d8bcba9553 --- /dev/null +++ b/vllm/tool_parsers/__init__.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) + +__all__ = ["ToolParser", "ToolParserManager"] + + +""" +Register a lazy module mapping. + +Example: + ToolParserManager.register_lazy_module( + name="kimi_k2", + module_path="vllm.tool_parsers.kimi_k2_parser", + class_name="KimiK2ToolParser", + ) +""" + + +_TOOL_PARSERS_TO_REGISTER = { + "deepseek_v3": ( # name + "deepseekv3_tool_parser", # filename + "DeepSeekV3ToolParser", # class_name + ), + "deepseek_v31": ( + "deepseekv31_tool_parser", + "DeepSeekV31ToolParser", + ), + "deepseek_v32": ( + "deepseekv32_tool_parser", + "DeepSeekV32ToolParser", + ), + "ernie45": ( + "ernie45_tool_parser", + "Ernie45ToolParser", + ), + "glm45": ( + "glm4_moe_tool_parser", + "Glm4MoeModelToolParser", + ), + "granite-20b-fc": ( + "granite_20b_fc_tool_parser", + "Granite20bFCToolParser", + ), + "granite": ( + "granite_tool_parser", + "GraniteToolParser", + ), + "hermes": ( + "hermes_tool_parser", + "Hermes2ProToolParser", + ), + "hunyuan_a13b": ( + "hunyuan_a13b_tool_parser", + "HunyuanA13BToolParser", + ), + "internlm": ( + "internlm2_tool_parser", + "Internlm2ToolParser", + ), + "jamba": ( + "jamba_tool_parser", + "JambaToolParser", + ), + "kimi_k2": ( + "kimi_k2_tool_parser", + "KimiK2ToolParser", + ), + "llama3_json": ( + "llama_tool_parser", + "Llama3JsonToolParser", + ), + "llama4_json": ( + "llama_tool_parser", + "Llama3JsonToolParser", + ), + "llama4_pythonic": ( + "llama4_pythonic_tool_parser", + "Llama4PythonicToolParser", + ), + "longcat": ( + "longcat_tool_parser", + "LongcatFlashToolParser", + ), + "minimax_m2": ( + "minimax_m2_tool_parser", + "MinimaxM2ToolParser", + ), + "minimax": ( + "minimax_tool_parser", + "MinimaxToolParser", + ), + "mistral": ( + "mistral_tool_parser", + "MistralToolParser", + ), + "olmo3": ( + "olmo3_tool_parser", + "Olmo3PythonicToolParser", + ), + "openai": ( + "openai_tool_parser", + "OpenAIToolParser", + ), + "phi4_mini_json": ( + "phi4mini_tool_parser", + "Phi4MiniJsonToolParser", + ), + "pythonic": ( + "pythonic_tool_parser", + "PythonicToolParser", + ), + "qwen3_coder": ( + "qwen3coder_tool_parser", + "Qwen3CoderToolParser", + ), + "qwen3_xml": ( + "qwen3xml_tool_parser", + "Qwen3XMLToolParser", + ), + "seed_oss": ( + "seed_oss_tool_parser", + "SeedOssToolParser", + ), + "step3": ( + "step3_tool_parser", + "Step3ToolParser", + ), + "xlam": ( + "xlam_tool_parser", + "xLAMToolParser", + ), + "gigachat3": ( + "gigachat3_tool_parser", + "GigaChat3ToolParser", + ), +} + + +def register_lazy_tool_parsers(): + for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items(): + module_path = f"vllm.tool_parsers.{file_name}" + ToolParserManager.register_lazy_module(name, module_path, class_name) + + +register_lazy_tool_parsers() diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/tool_parsers/abstract_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py rename to vllm/tool_parsers/abstract_tool_parser.py index 87ef2e0786a94..e2ccb1dad9907 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/tool_parsers/abstract_tool_parser.py @@ -17,12 +17,12 @@ from vllm.entrypoints.openai.protocol import ( ResponsesRequest, ResponseTextConfig, ) -from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools from vllm.logger import init_logger from vllm.sampling_params import ( StructuredOutputsParams, ) from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.utils import get_json_schema_from_tools from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import import_from_path @@ -203,7 +203,7 @@ class ToolParserManager: Example: ToolParserManager.register_lazy_module( name="kimi_k2", - module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser", + module_path="vllm.tool_parsers.kimi_k2_parser", class_name="KimiK2ToolParser", ) """ diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/tool_parsers/deepseekv31_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py rename to vllm/tool_parsers/deepseekv31_tool_parser.py index 10de3dabf985c..33383e1bc0739 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py +++ b/vllm/tool_parsers/deepseekv31_tool_parser.py @@ -15,11 +15,9 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser logger = init_logger(__name__) diff --git a/vllm/tool_parsers/deepseekv32_tool_parser.py b/vllm/tool_parsers/deepseekv32_tool_parser.py new file mode 100644 index 0000000000000..db081178fdeae --- /dev/null +++ b/vllm/tool_parsers/deepseekv32_tool_parser.py @@ -0,0 +1,591 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import uuid +from collections.abc import Sequence +from typing import Any + +import regex as re + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) + +logger = init_logger(__name__) + + +class DeepSeekV32ToolParser(ToolParser): + """ + example tool call content: + <|DSML|function_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">杭州</|DSML|parameter> + <|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter> + </|DSML|invoke> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">北京</|DSML|parameter> + <|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter> + </|DSML|invoke> + </|DSML|function_calls> + """ + + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + + self.prev_tool_call_arr: list[dict] = [] + + # Sentinel tokens + self.dsml_token: str = "|DSML|" + self.dsml_start_check: str = "<" + self.dsml_token + self.tool_call_start_token: str = "<|DSML|function_calls>" + self.tool_call_end_token: str = "</|DSML|function_calls>" + self.invoke_start_prefix: str = "<|DSML|invoke name=" + self.invoke_end_token: str = "</|DSML|invoke>" + self.parameter_prefix: str = "<|DSML|parameter name=" + self.parameter_end_token: str = "</|DSML|parameter>" + + # Streaming state variables + self.current_tool_name_sent: bool = False + # Override base class type - we use string IDs for tool calls + self.current_tool_id: str | None = None # type: ignore + self.streamed_args_for_tool: list[str] = [] + self.is_tool_call_started: bool = False + self.failed_count: int = 0 + + # Initialize streaming state variables + self.current_tool_index: int = 0 + self.invoke_index: int = 0 + self.header_sent: bool = False + self.current_function_name: str | None = None + self.current_param_name: str | None = None + self.current_param_value: str = "" + self.param_count: int = 0 + self.in_param: bool = False + self.in_function: bool = False + self.json_started: bool = False + self.json_closed: bool = False + self.accumulated_params: dict = {} + self.streaming_request: ChatCompletionRequest | None = None + + # Enhanced streaming state - reset for each new message + self._reset_streaming_state() + + # Regex patterns for complete parsing + self.tool_call_complete_regex = re.compile( + r"<|DSML|function_calls>(.*?)</|DSML|function_calls>", re.DOTALL + ) + self.invoke_complete_regex = re.compile( + r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>', re.DOTALL + ) + self.parameter_complete_regex = re.compile( + r'<|DSML|parameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</|DSML|parameter>', + re.DOTALL, + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction." + ) + + logger.debug( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) + + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.invoke_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self.current_tool_id = None + self.current_function_name = None + self.current_param_name = None + self.current_param_value = "" + self.param_count = 0 + self.in_param = False + self.in_function = False + self.json_started = False + self.json_closed = False + # Store accumulated parameters for type conversion + self.accumulated_params = {} + self.streaming_request = None + # Clear previous tool call history to avoid state pollution + self.prev_tool_call_arr.clear() + + def _parse_invoke_params(self, invoke_str: str) -> dict | None: + param_dict = dict() + for param_name, param_val in self.parameter_complete_regex.findall(invoke_str): + param_dict[param_name] = param_val + return param_dict + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + """Extract tool calls from complete model output (non-streaming).""" + # Quick check + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + tool_calls = [] + + # Find all complete tool_call blocks + for tool_call_match in self.tool_call_complete_regex.findall(model_output): + # Find all invokes within this tool_call + for invoke_name, invoke_content in self.invoke_complete_regex.findall( + tool_call_match + ): + param_dict = self._parse_invoke_params(invoke_content) + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=invoke_name, + arguments=json.dumps(param_dict, ensure_ascii=False), + ), + ) + ) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + # Extract content before first tool call + first_tool_idx = model_output.find(self.tool_call_start_token) + content = model_output[:first_tool_idx] if first_tool_idx > 0 else None + + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=content + ) + + except Exception: + logger.exception("Error extracting tool calls") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def _extract_name(self, name_str: str) -> str: + """Extract name from quoted string.""" + name_str = name_str.strip() + if ( + name_str.startswith('"') + and name_str.endswith('"') + or name_str.startswith("'") + and name_str.endswith("'") + ): + return name_str[1:-1] + return name_str + + def _extract_param_name(self, input_str: str) -> str: + """Extract param name""" + start = input_str.find('"') + 1 + end = input_str.find('"', start) + return input_str[start:end] if start > 0 and end > start else input_str + + def _convert_param_value(self, value: str, param_type: str) -> Any: + """Convert parameter value to the correct type.""" + if value.lower() == "null": + return None + + param_type = param_type.lower() + if param_type in ["string", "str", "text"]: + return value + elif param_type in ["integer", "int"]: + try: + return int(value) + except (ValueError, TypeError): + return value + elif param_type in ["number", "float"]: + try: + val = float(value) + return val if val != int(val) else int(val) + except (ValueError, TypeError): + return value + elif param_type in ["boolean", "bool"]: + return value.lower() in ["true", "1"] + elif param_type in ["object", "array"]: + try: + return json.loads(value) + except json.JSONDecodeError: + return value + else: + # Try JSON parse first, fallback to string + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], # pylint: disable=unused-argument + current_token_ids: Sequence[int], # pylint: disable=unused-argument + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + """Extract tool calls from streaming model output.""" + + # Store request for type conversion + if not previous_text: + self._reset_streaming_state() + self.streaming_request = request + + # If no delta text, return None unless it's an EOS token after tools + if not delta_text: + # Check if this is an EOS token after all tool calls are complete + if delta_token_ids: + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text) + ) + + # If we have completed tool calls and populated prev_tool_call_arr + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: + # Check if all tool calls are closed + open_calls = current_text.count( + self.tool_call_start_token + ) - current_text.count(self.tool_call_end_token) + if open_calls == 0: + # Return empty delta for finish_reason processing + return DeltaMessage(content="") + elif not self.is_tool_call_started and current_text: + # This is a regular content response that's now complete + return DeltaMessage(content="") + return None + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + # Check if this tool call has ended + invoke_ends = current_text.count(self.invoke_end_token) + if invoke_ends > self.current_tool_index: + # This tool has ended, advance to next + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + self.in_function = False # Now we can safely set this to False + self.accumulated_params = {} + # Continue processing next tool + return None + + # Handle normal content before tool calls + if not self.is_tool_call_started: + # Check if tool call is starting + if self.dsml_token in current_text: + self.is_tool_call_started = True + # Return any content before the tool call + if self.dsml_start_check in delta_text: + content_before = delta_text[ + : delta_text.index(self.dsml_start_check) + ] + if content_before: + return DeltaMessage(content=content_before) + return None + else: + # Check if we're between tool calls - skip whitespace + if ( + current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == "" + ): + # We just ended a tool call, skip whitespace + return None + # Normal content, no tool call + if delta_text.endswith("<"): + return DeltaMessage(content=delta_text[:-1]) + if previous_text and previous_text.endswith("<"): + return DeltaMessage(content="<" + delta_text) + return DeltaMessage(content=delta_text) + + # Check if we're between tool calls (waiting for next one) + invoke_starts_count = current_text.count(self.invoke_start_prefix) + if self.current_tool_index >= invoke_starts_count: + # We're past all tool calls, shouldn't be here + return None + + # Find the current tool call portion + invoke_start_positions: list[int] = [] + idx = 0 + while True: + idx = current_text.find(self.invoke_start_prefix, idx) + if idx == -1: + break + invoke_start_positions.append(idx) + idx += len(self.invoke_start_prefix) + + if self.current_tool_index >= len(invoke_start_positions): + # No more tool calls to process yet + return None + + invoke_start_idx = invoke_start_positions[self.current_tool_index] + # Find where this tool call ends (or current position if not ended yet) + invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx) + if invoke_end_idx == -1: + tool_text = current_text[invoke_start_idx:] + else: + tool_text = current_text[ + invoke_start_idx : invoke_end_idx + len(self.invoke_end_token) + ] + + # Looking for function header + if not self.header_sent: + if self.invoke_start_prefix in tool_text: + func_start = tool_text.find(self.invoke_start_prefix) + len( + self.invoke_start_prefix + ) + # Find the end quote for the function name + func_end = tool_text.find(">", func_start) + + if func_end != -1: + # Found complete function name + function_name_raw = tool_text[func_start:func_end] + self.current_function_name = self._extract_name(function_name_raw) + self.current_tool_id = self._generate_tool_call_id() + self.header_sent = True + self.in_function = True + + # Add to prev_tool_call_arr immediately when we detect a tool call + # Each tool call should be recorded regardless of function name + # Ensure we don't add the same tool call index multiple times + if len(self.prev_tool_call_arr) <= self.current_tool_index: + self.prev_tool_call_arr.append( + { + "name": self.current_function_name, + "arguments": "{}", # Placeholder, will be updated later + } + ) + + # Send header with function info + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments="" + ), + type="function", + ) + ] + ) + return None + + # We've sent header, now handle function body + if self.in_function: + # Send opening brace if not sent yet + if self.in_function and not self.json_started: + self.json_started = True + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ] + ) + + # Make sure json_started is set if we're processing parameters + if not self.json_started: + self.json_started = True + + # Check for function end in accumulated text + if not self.json_closed and self.invoke_end_token in tool_text: + # Count total parameters in the tool text + total_param_count = tool_text.count(self.parameter_prefix) + + # Only close JSON if all parameters have been processed + if self.param_count >= total_param_count: + # Close JSON + self.json_closed = True + + # Extract complete tool call + # Find the invoke content + invoke_start = tool_text.find(self.invoke_start_prefix) + len( + self.invoke_start_prefix + ) + invoke_content_end = tool_text.find( + self.invoke_end_token, invoke_start + ) + if invoke_content_end != -1: + invoke_content = tool_text[invoke_start:invoke_content_end] + # Parse to get the complete arguments + try: + invoke_params = self._parse_invoke_params(invoke_content) + if invoke_params and self.current_tool_index < len( + self.prev_tool_call_arr + ): + # Update existing entry in prev_tool_call_arr + self.prev_tool_call_arr[self.current_tool_index][ + "arguments" + ] = json.dumps(invoke_params, ensure_ascii=False) + except Exception: + pass # Ignore parsing errors during streaming + + result = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ] + ) + + # Reset state for next tool + self.json_closed = True + self.in_function = False + self.accumulated_params = {} + + logger.debug("[M2_STREAMING] Tool call completed") + + return result + else: + # Don't close JSON yet, continue processing parameters + return None + + # Look for parameters + # Find all parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) + + # Check if we should start a new parameter + if ( + not self.in_param + and self.param_count < len(param_starts) + and len(param_starts) > self.param_count + ): + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + param_name_raw = remaining[:name_end] + self.current_param_name = self._extract_param_name(param_name_raw) + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx == -1: + # No closing tag, look for next parameter or function end + next_param_idx = value_text.find(self.parameter_prefix) + func_end_idx = value_text.find(self.invoke_end_token) + + if next_param_idx != -1 and ( + func_end_idx == -1 or next_param_idx < func_end_idx + ): + param_end_idx = next_param_idx + elif func_end_idx != -1: + param_end_idx = func_end_idx + else: + # Neither found, check if tool call is complete + if self.invoke_end_token in tool_text: + # Tool call and parameter is complete + param_end_idx = len(value_text) + else: + # Still streaming, wait for more content + return None + + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Store raw value for later processing + self.accumulated_params[self.current_param_name] = param_value + + # Get parameter configuration for type conversion + param_config = {} + if self.streaming_request and self.streaming_request.tools: + for tool in self.streaming_request.tools: + if ( + hasattr(tool, "function") + and tool.function.name == self.current_function_name + and hasattr(tool.function, "parameters") + ): + params = tool.function.parameters + if ( + isinstance(params, dict) + and "properties" in params + ): + param_config = params["properties"] + break + + # Get parameter type + param_type = "string" + if ( + self.current_param_name in param_config + and isinstance(param_config[self.current_param_name], dict) + and "type" in param_config[self.current_param_name] + ): + param_type = param_config[self.current_param_name]["type"] + + # Convert param value to appropriate type + converted_value = self._convert_param_value( + param_value, param_type + ) + + # Build JSON fragment based on the converted type + # Use json.dumps to properly serialize the value + serialized_value = json.dumps( + converted_value, ensure_ascii=False + ) + + if self.param_count == 0: + json_fragment = ( + f'"{self.current_param_name}": {serialized_value}' + ) + else: + json_fragment = ( + f', "{self.current_param_name}": {serialized_value}' + ) + + self.param_count += 1 + + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments=json_fragment), + ) + ] + ) + + return None diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/tool_parsers/deepseekv3_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py rename to vllm/tool_parsers/deepseekv3_tool_parser.py index 66b14875dce25..f8cf559f2284a 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/tool_parsers/deepseekv3_tool_parser.py @@ -15,11 +15,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py b/vllm/tool_parsers/ernie45_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py rename to vllm/tool_parsers/ernie45_tool_parser.py index d054d8e4b8651..79193787b3b3b 100644 --- a/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py +++ b/vllm/tool_parsers/ernie45_tool_parser.py @@ -15,11 +15,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/tool_parsers/gigachat3_tool_parser.py b/vllm/tool_parsers/gigachat3_tool_parser.py new file mode 100644 index 0000000000000..27a6bc1a7bad8 --- /dev/null +++ b/vllm/tool_parsers/gigachat3_tool_parser.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser + +logger = init_logger(__name__) + +REGEX_FUNCTION_CALL = re.compile( + r"function call(?:<\|role_sep\|>\n)?(\{.*)", + re.DOTALL, +) + +NAME_REGEX = re.compile( + r'"name"\s*:\s*"([^"]*)"', + re.DOTALL, +) + +ARGS_REGEX = re.compile( + r'"arguments"\s*:\s*(.*)', + re.DOTALL, +) + + +class GigaChat3ToolParser(ToolParser): + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + self.tool_started: bool = False + self.tool_name_sent: bool = False + self.tool_id: str | None = None + self.prev_tool_call_arr: list[dict] = [] + self.content_buffer: str = "" + self.trigger_start = "function call{" + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + match = REGEX_FUNCTION_CALL.search(model_output) + if not match: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + json_candidate = match.group(1).strip() + try: + data = json.loads(json_candidate) + except json.JSONDecodeError: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + if not (isinstance(data, dict) and "name" in data and "arguments" in data): + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + name = data["name"] + args = data["arguments"] + if not isinstance(args, str): + args = json.dumps(args, ensure_ascii=False) + + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=name, + arguments=args, + ), + ) + ] + prefix = model_output[: match.start()] + content = prefix.rstrip() if prefix and prefix.strip() else None + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + func_name = None + cur_args = None + if not self.tool_started: + match = REGEX_FUNCTION_CALL.search(current_text) + if match: + self.tool_started = True + self.content_buffer = "" + else: + self.content_buffer += delta_text + clean_buffer = self.content_buffer.lstrip() + is_prefix = self.trigger_start.startswith(clean_buffer) + starts_with_trigger = clean_buffer.startswith(self.trigger_start) + if is_prefix or starts_with_trigger: + return None + else: + flush_text = self.content_buffer + self.content_buffer = "" + return DeltaMessage(content=flush_text) + + match = REGEX_FUNCTION_CALL.search(current_text) + if not match: + return None + json_tail = match.group(1).strip() + name_match = NAME_REGEX.search(json_tail) + if name_match: + func_name = name_match.group(1) + args_match = ARGS_REGEX.search(json_tail) + if args_match: + cur_args = args_match.group(1).strip() + if cur_args.endswith("}"): # last '}' end of json + try: + candidate = cur_args[:-1].strip() + json.loads(candidate) + cur_args = candidate + except json.JSONDecodeError: + pass + if not self.prev_tool_call_arr: + self.prev_tool_call_arr.append({}) + if not self.tool_name_sent: + if not func_name: + return None + self.tool_name_sent = True + self.tool_id = make_tool_call_id() + self.prev_tool_call_arr[0]["name"] = func_name + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + id=self.tool_id, + type="function", + function=DeltaFunctionCall( + name=func_name, + ).model_dump(exclude_none=True), + ) + ], + content=None, + ) + if cur_args is None: + return None + prev_args = self.prev_tool_call_arr[0].get("arguments", "") + if not prev_args: + delta_args = cur_args + elif cur_args.startswith(prev_args): + delta_args = cur_args[len(prev_args) :] + else: + return None + if not delta_args: + return None + self.prev_tool_call_arr[0]["arguments"] = cur_args + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + function=DeltaFunctionCall( + arguments=delta_args, + ).model_dump(exclude_none=True), + ) + ], + content=None, + ) diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/tool_parsers/glm4_moe_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py rename to vllm/tool_parsers/glm4_moe_tool_parser.py index 165346adb3d93..d254fcb5240a5 100644 --- a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +++ b/vllm/tool_parsers/glm4_moe_tool_parser.py @@ -18,11 +18,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/tool_parsers/granite_20b_fc_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py rename to vllm/tool_parsers/granite_20b_fc_tool_parser.py index df1b590526b1a..d841fb57ac87e 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/tool_parsers/granite_20b_fc_tool_parser.py @@ -19,17 +19,17 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.entrypoints.openai.tool_parsers.utils import ( +from vllm.tool_parsers.utils import ( consume_space, find_common_prefix, is_complete_json, partial_json_loads, ) -from vllm.logger import init_logger -from vllm.tokenizers import TokenizerLike logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/tool_parsers/granite_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py rename to vllm/tool_parsers/granite_tool_parser.py index 14b0ca0abe357..7abfdf72849d9 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/tool_parsers/granite_tool_parser.py @@ -17,17 +17,17 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.entrypoints.openai.tool_parsers.utils import ( +from vllm.tool_parsers.utils import ( consume_space, find_common_prefix, is_complete_json, partial_json_loads, ) -from vllm.logger import init_logger -from vllm.tokenizers import TokenizerLike logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/tool_parsers/hermes_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py rename to vllm/tool_parsers/hermes_tool_parser.py index 19c1c83268ed4..4b1dea7edf27a 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/tool_parsers/hermes_tool_parser.py @@ -18,11 +18,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/tool_parsers/hunyuan_a13b_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py rename to vllm/tool_parsers/hunyuan_a13b_tool_parser.py index d2419b5d84ead..c739821368042 100644 --- a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +++ b/vllm/tool_parsers/hunyuan_a13b_tool_parser.py @@ -17,12 +17,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) -from vllm.entrypoints.openai.tool_parsers.utils import consume_space from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) +from vllm.tool_parsers.utils import consume_space from vllm.utils import random_uuid logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/tool_parsers/internlm2_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py rename to vllm/tool_parsers/internlm2_tool_parser.py index 67788358543e9..e87efe3275a71 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/tool_parsers/internlm2_tool_parser.py @@ -17,12 +17,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) +from vllm.tool_parsers.utils import extract_intermediate_diff logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/tool_parsers/jamba_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py rename to vllm/tool_parsers/jamba_tool_parser.py index 4655da8dd4542..7f3de0b38a33c 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/tool_parsers/jamba_tool_parser.py @@ -18,10 +18,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers import ToolParser -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers import ToolParser +from vllm.tool_parsers.utils import extract_intermediate_diff logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/tool_parsers/kimi_k2_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py rename to vllm/tool_parsers/kimi_k2_tool_parser.py index 07db52ebd5af1..c215b7978854e 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/tool_parsers/kimi_k2_tool_parser.py @@ -15,11 +15,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/tool_parsers/llama4_pythonic_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py rename to vllm/tool_parsers/llama4_pythonic_tool_parser.py index 1d6de9244066e..3c5409bbfaf42 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/tool_parsers/llama4_pythonic_tool_parser.py @@ -18,10 +18,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/tool_parsers/llama_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py rename to vllm/tool_parsers/llama_tool_parser.py index e1fe6e90dfd0b..b0dfe05c8e556 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/tool_parsers/llama_tool_parser.py @@ -20,15 +20,15 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.entrypoints.openai.tool_parsers.utils import ( +from vllm.tool_parsers.utils import ( find_common_prefix, is_complete_json, partial_json_loads, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py b/vllm/tool_parsers/longcat_tool_parser.py similarity index 93% rename from vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py rename to vllm/tool_parsers/longcat_tool_parser.py index 76d76a4aa35a1..72f13559a9222 100644 --- a/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py +++ b/vllm/tool_parsers/longcat_tool_parser.py @@ -3,8 +3,8 @@ import regex as re -from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser class LongcatFlashToolParser(Hermes2ProToolParser): diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py b/vllm/tool_parsers/minimax_m2_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py rename to vllm/tool_parsers/minimax_m2_tool_parser.py index b595a98f35555..dcb2b64f6e73c 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py +++ b/vllm/tool_parsers/minimax_m2_tool_parser.py @@ -17,11 +17,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/tool_parsers/minimax_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py rename to vllm/tool_parsers/minimax_tool_parser.py index 1025041037c6e..86e1433c6e360 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/tool_parsers/minimax_tool_parser.py @@ -17,12 +17,12 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) +from vllm.tool_parsers.utils import extract_intermediate_diff logger = init_logger(__name__) diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py new file mode 100644 index 0000000000000..49a175f69f434 --- /dev/null +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -0,0 +1,585 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from enum import Enum, auto +from random import choices +from string import ascii_letters, digits +from typing import Any + +import ijson +import regex as re +from pydantic import Field + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) + +logger = init_logger(__name__) + +ALPHANUMERIC = ascii_letters + digits + + +class StreamingState(Enum): + """Enum for tracking the current streaming parsing state.""" + + WAITING_FOR_TOOL_START = auto() + WAITING_FOR_TOOL_KEY = ( + auto() + ) # waiting for the "name" or "arguments" key to be complete + PARSING_NAME = auto() + PARSING_NAME_COMPLETED = auto() + WAITING_FOR_ARGUMENTS_START = auto() + PARSING_ARGUMENTS = auto() + PARSING_ARGUMENTS_COMPLETED = auto() + TOOL_COMPLETE = auto() + ALL_TOOLS_COMPLETE = auto() + + +class MistralToolCall(ToolCall): + id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id()) + + @staticmethod + def generate_random_id(): + # Mistral Tool Call Ids must be alphanumeric with a length of 9. + # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 + return "".join(choices(ALPHANUMERIC, k=9)) + + @staticmethod + def is_valid_id(id: str) -> bool: + return id.isalnum() and len(id) == 9 + + +def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool: + return not ( + isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11 + ) + + +class MistralToolParser(ToolParser): + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with + - [`mistral_common`](https://github.com/mistralai/mistral-common/) + - the examples/tool_chat_template_mistral.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + """ + + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + + if not isinstance(self.model_tokenizer, MistralTokenizer): + logger.info("Non-Mistral tokenizer detected when using a Mistral model...") + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: list[dict[str, Any]] = [] + self.current_tool_id: int = -1 + self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START + + # For streaming pre v11 tokenizer tool calls + self.current_tool_name: str | None = None + self.current_tool_mistral_id: str | None = None + self.starting_new_tool = False + if _is_pre_v11_tokeniser(self.model_tokenizer): + self.parse_coro = ijson.parse_coro( + self.update_stream_state_pre_v11_tokenizer() + ) + + self.bot_token = "[TOOL_CALLS]" + self.bot_token_id = self.vocab.get(self.bot_token) + self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) + self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer) + + if self.bot_token_id is None: + raise RuntimeError( + "Mistral Tool Parser could not locate the tool call token in " + "the tokenizer!" + ) + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) + if ( + not isinstance(self.model_tokenizer, MistralTokenizer) + and request.tools + and request.tool_choice != "none" + ): + # Do not skip special tokens when using chat template + # with Mistral parser as TOOL_CALL token is needed + # for tool detection. + # Note: we don't want skip_special_tokens=False + # with MistralTokenizer as it is incompatible + request.skip_special_tokens = False + return request + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! + """ + + # case -- if a tool call token is not present, return a text response + if self.bot_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + # first remove the BOT token + tool_content = model_output.replace(self.bot_token, "").strip() + + try: + try: + if not self._is_pre_v11: + function_call_arr = [] + for single_tool_content in model_output.split(self.bot_token): + if "{" not in single_tool_content: + continue + + end_name = single_tool_content.find("{") + fn_name, args = ( + single_tool_content[:end_name], + single_tool_content[end_name:], + ) + + # fn_name is encoded outside serialized json dump + # only arguments are serialized + function_call_arr.append( + {"name": fn_name, "arguments": json.loads(args)} + ) + else: + function_call_arr = json.loads(tool_content) + except json.JSONDecodeError: + # use a regex to find the part corresponding to the tool call. + # NOTE: This use case should not happen if the model is trained + # correctly. It's an easy possible fix so it's included, but + # can be brittle for very complex / highly nested tool calls + raw_tool_call = self.tool_call_regex.findall(tool_content)[0] + function_call_arr = json.loads(raw_tool_call) + + # Tool Call + tool_calls: list[MistralToolCall] = [ + MistralToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps( + raw_function_call["arguments"], ensure_ascii=False + ), + ), + ) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + content = model_output.split(self.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=tool_content + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if self.bot_token_id not in current_token_ids: + # if the tool call token is not in the tokens generated so far, + # append output to contents since it's not a tool + return DeltaMessage(content=delta_text) + + # if the tool call token IS in the tokens generated so far, that + # means we're parsing as tool calls now + try: + if _is_pre_v11_tokeniser(self.model_tokenizer): + return self._extract_tool_calls_streaming_pre_v11_tokenizer( + delta_text=delta_text, + delta_token_ids=delta_token_ids, + ) + else: + return self._extract_tool_calls_streaming( + delta_text=delta_text, delta_token_ids=delta_token_ids + ) + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None + + def _extract_tool_calls_streaming( + self, + delta_text: str, + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """ + Extracts tool calls for Mistral models + doing tool calls of the following format: + `[TOOL_CALLS]add{"a": 3.5, "b": 4}` + """ + additional_content: str = "" + if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START: + # this is the first tool call + assert self.bot_token_id in delta_token_ids + if not delta_text.startswith(self.bot_token): + additional_content += delta_text.split(self.bot_token)[0] + delta_text = self.bot_token + "".join( + delta_text.split(self.bot_token)[1:] + ) + + delta_tool_calls = self._generate_delta_tool_call(delta_text) + if not additional_content and len(delta_tool_calls) == 0: + if self.streaming_state in [ + StreamingState.PARSING_ARGUMENTS, + StreamingState.PARSING_ARGUMENTS_COMPLETED, + StreamingState.TOOL_COMPLETE, + StreamingState.ALL_TOOLS_COMPLETE, + ]: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage() + else: + # return None when the tool is not likely to be finished + # This can occur when the name is being parsed for example + # and we wait for the name to be complete + # before sending the function name + return None + + delta = DeltaMessage() + if additional_content: + delta.content = additional_content + if len(delta_tool_calls) > 0: + delta.tool_calls = delta_tool_calls + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining its final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if delta_tool_calls and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + return delta + + def _generate_delta_tool_call(self, delta_text: str) -> list[DeltaToolCall]: + if delta_text == "" or delta_text is None: + return [] + delta_function_name = None + tool_id = None + if self.streaming_state not in [ + StreamingState.PARSING_NAME, + StreamingState.PARSING_ARGUMENTS, + ] and delta_text.startswith(self.bot_token): + self.current_tool_id += 1 + self.streaming_state = StreamingState.PARSING_NAME + delta_text = delta_text.replace(self.bot_token, "", 1) + if self.streaming_state == StreamingState.PARSING_NAME: + if self.current_tool_name is None: + self.current_tool_name = "" + # The name stops where the arguments start + # And the arguments start with the `{` char + if "{" in delta_text: + tool_id = MistralToolCall.generate_random_id() + delta_function_name = delta_text.split("{")[0] + self.current_tool_name += delta_function_name + delta_text = delta_text[len(delta_function_name) :] + self.streaming_state = StreamingState.PARSING_ARGUMENTS + else: + # we want to send the tool name once it's complete + self.current_tool_name += delta_text + return [] + if self.streaming_state == StreamingState.PARSING_ARGUMENTS: + next_function_text = None + if self.bot_token in delta_text: + # current tool call is over + delta_arguments = "" + delta_arguments += delta_text.split(self.bot_token)[0] + next_function_text = delta_text[len(delta_arguments) :] + self.streaming_state = StreamingState.TOOL_COMPLETE + else: + delta_arguments = delta_text + ret = [] + if self.current_tool_name or delta_arguments: + ret += [ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=self.current_tool_name, arguments=delta_arguments + ).model_dump(exclude_none=True), + ) + ] + self.current_tool_name = None + if next_function_text: + ret += self._generate_delta_tool_call(next_function_text) + return ret + # Should not happen + return [] + + @ijson.coroutine + def update_stream_state_pre_v11_tokenizer(self): + while True: + (prefix, event, value) = yield + + if prefix == "item" and event == "start_map": + self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY + if prefix == "item" and event == "map_key" and value == "name": + self.streaming_state = StreamingState.PARSING_NAME + if prefix == "item.name" and event == "string": + self.current_tool_name = value + self.streaming_state = StreamingState.PARSING_NAME_COMPLETED + if prefix == "item" and event == "map_key" and value == "arguments": + self.streaming_state = StreamingState.WAITING_FOR_ARGUMENTS_START + if prefix == "item.arguments" and event == "start_map": + self.streaming_state = StreamingState.PARSING_ARGUMENTS + if prefix == "item.arguments" and event == "end_map": + self.streaming_state = StreamingState.PARSING_ARGUMENTS_COMPLETED + if prefix == "item" and event == "end_map": + self.streaming_state = StreamingState.TOOL_COMPLETE + if prefix == "" and event == "end_array": + self.streaming_state = StreamingState.ALL_TOOLS_COMPLETE + + def _extract_tool_calls_streaming_pre_v11_tokenizer( + self, + delta_text: str, + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """ + Extracts tool calls for Mistral models + doing tool calls of the following format: + `[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}` + """ + assert self.parse_coro is not None + content = None + delta_tool_calls: list[DeltaToolCall] = [] + current_tool_call: DeltaToolCall = DeltaToolCall( + index=self.current_tool_id, type="function" + ) + current_tool_call_modified = False + if self.bot_token_id in delta_token_ids: + # this is the first tool call + if not delta_text.startswith(self.bot_token): + content = delta_text.split(self.bot_token)[0] + delta_text = "".join(delta_text.split(self.bot_token)[1:]) + + # Cut smartly the delta text to catch the ijson events + # as ijson does not give us the index in the text at each event. + # We need to cut so that we know + # where in the text the events are emitted from. + while len(delta_text) > 0: + streaming_state_before_parse = self.streaming_state + + if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_opening_curly_braces=1, + ) + elif self.streaming_state == StreamingState.WAITING_FOR_TOOL_KEY: + # Wait until another key is sent + # or the current tool is completed + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_colon=1, + stop_after_opening_curly_braces=1, + # if the tool ends, we want to separate + # at the start of the next tool + ) + elif self.streaming_state == StreamingState.PARSING_NAME: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_comma=1, + stop_after_closing_brackets=1, + ) + elif self.streaming_state == StreamingState.WAITING_FOR_ARGUMENTS_START: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_opening_curly_braces=1, + ) + elif self.streaming_state == StreamingState.PARSING_ARGUMENTS: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_closing_curly_braces=1, + # we could be more clever + # by listening to item.arguments.* start_map events + # and know how many curly braces we can allow + ) + elif self.streaming_state in [ + StreamingState.PARSING_ARGUMENTS_COMPLETED, + StreamingState.PARSING_NAME_COMPLETED, + ]: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_closing_curly_braces=1, + stop_after_closing_brackets=1, + ) + elif self.streaming_state == StreamingState.TOOL_COMPLETE: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_opening_curly_braces=1, + stop_after_closing_brackets=1, + ) + elif self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE: + content = delta_text + delta_text = "" + else: + delta_to_be_parsed = delta_text + delta_text = "" + + if self.streaming_state != StreamingState.ALL_TOOLS_COMPLETE: + self.parse_coro.send(delta_to_be_parsed.encode("utf-8")) + + # Given the parsed text and the possible streaming state change, + # let's add to the tool delta + if ( + (streaming_state_before_parse != self.streaming_state) + and streaming_state_before_parse + in [StreamingState.WAITING_FOR_TOOL_START, StreamingState.TOOL_COMPLETE] + and self.streaming_state + not in [ + StreamingState.ALL_TOOLS_COMPLETE, + StreamingState.TOOL_COMPLETE, + StreamingState.WAITING_FOR_TOOL_START, + ] + ): + # starting a new tool call + if current_tool_call_modified: + if self.current_tool_mistral_id is not None: + current_tool_call.id = self.current_tool_mistral_id + self.current_tool_mistral_id = None + delta_tool_calls.append(current_tool_call) + current_tool_call_modified = False + self.current_tool_id += 1 + self.current_tool_mistral_id = MistralToolCall.generate_random_id() + current_tool_call = DeltaToolCall( + index=self.current_tool_id, + type="function", + ) + if current_tool_call.function is None: + current_tool_call.function = DeltaFunctionCall() + + if self.current_tool_name is not None: + # we have the complete tool name + current_tool_call_modified = True + current_tool_call.function.name = self.current_tool_name + self.current_tool_name = None + if self.streaming_state == StreamingState.PARSING_NAME_COMPLETED: + self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY + if self.streaming_state in [ + StreamingState.PARSING_ARGUMENTS, + StreamingState.PARSING_ARGUMENTS_COMPLETED, + ]: + if self.streaming_state == StreamingState.PARSING_ARGUMENTS_COMPLETED: + self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY + # the delta_to_be_parsed is part of arguments. + current_tool_call_modified = True + if current_tool_call.function.arguments is None: + current_tool_call.function.arguments = delta_to_be_parsed + else: + current_tool_call.function.arguments += delta_to_be_parsed + if streaming_state_before_parse != StreamingState.PARSING_ARGUMENTS: + # It's the first chunk of arg. let's lstrip it + current_tool_call.function.arguments = ( + current_tool_call.function.arguments.lstrip() + ) + + if current_tool_call_modified: + if self.current_tool_mistral_id is not None: + current_tool_call.id = self.current_tool_mistral_id + self.current_tool_mistral_id = None + delta_tool_calls.append(current_tool_call) + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if delta_tool_calls and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if content or len(delta_tool_calls) > 0: + delta_message = DeltaMessage() + if content: + delta_message.content = content + if len(delta_tool_calls) > 0: + delta_message.tool_calls = delta_tool_calls + return delta_message + else: + if self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE: + return DeltaMessage() + else: + return None + + def _split_delta( + self, + delta_text: str, + stop_after_quotes: int = -1, + stop_after_opening_curly_braces: int = -1, + stop_after_closing_curly_braces: int = -1, + stop_after_closing_brackets: int = -1, + stop_after_colon: int = -1, + stop_after_comma=-1, + ) -> tuple[str, str]: + delta_to_be_parsed = "" + for i, c in enumerate(delta_text): + if c in ['"', "'"]: + delta_to_be_parsed += c + stop_after_quotes -= 1 + if stop_after_quotes == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == "{": + delta_to_be_parsed += c + stop_after_opening_curly_braces -= 1 + if stop_after_opening_curly_braces == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == "}": + delta_to_be_parsed += c + stop_after_closing_curly_braces -= 1 + if stop_after_closing_curly_braces == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == "]": + delta_to_be_parsed += c + stop_after_closing_brackets -= 1 + if stop_after_closing_brackets == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == ":": + delta_to_be_parsed += c + stop_after_colon -= 1 + if stop_after_colon == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == ",": + delta_to_be_parsed += c + stop_after_comma -= 1 + if stop_after_comma == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + else: + delta_to_be_parsed += c + + return (delta_to_be_parsed, "") diff --git a/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py b/vllm/tool_parsers/olmo3_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py rename to vllm/tool_parsers/olmo3_tool_parser.py index baff33bd7e8ac..8cd6a84a9f6b1 100644 --- a/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py +++ b/vllm/tool_parsers/olmo3_tool_parser.py @@ -18,10 +18,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/tool_parsers/openai_tool_parser.py similarity index 86% rename from vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py rename to vllm/tool_parsers/openai_tool_parser.py index 8bdf35d408805..db92ea8982d70 100644 --- a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py +++ b/vllm/tool_parsers/openai_tool_parser.py @@ -4,7 +4,7 @@ import json from collections.abc import Sequence from typing import TYPE_CHECKING -from vllm.entrypoints.harmony_utils import parse_output_into_messages +from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaMessage, @@ -12,10 +12,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger if TYPE_CHECKING: from vllm.tokenizers import TokenizerLike @@ -43,6 +43,7 @@ class OpenAIToolParser(ToolParser): parser = parse_output_into_messages(token_ids) tool_calls = [] final_content = None + commentary_content = None if len(parser.messages) > 0: for msg in parser.messages: @@ -75,11 +76,15 @@ class OpenAIToolParser(ToolParser): ) elif msg.channel == "final": final_content = msg_text + elif msg.channel == "commentary" and not msg.recipient: + commentary_content = msg_text return ExtractedToolCallInformation( tools_called=len(tool_calls) > 0, tool_calls=tool_calls, - content=final_content, + # prefer final content over commentary content if both are present + # commentary content is tool call preambles meant to be shown to the user + content=final_content or commentary_content, ) def extract_tool_calls_streaming( diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/tool_parsers/phi4mini_tool_parser.py similarity index 98% rename from vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py rename to vllm/tool_parsers/phi4mini_tool_parser.py index acb25ea2768e1..9003429d8c6f2 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/tool_parsers/phi4mini_tool_parser.py @@ -16,10 +16,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/tool_parsers/pythonic_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py rename to vllm/tool_parsers/pythonic_tool_parser.py index abeb923b93227..476a62d5f5273 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/tool_parsers/pythonic_tool_parser.py @@ -19,10 +19,10 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.logger import init_logger +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py rename to vllm/tool_parsers/qwen3coder_tool_parser.py index d49b14690ef03..d1a3cbeaafc7d 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -18,11 +18,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py b/vllm/tool_parsers/qwen3xml_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py rename to vllm/tool_parsers/qwen3xml_tool_parser.py index 03862ff432a5d..107f791654a1a 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py +++ b/vllm/tool_parsers/qwen3xml_tool_parser.py @@ -19,11 +19,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/tool_parsers/seed_oss_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py rename to vllm/tool_parsers/seed_oss_tool_parser.py index c7947faad1923..206072e65b10f 100644 --- a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +++ b/vllm/tool_parsers/seed_oss_tool_parser.py @@ -21,11 +21,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/tool_parsers/step3_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py rename to vllm/tool_parsers/step3_tool_parser.py index 9213d6859dd93..acd99bf56d0b6 100644 --- a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +++ b/vllm/tool_parsers/step3_tool_parser.py @@ -17,11 +17,11 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, -) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) from vllm.utils import random_uuid logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/tool_parsers/utils.py similarity index 100% rename from vllm/entrypoints/openai/tool_parsers/utils.py rename to vllm/tool_parsers/utils.py diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/tool_parsers/xlam_tool_parser.py similarity index 99% rename from vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py rename to vllm/tool_parsers/xlam_tool_parser.py index effd2bd08b42a..9c2b585fe9fdb 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/tool_parsers/xlam_tool_parser.py @@ -17,7 +17,7 @@ from vllm.entrypoints.openai.protocol import ( FunctionCall, ToolCall, ) -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( +from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) from vllm.logger import init_logger diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1bb5791e19016..a11d37b4b2edf 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -26,8 +26,15 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger +from vllm.transformers_utils.utils import parse_safetensors_file_metadata from .config_parser_base import ConfigParserBase +from .gguf_utils import ( + check_gguf_file, + is_gguf, + is_remote_gguf, + split_remote_gguf, +) from .repo_utils import ( _get_hf_token, file_or_path_exists, @@ -36,13 +43,6 @@ from .repo_utils import ( try_get_local_file, with_retry, ) -from .utils import ( - check_gguf_file, - is_gguf, - is_remote_gguf, - parse_safetensors_file_metadata, - split_remote_gguf, -) if envs.VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -66,6 +66,7 @@ class LazyConfigDict(dict): _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( afmoe="AfmoeConfig", + bagel="BagelConfig", chatglm="ChatGLMConfig", deepseek_vl_v2="DeepseekVLV2Config", deepseek_v32="DeepseekV3Config", @@ -89,6 +90,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( step3_text="Step3TextConfig", qwen3_next="Qwen3NextConfig", lfm2_moe="Lfm2MoeConfig", + tarsier2="Tarsier2Config", ) _CONFIG_ATTRS_MAPPING: dict[str, str] = { @@ -127,6 +129,9 @@ class HFConfigParser(ConfigParserBase): if config_dict.get("speculators_config") is not None else model_type ) + # Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY + if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None: + model_type = hf_overrides.get("model_type", model_type) if model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[model_type] @@ -300,17 +305,40 @@ def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> No def patch_rope_parameters(config: PretrainedConfig) -> None: """Provide backwards compatibility for RoPE.""" + from vllm.config.utils import getattr_iter + + # Older custom models may use non-standard field names + # which need patching for both Transformers v4 and v5. + names = ["rope_theta", "rotary_emb_base"] + rope_theta = getattr_iter(config, names, None, warn=True) + names = ["partial_rotary_factor", "rotary_pct", "rotary_emb_fraction"] + partial_rotary_factor = getattr_iter(config, names, None, warn=True) + if Version(version("transformers")) < Version("5.0.0.dev0"): # Transformers v4 installed, legacy config fields may be present if (rope_scaling := getattr(config, "rope_scaling", None)) is not None: config.rope_parameters = rope_scaling - if (rope_theta := getattr(config, "rope_theta", None)) is not None: + if rope_theta is not None: if not hasattr(config, "rope_parameters"): config.rope_parameters = {"rope_type": "default"} config.rope_parameters["rope_theta"] = rope_theta + if partial_rotary_factor is not None: + if not hasattr(config, "rope_parameters"): + config.rope_parameters = {"rope_type": "default"} + config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor + elif rope_theta is not None or hasattr(config, "rope_parameters"): + # Transformers v5 installed + # Patch these fields in case they used non-standard names + if rope_theta is not None: + config.rope_theta = rope_theta + if partial_rotary_factor is not None: + config.partial_rotary_factor = partial_rotary_factor + # Standardize and validate RoPE parameters + config.standardize_rope_params() + config.validate_rope() # No RoPE parameters to patch - if not hasattr(config, "rope_parameters"): + if getattr(config, "rope_parameters", None) is None: return # Add original_max_position_embeddings if present @@ -351,7 +379,10 @@ def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None: rope_parameters["rope_type"] = "longrope" logger.warning("Replacing legacy rope_type 'su' with 'longrope'") elif rope_parameters["rope_type"] == "mrope": - assert "mrope_section" in rope_parameters + if "mrope_section" not in rope_parameters: + raise ValueError( + "Legacy rope_type 'mrope' requires 'mrope_section' in rope_parameters" + ) rope_parameters["rope_type"] = "default" logger.warning("Replacing legacy rope_type 'mrope' with 'default'") @@ -584,8 +615,31 @@ def get_config( trust_remote_code=trust_remote_code, revision=revision, code_revision=code_revision, + hf_overrides=hf_overrides_kw, **kwargs, ) + + # Patching defaults for GGUF models + if _is_gguf: + # Some models have different default values between GGUF and HF. + def apply_gguf_default(key: str, gguf_default: Any): + """ + Apply GGUF defaults unless explicitly configured. + + This function reads/writes external `config` and `config_dict`. + If the specified `key` is not in `config_dict` (i.e. not explicitly + configured and the default HF value is used), it updates the + corresponding `config` value to `gguf_default`. + """ + if key not in config_dict: + config.update({key: gguf_default}) + + # Apply architecture-specific GGUF defaults. + if config.model_type in {"qwen3_moe"}: + # Qwen3 MoE: norm_topk_prob is always true. + # Note that, this parameter is always false (HF default) on Qwen2 MoE. + apply_gguf_default("norm_topk_prob", True) + # Special architecture mapping check for GGUF models if _is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: @@ -915,11 +969,13 @@ def get_hf_text_config(config: PretrainedConfig): """ text_config = config.get_text_config() - if text_config is not config: - # The code operates under the assumption that text_config should have - # `num_attention_heads` (among others). Assert here to fail early - # if transformers config doesn't align with this assumption. - assert hasattr(text_config, "num_attention_heads") + if text_config is not config and not hasattr(text_config, "num_attention_heads"): + raise ValueError( + "The text_config extracted from the model config does not have " + "`num_attention_heads` attribute. This indicates a mismatch " + "between the model config and vLLM's expectations. Please " + "ensure that the model config is compatible with vLLM." + ) return text_config @@ -930,6 +986,13 @@ def try_get_generation_config( revision: str | None = None, config_format: str | ConfigFormat = "auto", ) -> GenerationConfig | None: + # GGUF files don't have generation_config.json - their config is embedded + # in the file header. Skip all filesystem lookups to avoid re-reading the + # memory-mapped file, which can hang in multi-process scenarios when the + # EngineCore process already has the file mapped. + if is_gguf(model): + return None + try: return GenerationConfig.from_pretrained( model, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 109f2b6986514..54fe1b8d7b523 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -10,48 +10,52 @@ Model configs may be defined in this directory for the following reasons: deepseek-ai/DeepSeek-V3.2-Exp. """ -from transformers import DeepseekV3Config +from __future__ import annotations -from vllm.transformers_utils.configs.afmoe import AfmoeConfig -from vllm.transformers_utils.configs.chatglm import ChatGLMConfig -from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config -from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig -from vllm.transformers_utils.configs.eagle import EAGLEConfig +import importlib -# RWConfig is for the original tiiuae/falcon-40b(-instruct) and -# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the -# `FalconConfig` class from the official HuggingFace transformers library. -from vllm.transformers_utils.configs.falcon import RWConfig -from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig -from vllm.transformers_utils.configs.hunyuan_vl import ( - HunYuanVLConfig, - HunYuanVLTextConfig, - HunYuanVLVisionConfig, -) -from vllm.transformers_utils.configs.jais import JAISConfig -from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig -from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig -from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig -from vllm.transformers_utils.configs.medusa import MedusaConfig -from vllm.transformers_utils.configs.midashenglm import MiDashengLMConfig -from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig -from vllm.transformers_utils.configs.moonvit import MoonViTConfig -from vllm.transformers_utils.configs.nemotron import NemotronConfig -from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig -from vllm.transformers_utils.configs.olmo3 import Olmo3Config -from vllm.transformers_utils.configs.ovis import OvisConfig -from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig -from vllm.transformers_utils.configs.radio import RadioConfig -from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig -from vllm.transformers_utils.configs.step3_vl import ( - Step3TextConfig, - Step3VisionEncoderConfig, - Step3VLConfig, -) -from vllm.transformers_utils.configs.ultravox import UltravoxConfig +_CLASS_TO_MODULE: dict[str, str] = { + "AfmoeConfig": "vllm.transformers_utils.configs.afmoe", + "BagelConfig": "vllm.transformers_utils.configs.bagel", + "ChatGLMConfig": "vllm.transformers_utils.configs.chatglm", + "DeepseekVLV2Config": "vllm.transformers_utils.configs.deepseek_vl2", + "DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr", + "EAGLEConfig": "vllm.transformers_utils.configs.eagle", + "FlexOlmoConfig": "vllm.transformers_utils.configs.flex_olmo", + "HunYuanVLConfig": "vllm.transformers_utils.configs.hunyuan_vl", + "HunYuanVLTextConfig": "vllm.transformers_utils.configs.hunyuan_vl", + "HunYuanVLVisionConfig": "vllm.transformers_utils.configs.hunyuan_vl", + # RWConfig is for the original tiiuae/falcon-40b(-instruct) and + # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the + # `FalconConfig` class from the official HuggingFace transformers library. + "RWConfig": "vllm.transformers_utils.configs.falcon", + "JAISConfig": "vllm.transformers_utils.configs.jais", + "Lfm2MoeConfig": "vllm.transformers_utils.configs.lfm2_moe", + "MedusaConfig": "vllm.transformers_utils.configs.medusa", + "MiDashengLMConfig": "vllm.transformers_utils.configs.midashenglm", + "MLPSpeculatorConfig": "vllm.transformers_utils.configs.mlp_speculator", + "MoonViTConfig": "vllm.transformers_utils.configs.moonvit", + "KimiLinearConfig": "vllm.transformers_utils.configs.kimi_linear", + "KimiVLConfig": "vllm.transformers_utils.configs.kimi_vl", + "NemotronConfig": "vllm.transformers_utils.configs.nemotron", + "NemotronHConfig": "vllm.transformers_utils.configs.nemotron_h", + "Olmo3Config": "vllm.transformers_utils.configs.olmo3", + "OvisConfig": "vllm.transformers_utils.configs.ovis", + "RadioConfig": "vllm.transformers_utils.configs.radio", + "SpeculatorsConfig": "vllm.transformers_utils.configs.speculators.base", + "UltravoxConfig": "vllm.transformers_utils.configs.ultravox", + "Step3VLConfig": "vllm.transformers_utils.configs.step3_vl", + "Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl", + "Step3TextConfig": "vllm.transformers_utils.configs.step3_vl", + "Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next", + "Tarsier2Config": "vllm.transformers_utils.configs.tarsier2", + # Special case: DeepseekV3Config is from HuggingFace Transformers + "DeepseekV3Config": "transformers", +} __all__ = [ "AfmoeConfig", + "BagelConfig", "ChatGLMConfig", "DeepseekVLV2Config", "DeepseekV3Config", @@ -81,4 +85,18 @@ __all__ = [ "Step3VisionEncoderConfig", "Step3TextConfig", "Qwen3NextConfig", + "Tarsier2Config", ] + + +def __getattr__(name: str): + if name in _CLASS_TO_MODULE: + module_name = _CLASS_TO_MODULE[name] + module = importlib.import_module(module_name) + return getattr(module, name) + + raise AttributeError(f"module 'configs' has no attribute '{name}'") + + +def __dir__(): + return sorted(list(__all__)) diff --git a/vllm/transformers_utils/configs/bagel.py b/vllm/transformers_utils/configs/bagel.py new file mode 100644 index 0000000000000..53347ef452138 --- /dev/null +++ b/vllm/transformers_utils/configs/bagel.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import PretrainedConfig, SiglipVisionConfig +from transformers.models.qwen2 import Qwen2Config + + +class BagelConfig(PretrainedConfig): + """Configuration class for BAGEL model.""" + + model_type = "bagel" + + def __init__( + self, + visual_gen: bool = True, + visual_und: bool = True, + llm_config: dict | Qwen2Config | None = None, + vit_config: dict | SiglipVisionConfig | None = None, + vae_config: dict | None = None, + latent_patch_size: int = 2, + max_latent_size: int = 32, + vit_max_num_patch_per_side: int = 70, + connector_act: str = "gelu_pytorch_tanh", + interpolate_pos: bool = False, + timestep_shift: float = 1.0, + **kwargs, + ): + super().__init__(**kwargs) + self.visual_gen = visual_gen + self.visual_und = visual_und + + # Convert dict configs to proper config objects + if isinstance(llm_config, dict): + self.llm_config = Qwen2Config(**llm_config) + else: + self.llm_config = llm_config or Qwen2Config() + + if isinstance(vit_config, dict): + self.vit_config = SiglipVisionConfig(**vit_config) + else: + self.vit_config = vit_config or SiglipVisionConfig() + + self.vae_config = vae_config or {"z_channels": 16, "downsample": 8} + self.latent_patch_size = latent_patch_size + self.max_latent_size = max_latent_size + self.vit_max_num_patch_per_side = vit_max_num_patch_per_side + self.connector_act = connector_act + self.interpolate_pos = interpolate_pos + self.timestep_shift = timestep_shift + + @property + def hidden_size(self) -> int: + """Return the hidden size of the language model.""" + return self.llm_config.hidden_size diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index f5dc9ddfbc575..ce428e567c844 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -82,3 +82,9 @@ class EAGLEConfig(PretrainedConfig): pretrained_model_name_or_path, **kwargs ) return cls.from_dict(config_dict, **kwargs) + + def to_json_string(self, use_diff: bool = True) -> str: + # we override use_diff to False as initializing + # EAGLEConfig with default arguments is not supported + del use_diff + return super().to_json_string(use_diff=False) diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index 966737aad0867..d59169d95f0c9 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -18,9 +18,31 @@ def adapt_config_dict( if bool(config_dict.get("quantization")): config_dict = _remap_mistral_quantization_args(config_dict) + is_moe = bool(config_dict.get("moe")) + is_mistral_large_3 = ( + is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0 + ) if config_dict.get("model_type") == "mamba": config_dict["architectures"] = ["Mamba2ForCausalLM"] - elif bool(config_dict.get("moe")): + elif is_moe and is_mistral_large_3: + config_dict = _remap_moe_args(config_dict) + config_dict["model_type"] = "deepseek_v3" + config_dict["architectures"] = ["MistralLarge3ForCausalLM"] + + assert "llama_4_scaling" in config_dict, ( + "MistralLarge3 expect llama4 scaling config." + ) + llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"] + assert all( + [ + key in config_dict["llama_4_scaling"] + for key in llama_4_scaling_config_keys + ] + ), ( + "llama_4_scaling config should define the keys: " + f"{','.join(llama_4_scaling_config_keys)}" + ) + elif is_moe: config_dict["architectures"] = ["MixtralForCausalLM"] else: config_dict["architectures"] = ["MistralForCausalLM"] @@ -140,17 +162,20 @@ def _remap_general_mistral_args(config: dict) -> dict: def _remap_mistral_quantization_args(config: dict) -> dict: - quantization = config.get("quantization", {}) - if quantization.get("qformat_weight") == "fp8_e4m3": - # This maps to the FP8 static per-tensor quantization scheme - quantization_config = {"quant_method": "fp8", "activation_scheme": "static"} - elif quantization.get("quant_method") == "compressed-tensors": - # Pass through the quantization config to compressed-tensors - quantization_config = quantization - else: - raise ValueError(f"Found unknown quantization='{quantization}' in config") - - config["quantization_config"] = quantization_config + if config.get("quantization"): + quantization = config.pop("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + qscheme_act = quantization.get("qscheme_act") + assert qscheme_act in ("NO_SCALES", "TENSOR", None), ( + "Only NO_SCALES and TENSOR (default) are supported for qscheme_act" + ) + is_dynamic = qscheme_act == "NO_SCALES" + config["quantization_config"] = { + "quant_method": "fp8", + "activation_scheme": "dynamic" if is_dynamic else "static", + } + else: + raise ValueError(f"Found unknown quantization='{quantization}' in config") return config @@ -183,3 +208,28 @@ def _remap_mistral_audio_args(config: dict) -> dict: if quant_config: config["quantization_config"] = quant_config return config + + +def _remap_moe_args(config: dict) -> dict: + moe_config_map = { + "route_every_n": "moe_layer_freq", + "first_k_dense_replace": "first_k_dense_replace", + "num_experts_per_tok": "num_experts_per_tok", + "num_experts": "n_routed_experts", + "expert_hidden_dim": "moe_intermediate_size", + "routed_scale": "routed_scaling_factor", + "num_shared_experts": "n_shared_experts", + "num_expert_groups": "n_group", + "num_expert_groups_per_tok": "topk_group", + } + moe_config = config.get("moe", {}) + for old_name, new_name in moe_config_map.items(): + if old_name in moe_config: + value = moe_config.pop(old_name) + config[new_name] = value + + config["topk_method"] = None + config["norm_topk_prob"] = True + config["scoring_func"] = "softmax" + + return config diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py index d112c71d7d20b..62f52703029b7 100644 --- a/vllm/transformers_utils/configs/nemotron.py +++ b/vllm/transformers_utils/configs/nemotron.py @@ -89,9 +89,14 @@ class NemotronConfig(PretrainedConfig): tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_parameters (`dict`, *optional*): - The parameters of the RoPE embeddings. - partial_rotary_factor (`float`, *optional*, defaults to 0.5): - Percentage of the query and keys which will have rotary embedding. + The parameters of the RoPE embeddings. Expected contents: + `rope_theta` (`float`): The base period of the RoPE embeddings. + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', + 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the + original RoPE implementation. + `partial_rotary_factor` (`float`, *optional*, defaults to 0.5): + Percentage of the query and keys which will have rotary embedding. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. @@ -133,7 +138,6 @@ class NemotronConfig(PretrainedConfig): eos_token_id=3, tie_word_embeddings=False, rope_parameters=None, - partial_rotary_factor=0.5, attention_bias=False, attention_dropout=0.0, mlp_bias=False, @@ -165,14 +169,16 @@ class NemotronConfig(PretrainedConfig): rope_theta = kwargs.pop("rope_theta", 10000.0) if "rope_theta" not in rope_parameters: rope_parameters["rope_theta"] = rope_theta - self.rope_parameters = rope_parameters # for backward compatibility partial_rotary_factor = ( kwargs.get("rope_percent") or kwargs.get("rope_percentage") - or partial_rotary_factor + or kwargs.get("partial_rotary_factor") + or 0.5 ) - self.partial_rotary_factor = partial_rotary_factor + if "partial_rotary_factor" not in rope_parameters: + rope_parameters["partial_rotary_factor"] = partial_rotary_factor + self.rope_parameters = rope_parameters self._rope_parameters_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index 68c40002098c8..86c117fd9d59f 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -189,6 +189,7 @@ class NemotronHConfig(PretrainedConfig): n_shared_experts=1, moe_intermediate_size=7688, moe_shared_expert_intermediate_size=7688, + moe_latent_size=None, num_experts_per_tok=2, routed_scaling_factor=1.0, n_group=1, @@ -254,6 +255,7 @@ class NemotronHConfig(PretrainedConfig): self.n_shared_experts = n_shared_experts self.moe_intermediate_size = moe_intermediate_size self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501 + self.moe_latent_size = moe_latent_size self.num_experts_per_tok = num_experts_per_tok self.routed_scaling_factor = routed_scaling_factor self.n_group = n_group diff --git a/vllm/transformers_utils/configs/qwen3_next.py b/vllm/transformers_utils/configs/qwen3_next.py index fd36b49245f56..8230a18343c5e 100644 --- a/vllm/transformers_utils/configs/qwen3_next.py +++ b/vllm/transformers_utils/configs/qwen3_next.py @@ -103,8 +103,8 @@ class Qwen3NextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - partial_rotary_factor (`float`, *optional*, defaults to 0.25): - Percentage of the query and keys which will have rotary embedding. + `partial_rotary_factor` (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -198,7 +198,6 @@ class Qwen3NextConfig(PretrainedConfig): use_cache=True, tie_word_embeddings=False, rope_parameters=None, - partial_rotary_factor=0.25, attention_bias=False, attention_dropout=0.0, head_dim=256, @@ -239,6 +238,9 @@ class Qwen3NextConfig(PretrainedConfig): rope_theta = kwargs.pop("rope_theta", 10000.0) if "rope_theta" not in rope_parameters: rope_parameters["rope_theta"] = rope_theta + partial_rotary_factor = kwargs.pop("partial_rotary_factor", 0.25) + if "partial_rotary_factor" not in rope_parameters: + rope_parameters["partial_rotary_factor"] = partial_rotary_factor self.rope_parameters = rope_parameters self.partial_rotary_factor = partial_rotary_factor self.attention_bias = attention_bias diff --git a/vllm/transformers_utils/configs/tarsier2.py b/vllm/transformers_utils/configs/tarsier2.py new file mode 100644 index 0000000000000..12ebb4b7f602d --- /dev/null +++ b/vllm/transformers_utils/configs/tarsier2.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import Qwen2VLConfig + + +class Tarsier2Config(Qwen2VLConfig): + """ + Tarsier2's config.json is written such that AutoConfig.from_pretrained will create + a deeply nested config consisting of: + + - LlavaConfig + - Qwen2VLConfig + - Qwen2VLTextConfig + - Qwen2VLVisionConfig + - Qwen2VLConfig + - Qwen2VLTextConfig + - Qwen2VLVisionConfig + + When it should really just be a single Qwen2VLConfig. + + This class is a hack to stop AutoConfig from creating the nested config structure. + """ + + model_type = "tarsier2" diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index fc0360a9ecb4e..395b3130d40af 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -61,6 +61,7 @@ class UltravoxConfig(transformers.PretrainedConfig): norm_init: float = 0.4, projector_act: str = "swiglu", projector_ln_mid: bool = False, + num_projector_layers: int = 0, **kwargs, ): self.ignore_index = ignore_index @@ -71,6 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig): self.norm_init = norm_init self.projector_act = projector_act self.projector_ln_mid = projector_ln_mid + self.num_projector_layers = num_projector_layers # N.B. May set the wrapped_model_config below. self.text_model_id = text_model_id diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py index cb1fc2d092e01..f3fd43c6ace51 100644 --- a/vllm/transformers_utils/gguf_utils.py +++ b/vllm/transformers_utils/gguf_utils.py @@ -2,10 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """GGUF utility functions.""" +from functools import cache +from os import PathLike from pathlib import Path import gguf +import regex as re from gguf.constants import Keys, VisionProjectorType +from gguf.quants import GGMLQuantizationType from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig from vllm.logger import init_logger @@ -15,6 +19,73 @@ from .repo_utils import list_filtered_repo_files logger = init_logger(__name__) +@cache +def check_gguf_file(model: str | PathLike) -> bool: + """Check if the file is a GGUF model.""" + model = Path(model) + if not model.is_file(): + return False + elif model.suffix == ".gguf": + return True + + try: + with model.open("rb") as f: + header = f.read(4) + + return header == b"GGUF" + except Exception as e: + logger.debug("Error reading file %s: %s", model, e) + return False + + +@cache +def is_remote_gguf(model: str | Path) -> bool: + """Check if the model is a remote GGUF model.""" + pattern = r"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*:[A-Za-z0-9_+-]+$" + model = str(model) + if re.fullmatch(pattern, model): + _, quant_type = model.rsplit(":", 1) + return is_valid_gguf_quant_type(quant_type) + return False + + +def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool: + """Check if the quant type is a valid GGUF quant type.""" + return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None + + +def split_remote_gguf(model: str | Path) -> tuple[str, str]: + """Split the model into repo_id and quant type.""" + model = str(model) + if is_remote_gguf(model): + parts = model.rsplit(":", 1) + return (parts[0], parts[1]) + raise ValueError( + f"Wrong GGUF model or invalid GGUF quant type: {model}.\n" + "- It should be in repo_id:quant_type format.\n" + f"- Valid GGMLQuantizationType values: {GGMLQuantizationType._member_names_}", + ) + + +def is_gguf(model: str | Path) -> bool: + """Check if the model is a GGUF model. + + Args: + model: Model name, path, or Path object to check. + + Returns: + True if the model is a GGUF model, False otherwise. + """ + model = str(model) + + # Check if it's a local GGUF file + if check_gguf_file(model): + return True + + # Check if it's a remote GGUF model (repo_id:quant_type format) + return is_remote_gguf(model) + + def detect_gguf_multimodal(model: str) -> Path | None: """Check if GGUF model has multimodal projector file. diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 63cdf63370342..e9864b0c1531d 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -18,7 +18,8 @@ from transformers.processing_utils import ProcessorMixin from transformers.video_processing_utils import BaseVideoProcessor from typing_extensions import TypeVar -from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf +from vllm.transformers_utils.gguf_utils import is_gguf +from vllm.transformers_utils.utils import convert_model_repo_to_path from vllm.utils.func_utils import get_allowed_kwarg_only_overrides if TYPE_CHECKING: diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index b49fdbe9ce776..af25dbe4ccdfe 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -8,6 +8,7 @@ reasons: - There is a need to override the existing processor to support vLLM. """ +from vllm.transformers_utils.processors.bagel import BagelProcessor from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor @@ -15,6 +16,7 @@ from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor __all__ = [ + "BagelProcessor", "DeepseekVLV2Processor", "HunYuanVLProcessor", "HunYuanVLImageProcessor", diff --git a/vllm/transformers_utils/processors/bagel.py b/vllm/transformers_utils/processors/bagel.py new file mode 100644 index 0000000000000..850e64f2fad1e --- /dev/null +++ b/vllm/transformers_utils/processors/bagel.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +"""BAGEL processor for image and text inputs.""" + +from transformers import AutoProcessor +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + + +class BagelProcessor(ProcessorMixin): + """ + Constructs a BAGEL processor which wraps a + SigLIP image processor and a Qwen2 tokenizer. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "SiglipImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __call__( + self, + text: TextInput + | PreTokenizedInput + | list[TextInput] + | list[PreTokenizedInput] = None, + images: ImageInput = None, + **kwargs, + ): + """ + Main method to prepare for the model one or several sequences(s) and image(s). + """ + if images is not None: + # Process images with the image processor + # Ensure return_tensors is set to "pt" for PyTorch tensors + image_kwargs = {**kwargs} + if "return_tensors" not in image_kwargs: + image_kwargs["return_tensors"] = "pt" + pixel_values = self.image_processor(images, **image_kwargs) + else: + pixel_values = None + + text_inputs = self.tokenizer(text, **kwargs) if text is not None else None + + if pixel_values is not None and text_inputs is not None: + text_inputs["pixel_values"] = pixel_values["pixel_values"] + return text_inputs + elif pixel_values is not None: + return pixel_values + else: + return text_inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's batch_decode. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's decode. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +AutoProcessor.register("BagelProcessor", BagelProcessor) diff --git a/vllm/transformers_utils/processors/hunyuan_vl.py b/vllm/transformers_utils/processors/hunyuan_vl.py index 615a8bff85912..f32ce115c866d 100644 --- a/vllm/transformers_utils/processors/hunyuan_vl.py +++ b/vllm/transformers_utils/processors/hunyuan_vl.py @@ -123,7 +123,7 @@ class HunYuanVLProcessor(ProcessorMixin): attention_mask = input_ids.ne(self.pad_id) text_inputs["attention_mask"] = attention_mask - text_inputs["imgs_pos"] = [self.get_imgs_pos(input_ids)] + text_inputs["imgs_pos"] = [self.get_imgs_pos(e) for e in input_ids] # image_inputs["imgs"] = [[image_inputs["pixel_values"]]] return_tensors = kwargs.pop("return_tensors", None) diff --git a/vllm/transformers_utils/runai_utils.py b/vllm/transformers_utils/runai_utils.py index eac4294bb59cd..041056720a96b 100644 --- a/vllm/transformers_utils/runai_utils.py +++ b/vllm/transformers_utils/runai_utils.py @@ -18,9 +18,7 @@ SUPPORTED_SCHEMES = ["s3://", "gs://"] try: from runai_model_streamer import list_safetensors as runai_list_safetensors from runai_model_streamer import pull_files as runai_pull_files -except (ImportError, OSError): - # see https://github.com/run-ai/runai-model-streamer/issues/26 - # OSError will be raised on arm64 platform +except ImportError: runai_model_streamer = PlaceholderModule("runai_model_streamer") # type: ignore[assignment] runai_pull_files = runai_model_streamer.placeholder_attr("pull_files") runai_list_safetensors = runai_model_streamer.placeholder_attr("list_safetensors") diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 0911848c02e14..90af573535d3b 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -2,17 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import warnings -from functools import lru_cache -from typing import TYPE_CHECKING, Any +from typing import Any -from typing_extensions import assert_never +from typing_extensions import deprecated from vllm.logger import init_logger -from vllm.tokenizers import TokenizerLike, get_tokenizer - -if TYPE_CHECKING: - from vllm.config import ModelConfig - +from vllm.tokenizers import TokenizerLike logger = init_logger(__name__) @@ -22,28 +17,65 @@ def __getattr__(name: str): warnings.warn( "`vllm.transformers_utils.tokenizer.AnyTokenizer` has been moved to " "`vllm.tokenizers.TokenizerLike`. " - "The old name will be removed in v0.13.", + "The old name will be removed in v0.14.", DeprecationWarning, stacklevel=2, ) return TokenizerLike - if name == "get_cached_tokenizer": - from vllm.tokenizers.hf import get_cached_tokenizer + if name == "get_tokenizer": + from vllm.tokenizers import get_tokenizer warnings.warn( - "`vllm.transformers_utils.tokenizer.get_cached_tokenizer` " - "has been moved to `vllm.tokenizers.hf.get_cached_tokenizer`. " - "The old name will be removed in v0.13.", + "`vllm.transformers_utils.tokenizer.get_tokenizer` " + "has been moved to `vllm.tokenizers.get_tokenizer`. " + "The old name will be removed in v0.14.", DeprecationWarning, stacklevel=2, ) - return get_cached_tokenizer + return get_tokenizer + if name == "cached_get_tokenizer": + from vllm.tokenizers import cached_get_tokenizer + + warnings.warn( + "`vllm.transformers_utils.tokenizer.cached_get_tokenizer` " + "has been moved to `vllm.tokenizers.cached_get_tokenizer`. " + "The old name will be removed in v0.14.", + DeprecationWarning, + stacklevel=2, + ) + + return cached_get_tokenizer + if name == "cached_tokenizer_from_config": + from vllm.tokenizers import cached_tokenizer_from_config + + warnings.warn( + "`vllm.transformers_utils.tokenizer.cached_tokenizer_from_config` " + "has been moved to `vllm.tokenizers.cached_tokenizer_from_config`. " + "The old name will be removed in v0.14.", + DeprecationWarning, + stacklevel=2, + ) + + return cached_tokenizer_from_config + if name == "init_tokenizer_from_configs": + from vllm.tokenizers import cached_tokenizer_from_config + + warnings.warn( + "`vllm.transformers_utils.tokenizer.init_tokenizer_from_configs` " + "has been moved to `vllm.tokenizers.cached_tokenizer_from_config`. " + "The old name will be removed in v0.14.", + DeprecationWarning, + stacklevel=2, + ) + + return cached_tokenizer_from_config raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +@deprecated("Will be removed in v0.14. Please use `tokenizer.decode()` instead.") def decode_tokens( tokenizer: TokenizerLike, token_ids: list[int], @@ -65,6 +97,7 @@ def decode_tokens( return tokenizer.decode(token_ids, **kw_args) +@deprecated("Will be removed in v0.14. Please use `tokenizer.encode()` instead.") def encode_tokens( tokenizer: TokenizerLike, text: str, @@ -92,37 +125,3 @@ def encode_tokens( kw_args["add_special_tokens"] = add_special_tokens return tokenizer.encode(text, **kw_args) - - -cached_get_tokenizer = lru_cache(get_tokenizer) - - -def cached_tokenizer_from_config( - model_config: "ModelConfig", - **kwargs: Any, -): - return cached_get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - revision=model_config.tokenizer_revision, - trust_remote_code=model_config.trust_remote_code, - **kwargs, - ) - - -def init_tokenizer_from_configs(model_config: "ModelConfig"): - runner_type = model_config.runner_type - if runner_type == "generate" or runner_type == "draft": - truncation_side = "left" - elif runner_type == "pooling": - truncation_side = "right" - else: - assert_never(runner_type) - - return get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision, - truncation_side=truncation_side, - ) diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py index 78fb6edc8b9ed..3dfd4b4f2f6c1 100644 --- a/vllm/transformers_utils/tokenizer_base.py +++ b/vllm/transformers_utils/tokenizer_base.py @@ -11,7 +11,7 @@ def __getattr__(name: str): warnings.warn( "`vllm.transformers_utils.tokenizer_base.TokenizerBase` has been " "moved to `vllm.tokenizers.TokenizerLike`. " - "The old name will be removed in v0.13.", + "The old name will be removed in v0.14.", DeprecationWarning, stacklevel=2, ) @@ -23,7 +23,7 @@ def __getattr__(name: str): warnings.warn( "`vllm.transformers_utils.tokenizer_base.TokenizerRegistry` has been " "moved to `vllm.tokenizers.TokenizerRegistry`. " - "The old name will be removed in v0.13.", + "The old name will be removed in v0.14.", DeprecationWarning, stacklevel=2, ) diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 45a873c9f7001..96f292f4c949e 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -9,8 +9,6 @@ from os import PathLike from pathlib import Path from typing import Any -from gguf import GGMLQuantizationType - import vllm.envs as envs from vllm.logger import init_logger @@ -29,76 +27,6 @@ def is_cloud_storage(model_or_path: str) -> bool: return is_s3(model_or_path) or is_gcs(model_or_path) -@cache -def check_gguf_file(model: str | PathLike) -> bool: - """Check if the file is a GGUF model.""" - model = Path(model) - if not model.is_file(): - return False - elif model.suffix == ".gguf": - return True - - try: - with model.open("rb") as f: - header = f.read(4) - - return header == b"GGUF" - except Exception as e: - logger.debug("Error reading file %s: %s", model, e) - return False - - -@cache -def is_remote_gguf(model: str | Path) -> bool: - """Check if the model is a remote GGUF model.""" - model = str(model) - return ( - (not is_cloud_storage(model)) - and (not model.startswith(("http://", "https://"))) - and ("/" in model and ":" in model) - and is_valid_gguf_quant_type(model.rsplit(":", 1)[1]) - ) - - -def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool: - """Check if the quant type is a valid GGUF quant type.""" - return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None - - -def split_remote_gguf(model: str | Path) -> tuple[str, str]: - """Split the model into repo_id and quant type.""" - model = str(model) - if is_remote_gguf(model): - parts = model.rsplit(":", 1) - return (parts[0], parts[1]) - raise ValueError( - "Wrong GGUF model or invalid GGUF quant type: %s.\n" - "- It should be in repo_id:quant_type format.\n" - "- Valid GGMLQuantizationType values: %s", - model, - GGMLQuantizationType._member_names_, - ) - - -def is_gguf(model: str | Path) -> bool: - """Check if the model is a GGUF model. - - Args: - model: Model name, path, or Path object to check. - - Returns: - True if the model is a GGUF model, False otherwise. - """ - model = str(model) - - # Check if it's a local GGUF file - if check_gguf_file(model): - return True - - # Check if it's a remote GGUF model (repo_id:quant_type format) - return is_remote_gguf(model) - - def modelscope_list_repo_files( repo_id: str, revision: str | None = None, diff --git a/vllm/utils/argparse_utils.py b/vllm/utils/argparse_utils.py index 555fcfea491e2..87ee6f54c0c9b 100644 --- a/vllm/utils/argparse_utils.py +++ b/vllm/utils/argparse_utils.py @@ -244,9 +244,8 @@ class FlexibleArgumentParser(ArgumentParser): else: key = pattern.sub(repl, arg, count=1) processed_args.append(key) - elif arg.startswith("-O") and arg != "-O" and arg[2] != ".": + elif arg.startswith("-O") and arg != "-O": # allow -O flag to be used without space, e.g. -O3 or -Odecode - # -O.<...> handled later # also handle -O=<optimization_level> here optimization_level = arg[3:] if arg[2] == "=" else arg[2:] processed_args += ["--optimization-level", optimization_level] @@ -257,17 +256,6 @@ class FlexibleArgumentParser(ArgumentParser): ): # Convert -O <n> to --optimization-level <n> processed_args.append("--optimization-level") - elif arg.startswith("-O."): - # Handle -O.* dotted syntax - ALL dotted syntax is deprecated - logger.warning_once( - "The -O.* dotted syntax for --compilation-config is " - "deprecated and will be removed in v0.13.0 or v1.0.0" - ", whichever is earlier. Please use -cc.* instead. " - "Example: -cc.backend=eager instead of " - "-O.backend=eager." - ) - converted_arg = arg.replace("-O", "-cc", 1) - processed_args.append(converted_arg) else: processed_args.append(arg) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index b25c1e3e1ece3..3d4f8449ad3b6 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -38,7 +38,7 @@ class DeepGemmQuantScaleFMT(Enum): return DeepGemmQuantScaleFMT.FLOAT32 return ( DeepGemmQuantScaleFMT.UE8M0 - if current_platform.is_device_capability(100) + if current_platform.is_device_capability_family(100) else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 ) @@ -50,7 +50,7 @@ def is_deep_gemm_supported() -> bool: """ is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) - or current_platform.is_device_capability(100) + or current_platform.is_device_capability_family(100) ) return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch @@ -383,6 +383,7 @@ def should_use_deepgemm_for_fp8_linear( __all__ = [ "calc_diff", + "DeepGemmQuantScaleFMT", "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 9f9976d52b4ae..5019b771f4a14 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -264,24 +264,23 @@ def supports_trtllm_attention() -> bool: return False # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - return current_platform.is_device_capability(100) and has_nvidia_artifactory() - - -@functools.cache -def _force_use_trtllm_attention(env_value: bool | None) -> bool | None: - """Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" - if env_value is not None: - logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) - return env_value + return ( + current_platform.is_device_capability_family(100) and has_nvidia_artifactory() + ) def force_use_trtllm_attention() -> bool | None: """ - Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set, + This function should only be called during initialization stage when vllm config + is set. + Return `None` if --attention-config.use_trtllm_attention is not set, return `True` if TRTLLM attention is forced to be used, return `False` if TRTLLM attention is forced to be not used. """ - return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + return vllm_config.attention_config.use_trtllm_attention def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: @@ -301,13 +300,14 @@ def use_trtllm_attention( kv_cache_dtype: str, q_dtype: torch.dtype, is_prefill: bool, + # None means auto-detection, True means force on, False means force off + force_use_trtllm: bool | None = None, has_sinks: bool = False, has_spec: bool = False, ) -> bool: """Return `True` if TRTLLM attention is used.""" - force_use_trtllm = force_use_trtllm_attention() - # Environment variable is set to 0 - respect it + # CLI argument is set to 0 - respect it if force_use_trtllm is not None and not force_use_trtllm: return False @@ -324,7 +324,7 @@ def use_trtllm_attention( if force_use_trtllm: logger.warning_once( "TRTLLM attention is not supported on this platform, " - "but VLLM_USE_TRTLLM_ATTENTION is set to 1" + "but --attention-config.use_trtllm_attention is set to 1" ) return False @@ -333,7 +333,8 @@ def use_trtllm_attention( if force_use_trtllm: logger.warning_once( "TRTLLM attention is not supported for this combination of " - "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1" + "query and key heads, but --attention-config.use_trtllm_attention is " + "set to 1" ) return False @@ -354,7 +355,7 @@ def use_trtllm_attention( return True if force_use_trtllm is None: - # Environment variable not set - use auto-detection + # CLI argument not set - use auto-detection if is_prefill: # Prefill auto-detection use_trtllm = kv_cache_dtype == "auto" @@ -367,8 +368,10 @@ def use_trtllm_attention( logger.warning_once("Using TRTLLM decode attention (auto-detected).") return use_trtllm - # Environment variable is set to 1 - respect it - logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") + # CLI argument is set to 1 - respect it + logger.info_once( + "Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)" + ) return True @@ -500,12 +503,6 @@ def flashinfer_scaled_fp8_mm( return output -@functools.cache -def flashinfer_disable_q_quantization() -> bool: - """Cache result which only depends on the environment""" - return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION - - __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -526,7 +523,6 @@ __all__ = [ "supports_trtllm_attention", "can_use_trtllm_attention", "use_trtllm_attention", - "flashinfer_disable_q_quantization", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp8_mm", ] diff --git a/vllm/utils/hashing.py b/vllm/utils/hashing.py index edf1e9cb34e56..f01c6b074ffeb 100644 --- a/vllm/utils/hashing.py +++ b/vllm/utils/hashing.py @@ -11,6 +11,17 @@ from typing import Any import cbor2 +try: + # It is important that this remains an optional dependency. + # It would not be allowed in environments with strict security controls, + # so it's best not to have it installed when not in use. + import xxhash as _xxhash + + if not hasattr(_xxhash, "xxh3_128_digest"): + _xxhash = None +except ImportError: # pragma: no cover + _xxhash = None + def sha256(input: Any) -> bytes: """Hash any picklable Python object using SHA-256. @@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes: return hashlib.sha256(input_bytes).digest() +def _xxhash_digest(input_bytes: bytes) -> bytes: + if _xxhash is None: + raise ModuleNotFoundError( + "xxhash is required for the 'xxhash' prefix caching hash algorithms. " + "Install it via `pip install xxhash`." + ) + return _xxhash.xxh3_128_digest(input_bytes) + + +def xxhash(input: Any) -> bytes: + """Hash picklable objects using xxHash.""" + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + return _xxhash_digest(input_bytes) + + +def xxhash_cbor(input: Any) -> bytes: + """Hash objects serialized with CBOR using xxHash.""" + input_bytes = cbor2.dumps(input, canonical=True) + return _xxhash_digest(input_bytes) + + def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: """Get a hash function by name, or raise an error if the function is not found. @@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: return sha256 if hash_fn_name == "sha256_cbor": return sha256_cbor + if hash_fn_name == "xxhash": + return xxhash + if hash_fn_name == "xxhash_cbor": + return xxhash_cbor raise ValueError(f"Unsupported hash function: {hash_fn_name}") diff --git a/vllm/utils/nvtx_pytorch_hooks.py b/vllm/utils/nvtx_pytorch_hooks.py new file mode 100644 index 0000000000000..39e2a9a136e63 --- /dev/null +++ b/vllm/utils/nvtx_pytorch_hooks.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager + +import torch +import torch.cuda.nvtx as nvtx + + +def print_tensor(tensor_obj, prefix, tensor_list=None): + """Descends iterators that contains Tensors and prints the Tensor. + Recursive function that descends iterator type arguments until + it finds a Tensor object. + """ + if tensor_list is None: + tensor_list = [] + + if isinstance(tensor_obj, (list, tuple)): + for ten in tensor_obj: + tensor_list = print_tensor(ten, prefix, tensor_list) + elif isinstance(tensor_obj, torch.Tensor): + tensor_dims = list(tensor_obj.size()) + tensor_list.append(tensor_dims) + return tensor_list + + +def process_layer_params(module_obj): + """Extract the static parameters from LLM and VLM relevant layer types""" + param_info = {} + # Extract parameters for layers commonly used in LLMs and VLMs + if isinstance(module_obj, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)): + conv_params = {} + conv_params["in_chan"] = module_obj.in_channels + conv_params["out_chan"] = module_obj.out_channels + conv_params["filter_dim"] = module_obj.kernel_size + conv_params["stride"] = module_obj.stride + conv_params["padding"] = module_obj.padding + conv_params["dilation"] = module_obj.dilation + conv_params["transposed"] = module_obj.transposed + conv_params["output_padding"] = module_obj.output_padding + conv_params["groups"] = module_obj.groups + conv_params["padding_mode"] = module_obj.padding_mode + param_info = conv_params + elif isinstance( + module_obj, + ( + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + ), + ): + convtranspose_params = {} + convtranspose_params["in_chan"] = module_obj.in_channels + convtranspose_params["out_chan"] = module_obj.out_channels + convtranspose_params["filter_dim"] = module_obj.kernel_size + convtranspose_params["stride"] = module_obj.stride + convtranspose_params["padding"] = module_obj.padding + convtranspose_params["dilation"] = module_obj.dilation + convtranspose_params["transposed"] = module_obj.transposed + convtranspose_params["output_padding"] = module_obj.output_padding + convtranspose_params["groups"] = module_obj.groups + convtranspose_params["padding_mode"] = module_obj.padding_mode + param_info = convtranspose_params + elif isinstance( + module_obj, (torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d) + ): + + def _handle_int_or_tuple(parameter): + if isinstance(parameter, tuple): + return list(parameter) + elif isinstance(parameter, int): + return [parameter, parameter] + + pooling_params = {} + pooling_params["filter_dim"] = _handle_int_or_tuple(module_obj.kernel_size) + pooling_params["stride"] = _handle_int_or_tuple(module_obj.stride) + pooling_params["padding"] = _handle_int_or_tuple(module_obj.padding) + pooling_params["dilation"] = _handle_int_or_tuple(module_obj.dilation) + param_info = pooling_params + elif isinstance( + module_obj, (torch.nn.AvgPool1d, torch.nn.AvgPool2d, torch.nn.AvgPool3d) + ): + pooling_params = {} + pooling_params["filter_dim"] = [ + module_obj.kernel_size, + module_obj.kernel_size, + ] + pooling_params["stride"] = [module_obj.stride, module_obj.stride] + pooling_params["padding"] = [module_obj.padding, module_obj.padding] + pooling_params["ceil_mode"] = module_obj.ceil_mode + pooling_params["count_include_pad"] = module_obj.count_include_pad + param_info = pooling_params + elif isinstance( + module_obj, + ( + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + ), + ): + pooling_params = {} + pooling_params["output_size"] = [ + module_obj.output_size, + module_obj.output_size, + ] + param_info = pooling_params + elif isinstance(module_obj, torch.nn.Linear): + param_info["in_features"] = module_obj.in_features + param_info["out_features"] = module_obj.out_features + elif isinstance( + module_obj, + (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d), + ): + param_info["num_features"] = module_obj.num_features + param_info["epsilon"] = module_obj.eps + param_info["momentum"] = module_obj.momentum + elif isinstance(module_obj, torch.nn.ReLU): + param_info["in_place"] = module_obj.inplace + elif isinstance(module_obj, torch.nn.Dropout): + param_info["p"] = module_obj.p + param_info["in_place"] = module_obj.inplace + elif isinstance(module_obj, torch.nn.Embedding): + param_info["num_embeddings"] = module_obj.num_embeddings + param_info["embedding_dim"] = module_obj.embedding_dim + elif isinstance( + module_obj, + ( + torch.nn.Upsample, + torch.nn.UpsamplingNearest2d, + torch.nn.UpsamplingBilinear2d, + ), + ): + param_info["scale_factor"] = module_obj.scale_factor + + return param_info + + +def construct_marker_dict_and_push( + module_name, module_obj, in_tensor, kwargs=None, out_tensor=None +): + marker_dict = {} + marker_dict["Module"] = module_name + + ## Get trainable parameters like weights and bias + module_params = module_obj.named_parameters(recurse=False) + for idx, (param_name, param_obj) in enumerate(module_params): + if idx == 0: + marker_dict["TrainableParams"] = {} + marker_dict["TrainableParams"][param_name] = list(param_obj.size()) + + in_tensor_list = print_tensor(in_tensor, "Input") + if in_tensor_list: + marker_dict["Inputs"] = in_tensor_list + + out_tensor_list = print_tensor(out_tensor, "Output") + if out_tensor_list: + marker_dict["Outputs"] = out_tensor_list + + ## Get Kwargs like input_ids and positions for the top module + if kwargs: + for key, value in kwargs.items(): + if isinstance(value, (torch.Tensor, list, tuple)): + tensor_list = print_tensor(value, key) + if tensor_list: + marker_dict[key] = tensor_list + + param_info = process_layer_params(module_obj) + if param_info: + marker_dict["StaticParams"] = param_info + nvtx.range_push("{}".format(marker_dict)) + + +class ResultHolder: + """Holder for storing results from within a context manager.""" + + result = None + + +@contextmanager +def layerwise_nvtx_marker_context(module_name, module_obj, in_tensor=None, kwargs=None): + """Context manager for NVTX markers that automatically pushes on enter + and pops on exit. + + Example: + with nvtx_marker_context("Module:MyModule", module, in_tensor=args, + kwargs=kwargs) as ctx: + ctx.result = module(*args, **kwargs) + return ctx.result + """ + holder = ResultHolder() + + # Push input marker + construct_marker_dict_and_push( + module_name, + module_obj, + in_tensor=in_tensor, + kwargs=kwargs, + ) + try: + yield holder + finally: + # Pop input marker + nvtx.range_pop() + # Push and pop output marker + output_name = module_name.replace("(input)", "(output)") + construct_marker_dict_and_push( + output_name, + module_obj, + in_tensor=None, + kwargs=None, + out_tensor=holder.result, + ) + nvtx.range_pop() + + +class PytHooks: + """This module contains all the code needed to enable forward hooks + in a pytorch network. + + To register the hooks for a given network, the user needs to instantiate + a PytHook object. Then call the register_hooks method. + + Example: + + my_hook = PytHook() + my_hook.register_hooks(my_network_model) + """ + + def __init__(self): + """Initialize module variables.""" + super().__init__() + self.module_to_name_map = {} + + def _process_layer_params(self, module_obj): + return process_layer_params(module_obj) + + def module_fwd_hook(self, module_obj, in_tensor, out_tensor): + """Callback function that ends the NVTX marker. + Records the module name and tensor information. + Called after the module executes the forward method. + """ + nvtx.range_pop() + module_name = self.module_to_name_map.get(module_obj, "unknown") + construct_marker_dict_and_push( + module_name, module_obj, in_tensor=None, kwargs=None, out_tensor=out_tensor + ) + nvtx.range_pop() + return + + def module_fwd_pre_hook(self, module_obj, in_tensor, kwargs): + """Creates an NVTX marker with the module name in it. + This function is called before the module executes. + """ + module_name = self.module_to_name_map.get(module_obj, "unknown") + construct_marker_dict_and_push( + module_name, module_obj, in_tensor=in_tensor, kwargs=kwargs, out_tensor=None + ) + return + + def register_hooks(self, network_model, module_prefix="top"): + """User level function that activates all the hooks. + The user needs to call this method from the network source code. + The code descends all the modules in the network and registers their + respective hooks. + """ + # Module types to skip (simple operations that don't need detailed profiling) + skip_types = ( + torch.nn.Identity, + torch.nn.Dropout, + torch.nn.Dropout1d, + torch.nn.Dropout2d, + torch.nn.Dropout3d, + ) + + for name, module in network_model.named_modules(prefix=module_prefix): + # Skip certain module types to reduce profiling overhead + if isinstance(module, skip_types): + continue + + module.register_forward_pre_hook(self.module_fwd_pre_hook, with_kwargs=True) + module.register_forward_hook(self.module_fwd_hook) + if module not in self.module_to_name_map: + self.module_to_name_map[module] = name + else: + raise ValueError("Module instance {} is not unique ".format(module)) + return diff --git a/vllm/utils/serial_utils.py b/vllm/utils/serial_utils.py index b89fa6ce4db66..07db5eaf74c8d 100644 --- a/vllm/utils/serial_utils.py +++ b/vllm/utils/serial_utils.py @@ -1,15 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 +import io +import math import sys from dataclasses import dataclass -from typing import Literal +from typing import TYPE_CHECKING, Any, Literal import numpy as np import torch from typing_extensions import assert_never -from vllm import PoolingRequestOutput +if TYPE_CHECKING: + from vllm import PoolingRequestOutput +else: + PoolingRequestOutput = Any sys_byteorder = sys.byteorder @@ -26,6 +31,14 @@ EMBED_DTYPE_TO_TORCH_DTYPE = { "fp8_e5m2": torch.float8_e5m2, } +EMBED_DTYPE_TO_N_BYTES = { + "float32": 4, + "float16": 2, + "bfloat16": 2, + "fp8_e4m3": 1, + "fp8_e5m2": 1, +} + EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = { "float32": torch.float32, @@ -49,7 +62,16 @@ ENDIANNESS = ["native", "big", "little"] EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"] Endianness = Literal["native", "big", "little"] -EncodingFormat = Literal["float", "base64", "bytes"] +EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"] + + +def tensor2base64(x: torch.Tensor) -> str: + with io.BytesIO() as buf: + torch.save(x, buf) + buf.seek(0) + binary_data = buf.read() + + return base64.b64encode(binary_data).decode("utf-8") def tensor2binary( @@ -104,7 +126,7 @@ def encode_pooling_output( elif encoding_format == "base64": embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness) return base64.b64encode(embedding_bytes).decode("utf-8") - elif encoding_format == "bytes": + elif encoding_format == "bytes" or encoding_format == "bytes_only": return tensor2binary(output.outputs.data, embed_dtype, endianness) assert_never(encoding_format) @@ -119,6 +141,29 @@ class MetadataItem: shape: tuple[int, ...] +def build_metadata_items( + embed_dtype: EmbedDType, + endianness: Endianness, + shape: tuple[int, ...], + n_request: int, +): + n_bytes = EMBED_DTYPE_TO_N_BYTES[embed_dtype] + size = math.prod(shape) + items = [ + MetadataItem( + index=i, + embed_dtype=embed_dtype, + endianness=endianness, + start=i * size * n_bytes, + end=(i + 1) * size * n_bytes, + shape=shape, + ) + for i in range(n_request) + ] + + return items + + def encode_pooling_bytes( pooling_outputs: list[PoolingRequestOutput], embed_dtype: EmbedDType, diff --git a/vllm/utils/system_utils.py b/vllm/utils/system_utils.py index a4eb8f4d4fd7d..76cac59c18098 100644 --- a/vllm/utils/system_utils.py +++ b/vllm/utils/system_utils.py @@ -204,6 +204,10 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: def decorate_logs(process_name: str | None = None) -> None: """Decorate stdout/stderr with process name and PID prefix.""" + # Respect VLLM_CONFIGURE_LOGGING environment variable + if not envs.VLLM_CONFIGURE_LOGGING: + return + if process_name is None: process_name = get_mp_context().current_process().name diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index f5c49ac169f0c..c97efce312b56 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -28,6 +28,7 @@ else: STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, "half": torch.half, + "float16": torch.float16, "bfloat16": torch.bfloat16, "float": torch.float, "fp8": torch.uint8, diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index fed7dcdf293bd..394d0c2f67136 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, split_decodes_and_prefills, ) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec logger = init_logger(__name__) @@ -50,11 +50,13 @@ class CPUAttentionBackend(AttentionBackend): @classmethod def supports_attn_type(cls, attn_type: str) -> bool: - """CPU attention supports decoder and encoder-only attention.""" + """CPU attention supports decoder, + encoder-only and encoder-decoder attention.""" return attn_type in ( AttentionType.DECODER, AttentionType.ENCODER, AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, ) @staticmethod @@ -136,6 +138,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata] self.window_size = -1 self.block_size = vllm_config.cache_config.block_size self.isa = _get_attn_isa(self.dtype, self.block_size) + self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec) def build( self, @@ -151,7 +154,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata] seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - causal = common_attn_metadata.causal + causal = False if self.is_cross_attention else common_attn_metadata.causal sdpa_start_loc = query_start_loc num_decode_tokens = 0 @@ -171,22 +174,19 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata] query_start_loc = query_start_loc[: num_decodes + 1] block_table_tensor = block_table_tensor[:num_decodes] - sheduler_metadata = None - if causal: - # for decode batch, use the custom kernel - sheduler_metadata = ops.cpu_attn_get_scheduler_metadata( - num_reqs=num_reqs, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - seq_lens=seq_lens, - dtype=self.dtype, - query_start_loc=query_start_loc, - causal=causal, - sliding_window_size=self.window_size, - isa=self.isa, - enable_kv_split=True, - ) + sheduler_metadata = ops.cpu_attn_get_scheduler_metadata( + num_reqs=num_reqs, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + seq_lens=seq_lens, + dtype=self.dtype, + query_start_loc=query_start_loc, + causal=causal, + sliding_window_size=self.window_size, + isa=self.isa, + enable_kv_split=True, + ) attn_metadata = CPUAttentionMetadata( isa=self.isa, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fb080b0b33bc0..f5ad98cf2125c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -8,7 +8,6 @@ from typing import ClassVar import numpy as np import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -264,6 +263,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config + self.attention_config = vllm_config.attention_config self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config @@ -304,7 +304,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + self.attention_config.flash_attn_max_num_splits_for_cuda_graph + ) # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. @@ -554,8 +556,7 @@ class FlashAttentionImpl(AttentionImpl): "heads in the layer" ) - def supports_quant_query_input(self) -> bool: - return True + self.supports_quant_query_input = True def forward( self, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 69a6a5e5fae82..2740a6916fd97 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import ( ) from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger @@ -43,7 +43,6 @@ from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.flashinfer import ( can_use_trtllm_attention, - flashinfer_disable_q_quantization, use_trtllm_attention, ) from vllm.utils.math_utils import cdiv @@ -362,7 +361,8 @@ class FlashInferBackend(AttentionBackend): supports_trtllm_attention, ) - # Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0) + # Respect explicit disable flag (e.g., + # --attention-config.use_trtllm_attention=0) if force_use_trtllm_attention() is False: return False @@ -429,6 +429,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config + self.attention_config = vllm_config.attention_config self._workspace_buffer = None self._prefill_wrapper: ( BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None @@ -482,9 +483,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.dcp_rank = 0 self.dcp_kv_cache_interleave_size = 1 - self.num_qo_heads = ( - self.model_config.get_num_attention_heads(self.vllm_config.parallel_config) - * self.dcp_world_size + self.num_qo_heads = self.model_config.get_num_attention_heads( + self.vllm_config.parallel_config ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads @@ -501,11 +501,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.kv_cache_dtype = self.kv_cache_spec.dtype # Use model dtype as q dtype when TRTLLM attn is not supported, or - # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to - # use fp8 q if kv cache is fp8, and will fall back to model dtype + # --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise, + # try to use fp8 q if kv cache is fp8, and will fall back to model dtype # if TRTLLM attention kernel is not used when building attn metadata can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) - if can_use_trtllm and not flashinfer_disable_q_quantization(): + if ( + can_use_trtllm + and not vllm_config.attention_config.disable_flashinfer_q_quantization + ): self.q_data_type = self.kv_cache_dtype else: self.q_data_type = self.model_config.dtype @@ -561,7 +564,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() - if self.head_dim == 256 and current_platform.is_device_capability(100): + if self.head_dim == 256 and current_platform.is_device_capability_family(100): # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that # head size 256 and block size 16 is not supported on blackwell. assert kv_cache_spec.block_size != 16, ( @@ -777,6 +780,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.cache_dtype, self.q_data_type, is_prefill=True, + force_use_trtllm=self.attention_config.use_trtllm_attention, has_sinks=self.has_sinks, has_spec=uses_spec_reorder, ) @@ -1036,6 +1040,11 @@ class FlashInferImpl(AttentionImpl): self.sinks = sinks self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) + vllm_config = get_current_vllm_config() + self.supports_quant_query_input = ( + self.support_trtllm_attn + and not vllm_config.attention_config.disable_flashinfer_q_quantization + ) self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None self.o_sf_scale: float | None = None @@ -1047,12 +1056,6 @@ class FlashInferImpl(AttentionImpl): and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) ) - def supports_quant_query_input(self) -> bool: - if flashinfer_disable_q_quantization(): - return False - - return self.support_trtllm_attn - # FlashInfer requires attention sinks to be float32 def process_weights_after_loading(self, act_dtype: torch.dtype): if self.sinks is not None and self.sinks.dtype != torch.float32: diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 8de0a0a11471f..d8dbe4cbae013 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -4,6 +4,7 @@ import math from dataclasses import dataclass +from functools import cached_property from typing import ClassVar import torch @@ -16,6 +17,7 @@ from torch.nn.attention.flex_attention import ( and_masks, create_block_mask, flex_attention, + or_masks, ) from vllm.attention.backends.abstract import ( @@ -30,6 +32,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( @@ -40,6 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +torch._dynamo.config.recompile_limit = 16 create_block_mask_compiled = torch.compile( create_block_mask, fullgraph=True, mode="reduce-overhead" ) @@ -89,6 +93,11 @@ class FlexAttentionBackend(AttentionBackend): """FlexAttention supports both decoder and encoder-only attention.""" return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY) + @classmethod + def supports_mm_prefix(cls) -> bool: + """FlexAttention supports full attention for image tokens.""" + return True + @staticmethod def get_impl_cls() -> type["FlexAttentionImpl"]: return FlexAttentionImpl @@ -314,6 +323,15 @@ class FlexAttentionMetadata: kv_block_size: int = 16 transformed_score_mod: _score_mod_signature | None = None sliding_window: int | None = None + mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None + + @cached_property + def logical_block_ids(self): + return torch.arange( + cdiv(self.max_seq_len, self.block_size), + device=self.block_table.device, + dtype=torch.long, + ) def _convert_physical_to_logical( self, @@ -433,6 +451,45 @@ class FlexAttentionMetadata: return final_mask_mod if self.causal else sliding_window_mask_mod + def get_prefix_lm_mask_mod(self) -> _mask_mod_signature: + """Creates the prefix LM mask_mod function for FlexAttention.""" + + assert self.doc_ids is not None + request_lookup = self.doc_ids + + def prefix_lm_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + cu_q_idx: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ): + mask = torch.zeros_like(q_idx, dtype=torch.bool) + for req, doc_range_lst in (self.mm_prefix_range or {}).items(): + req_mask = request_lookup[cu_q_idx] == req + for start, end in doc_range_lst: + doc_mask_q = (q_idx >= start) & (q_idx <= end) + doc_mask_kv = (kv_idx >= start) & (kv_idx <= end) + mask = mask | (req_mask & doc_mask_q & doc_mask_kv) + return mask + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) + return torch.where( + is_valid, + prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx), + False, + ) + + return final_mask_mod + def get_mask_mod(self): # Stage-1: initialize the base mask_mod # (causal mask for decoder or bidirectional mask for encoder) @@ -446,6 +503,10 @@ class FlexAttentionMetadata: # Add sliding window mask for sliding window attention sliding_window_mask_mod = self.get_sliding_window_mask_mod() mask_mod = and_masks(mask_mod, sliding_window_mask_mod) + if self.mm_prefix_range: + # Add prefix LM mask for vision-language prefix LM attention + prefix_lm_mask_mod = self.get_prefix_lm_mask_mod() + mask_mod = or_masks(mask_mod, prefix_lm_mask_mod) return mask_mod def get_transformed_score_mod(self) -> _score_mod_signature | None: @@ -493,6 +554,7 @@ class FlexAttentionMetadata: The direct path works as follows: 1. For each query token, fetch blocks from block_table using max_seq_len + and exclude out of sliding window blocks if needed. (this fetches more blocks than needed for shorter sequences) 2. Group query tokens into chunks of q_block_size 3. For each group, deduplicate the blocks using unique_static_unsorted @@ -517,6 +579,23 @@ class FlexAttentionMetadata: used_pages = self.block_table[ self.doc_ids, : cdiv(self.max_seq_len, self.block_size) ] + + if self.sliding_window and self.causal: + device = used_pages.device + assert self.doc_ids is not None + token_indices = torch.arange( + self.doc_ids.shape[0], device=device, dtype=torch.long + ) + logical_q_idx = ( + token_indices + - self.query_start_loc[self.doc_ids] + + self.decode_offset[self.doc_ids] + ) + min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0) + min_block_idx = min_kv_idx // self.block_size + sliding_mask = self.logical_block_ids >= min_block_idx[:, None] + used_pages.masked_fill_(~sliding_mask, 0) + used_pages_padded = pad_to_multiple( used_pages, multiple=self.q_block_size, dim=0 ) @@ -681,6 +760,7 @@ class FlexAttentionImpl(AttentionImpl): sliding_window: int | None alibi_slopes: torch.Tensor | None logits_soft_cap: float | None + mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None def __init__( self, @@ -782,17 +862,21 @@ class FlexAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens + needs_rebuild_block_mask = False if attn_metadata.sliding_window != self.sliding_window: attn_metadata.sliding_window = self.sliding_window if attn_metadata.direct_build: - # TODO: Support skipping the computation of sliding window - # in direct block mask building code path. - logger.warning_once( - "Using direct block mask building with sliding window, " - "which is suboptimal now. Performance may be degraded." - ) # update mask mod in attention metadata attn_metadata.mask_mod = attn_metadata.get_mask_mod() + needs_rebuild_block_mask = True + + if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None): + self.mm_prefix_range = attn_metadata.mm_prefix_range + attn_metadata.mask_mod = attn_metadata.get_mask_mod() + needs_rebuild_block_mask = True + + if needs_rebuild_block_mask: + if attn_metadata.direct_build and attn_metadata.causal: attn_metadata.block_mask = attn_metadata._build_block_mask_direct() else: attn_metadata.block_mask = attn_metadata.build_block_mask() @@ -906,7 +990,18 @@ def get_kernel_options( if torch.cuda.is_available(): device_props = torch.cuda.get_device_properties() - max_shared_memory = device_props.shared_memory_per_block_optin + # ROCm doesn't expose shared_memory_per_block_optin attribute + # AMD GPUs typically have 64KB LDS (Local Data Share) per workgroup + if hasattr(device_props, "shared_memory_per_block_optin"): + max_shared_memory = device_props.shared_memory_per_block_optin + elif current_platform.is_rocm(): + # ROCm fallback: use 64KB + max_shared_memory = 65536 + else: + raise RuntimeError( + "Unable to determine shared memory size on this hardware." + ) + if max_shared_memory < 144 * 1024: block_m_candidate = ensure_divisible( max(1, block_m_candidate // 2), block_m diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 69b5a6fb48564..ace2cbb0564c8 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -211,7 +211,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] spec_token_masks = torch.repeat_interleave( spec_sequence_masks, query_lens ) - index = torch.argsort(spec_token_masks) + index = torch.argsort(spec_token_masks, stable=True) num_non_spec_tokens = num_prefill_tokens + num_decode_tokens non_spec_token_indx = index[:num_non_spec_tokens] spec_token_indx = index[num_non_spec_tokens:] @@ -254,17 +254,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ) else: has_initial_state = None - num_actual_tokens = ( - num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens - ) - # prepare tensors for cudagraph - # - # With speculative decoding, the xgrammar backend may rollback tokens - # and causing some sequences has less draft tokens than self.num_spec. - # - # In above cases, the max possible batch size for n tokens, can be - # min(n, cudagraph_max_bs). + # Prepare tensors for cudagraph + # Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph + batch_size = m.num_actual_tokens + if ( self.use_full_cuda_graph and num_prefills == 0 @@ -272,9 +266,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] and num_spec_decodes <= self.decode_cudagraph_max_bs and num_spec_decode_tokens <= self.decode_cudagraph_max_bs ): - num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) - batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens) - self.spec_state_indices_tensor[:num_spec_decodes].copy_( spec_state_indices_tensor, non_blocking=True ) @@ -319,9 +310,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] and num_spec_decodes == 0 and num_decodes <= self.decode_cudagraph_max_bs ): - num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) - batch_size = num_actual_tokens - self.non_spec_state_indices_tensor[:num_decodes].copy_( non_spec_state_indices_tensor, non_blocking=True ) @@ -344,7 +332,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] num_decode_tokens=num_decode_tokens, num_spec_decodes=num_spec_decodes, num_spec_decode_tokens=num_spec_decode_tokens, - num_actual_tokens=num_actual_tokens, + num_actual_tokens=m.num_actual_tokens, has_initial_state=has_initial_state, spec_query_start_loc=spec_query_start_loc, non_spec_query_start_loc=non_spec_query_start_loc, @@ -382,6 +370,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] num_accepted_tokens = torch.diff(m.query_start_loc) num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() - m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu() + m._num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu() return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu) diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 8e949e53330c1..fcda6134016ba 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -31,7 +31,6 @@ class Mamba1AttentionMetadata: num_prefill_tokens: int num_decodes: int num_decode_tokens: int - num_padded_decodes: int block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] @@ -68,7 +67,6 @@ class Mamba1AttentionMetadataBuilder( has_initial_states_p = None query_start_loc_p = None - padded_decodes = num_decodes num_computed_tokens, num_computed_tokens_p = None, None block_idx_first_scheduled_token = None block_idx_first_scheduled_token_p = None @@ -125,11 +123,10 @@ class Mamba1AttentionMetadataBuilder( and num_decodes <= self.decode_cudagraph_max_bs and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): - padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( state_indices_tensor, non_blocking=True ) - state_indices_tensor = self.state_indices_tensor[:padded_decodes] + state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID if self.vllm_config.cache_config.enable_prefix_caching: @@ -137,17 +134,15 @@ class Mamba1AttentionMetadataBuilder( block_idx_last_scheduled_token, non_blocking=True ) block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ - :padded_decodes + :num_decode_tokens ] - block_idx_last_scheduled_token[num_decodes:] = 0 self.block_idx_last_computed_token[:num_decodes].copy_( block_idx_last_computed_token, non_blocking=True ) block_idx_last_computed_token = self.block_idx_last_computed_token[ - :padded_decodes + :num_decode_tokens ] - block_idx_last_computed_token[num_decodes:] = 0 return Mamba1AttentionMetadata( query_start_loc_p=query_start_loc_p, @@ -157,7 +152,6 @@ class Mamba1AttentionMetadataBuilder( num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, - num_padded_decodes=padded_decodes, block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, block_idx_last_computed_token=block_idx_last_computed_token, diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 888734e5d2b6b..bf1d8f09ab0ac 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -10,7 +10,6 @@ from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - PAD_SLOT_ID, CommonAttentionMetadata, compute_causal_conv1d_metadata, split_decodes_and_prefills, @@ -304,30 +303,25 @@ class Mamba2AttentionMetadataBuilder( num_decodes <= self.decode_cudagraph_max_bs and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): - # Pad state tensor for CUDA graph - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( state_indices_tensor, non_blocking=True ) - state_indices_tensor = self.state_indices_tensor[:num_input_tokens] - state_indices_tensor[num_decodes:] = PAD_SLOT_ID + state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] if self.vllm_config.cache_config.enable_prefix_caching: self.block_idx_last_scheduled_token[:num_decodes].copy_( block_idx_last_scheduled_token, non_blocking=True ) block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ - :num_input_tokens + :num_decode_tokens ] - block_idx_last_scheduled_token[num_decodes:] = 0 self.block_idx_last_computed_token[:num_decodes].copy_( block_idx_last_computed_token, non_blocking=True ) block_idx_last_computed_token = self.block_idx_last_computed_token[ - :num_input_tokens + :num_decode_tokens ] - block_idx_last_computed_token[num_decodes:] = 0 attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 180625b6ce897..fea482493635f 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -438,30 +438,39 @@ A = TypeVar("A") def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() return ( - not envs.VLLM_DISABLE_FLASHINFER_PREFILL + not vllm_config.attention_config.disable_flashinfer_prefill and flashinfer_available - and not envs.VLLM_USE_CUDNN_PREFILL - and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL - and current_platform.is_device_capability(100) + and not vllm_config.attention_config.use_cudnn_prefill + and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill + and current_platform.is_device_capability_family(100) ) def use_cudnn_prefill() -> bool: + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() return ( flashinfer_available - and envs.VLLM_USE_CUDNN_PREFILL - and current_platform.is_device_capability(100) + and vllm_config.attention_config.use_cudnn_prefill + and current_platform.is_device_capability_family(100) and has_nvidia_artifactory() ) def use_trtllm_ragged_deepseek_prefill() -> bool: """Check if TRT-LLM ragged DeepSeek prefill should be used.""" + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() return ( flashinfer_available - and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL - and current_platform.is_device_capability(100) + and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill + and current_platform.is_device_capability_family(100) ) @@ -1645,6 +1654,33 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0) + def _concat_k_nope_k_pe( + self, k_nope: torch.Tensor, k_pe: torch.Tensor + ) -> torch.Tensor: + """ + Efficiently concatenate k_nope and k_pe tensors along the last dimension. + + This function avoids the performance penalty of torch.cat with expanded + non-contiguous tensors by pre-allocating the output and using direct copies. + + Args: + k_nope: Tensor of shape [..., nope_dim] + k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim] + or [..., pe_dim] + + Returns: + Tensor of shape [..., nope_dim + pe_dim] + """ + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + # Direct copies with efficient broadcasting + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + def _compute_prefill_context( self, q: torch.Tensor, @@ -1681,7 +1717,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1785,7 +1821,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1834,7 +1870,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) output_prefill = self._run_prefill_new_tokens( prefill=attn_metadata.prefill, @@ -2028,21 +2064,30 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): if fp8_attention: ql_nope_shape = decode_ql_nope.shape - decode_ql_nope, _ = ops.scaled_fp8_quant( - decode_ql_nope.reshape( - [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]] - ), - layer._q_scale, - ) - decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) q_pe_shape = decode_q_pe.shape - decode_q_pe, _ = ops.scaled_fp8_quant( - decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] + assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] + decode_q_shape = ( + ql_nope_shape[0], + ql_nope_shape[1], + ql_nope_shape[2] + q_pe_shape[2], + ) + # Using empty and copy since torch.cat introduces significant overhead. + decode_q0 = torch.empty( + decode_q_shape, + device=decode_ql_nope.device, + dtype=decode_ql_nope.dtype, + ) + decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope) + decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe) + + decode_q, _ = ops.scaled_fp8_quant( + decode_q0.view(decode_q_shape[0], -1), layer._q_scale, ) - decode_q_pe = decode_q_pe.reshape(q_pe_shape) - - decode_q = (decode_ql_nope, decode_q_pe) + decode_q = decode_q.view(decode_q_shape) + else: + decode_q = (decode_ql_nope, decode_q_pe) if self.dcp_world_size > 1: assert not fp8_attention, "DCP not support fp8 kvcache now." # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index d369814c10b6f..b28814aceada9 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -6,7 +6,6 @@ from typing import ClassVar import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, @@ -106,13 +105,14 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] vllm_config: VllmConfig, device: torch.device, ): + interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size super().__init__( kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata, - supports_dcp_with_varlen=True, + supports_dcp_with_varlen=(interleave_size == 1), ) self.max_num_splits = 0 # No upper bound on the number of splits. self.fa_aot_schedule = get_flash_attn_version() == 3 @@ -131,7 +131,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph + ) if vllm_is_batch_invariant(): self.max_num_splits = 1 diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 1eee1d225293b..0818078da0364 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -18,7 +18,7 @@ from vllm.attention.ops.flashmla import ( flash_mla_with_kvcache, get_mla_metadata, ) -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms import current_platform @@ -30,13 +30,31 @@ from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + reshape_attn_output_for_spec_decode, + reshape_query_for_spec_decode, + split_decodes_and_prefills, + split_prefill_chunks, ) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.workspace import current_workspace_manager if TYPE_CHECKING: from vllm.model_executor.models.deepseek_v2 import Indexer logger = init_logger(__name__) + +# For FP8 sparse attention we have two impelementations: +# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is +# done by treating all tokens as single batch. +# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill +# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using +# the FP8 decode kernel for decode. +# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16 +# prefill kernel requires padding the numer of heads to 128 while the decode does not +# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed +# batch mode (#2). +MIN_HEADS_FOR_BF16_PREFILL = 32 + """ NOTE: FlashMLA Sparse uses an fp8 cache with the following format @@ -127,19 +145,72 @@ class FlashMLASparseMetadata: dummy_block_table: torch.Tensor cache_lens: torch.Tensor - fp8_extra_metadata: FP8KernelMetadata | None = None + @dataclass + class FP8SeperatePrefillDecode: + @dataclass + class Decode: + kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata" + decode_query_len: int # needed for reshape in spec decode + + @dataclass + class Prefill: + # Sequence lengths (context + query) for prefill requests + # Shape: [num_prefill_reqs] + seq_lens: torch.Tensor + + # Request ID for each token: -1 for decode tokens, request index + # (0, 1, 2, ...) for prefill tokens. + # Shape: [num_actual_tokens] + request_ids: torch.Tensor + + # Workspace start offsets for all prefill requests + # Shape: [num_prefill_reqs], adjusted in-place per chunk to be + # 0-indexed within each chunk. Used to map prefill tokens to workspace + # offsets in convert_logical_index_to_physical_index + workspace_starts: torch.Tensor + + @dataclass + class Chunk: + """Metadata for a chunk of prefill requests. + + Prefill requests may be chunked to fit within the fixed workspace size. + """ + + seq_lens: torch.Tensor + tokens_slice: slice + block_table: torch.Tensor + req_start_idx: int + workspace_starts: torch.Tensor + chunk_tot_seqlen: int + + chunks: list[Chunk] + + num_prefills: int = 0 + num_decodes: int = 0 + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + + decode: Decode | None = None + prefill: Prefill | None = None + + fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None + fp8_use_mixed_batch: bool = False +# Kernel with prefill workspace support @triton.jit def _convert_req_index_to_global_index_kernel( req_id_ptr, # int32 [num_tokens] block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill + workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr # shapes (compile-time where possible) max_num_blocks_per_req: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_N: tl.constexpr, # tile width along columns + HAS_PREFILL: tl.constexpr, # strides (in elements) bt_stride0, bt_stride1, @@ -165,7 +236,10 @@ def _convert_req_index_to_global_index_kernel( # Only token == -1 should propagate as -1 is_invalid_tok = tok < 0 - + is_prefill = False + if HAS_PREFILL: + prefill_req_id = tl.load(prefill_request_id_ptr + token_id) + is_prefill = prefill_req_id >= 0 # Compute block id and in-block offset block_id = tok // BLOCK_SIZE inblock_off = tok % BLOCK_SIZE @@ -173,12 +247,18 @@ def _convert_req_index_to_global_index_kernel( # Guard block_table access valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 - base = tl.load(bt_ptr, mask=valid_block, other=0) + is_invalid_tok |= ~valid_block + base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) + out_val = base * BLOCK_SIZE + inblock_off - # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset - out_val = tl.where( - is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off - ) + # Override with prefill output if prefill is enabled + if HAS_PREFILL: + workspace_start = tl.load( + workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 + ) + prefill_out = workspace_start + tok + out_val = tl.where(is_prefill, prefill_out, out_val) + out_val = tl.where(is_invalid_tok, -1, out_val) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 @@ -192,6 +272,9 @@ def triton_convert_req_index_to_global_index( BLOCK_SIZE: int = 64, NUM_TOPK_TOKENS: int = 2048, BLOCK_N: int = 128, # tile width along columns + HAS_PREFILL_WORKSPACE: bool = False, + prefill_workspace_request_ids: torch.Tensor | None = None, + prefill_workspace_starts: torch.Tensor | None = None, ): """ out[token_id, indice_id] = @@ -202,17 +285,32 @@ def triton_convert_req_index_to_global_index( Only when token_indices[token_id, indice_id] == -1 do we output -1. For safety, we also output -1 if the derived block_id would be out-of-bounds. + + When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets + instead of global cache slots. prefill_workspace_request_ids and + prefill_workspace_starts must be provided. + + prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else + prefill request index (maps to prefill_workspace_starts) + prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace + starts for each prefill request """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" ) + if HAS_PREFILL_WORKSPACE: + assert prefill_workspace_request_ids is not None + assert prefill_workspace_starts is not None + assert prefill_workspace_request_ids.dtype == torch.int32 + assert prefill_workspace_starts.dtype == torch.int32 + num_tokens = req_id.shape[0] - num_requests, max_num_blocks_per_req = block_table.shape + max_num_blocks_per_req = block_table.shape[1] tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N # Ensure contiguous tensors on the same device @@ -226,6 +324,13 @@ def triton_convert_req_index_to_global_index( ti_stride0, ti_stride1 = token_indices_c.stride() out_stride0, out_stride1 = out.stride() + # Prepare prefill pointers + if HAS_PREFILL_WORKSPACE: + assert prefill_workspace_request_ids is not None # for mypy + assert prefill_workspace_starts is not None # for mypy + assert prefill_workspace_request_ids.is_contiguous() + assert prefill_workspace_starts.is_contiguous() + # Exact 2D grid: tokens × column tiles grid = (num_tokens, tiles_per_row) @@ -234,10 +339,13 @@ def triton_convert_req_index_to_global_index( block_table_c, token_indices_c, out, + prefill_workspace_request_ids, + prefill_workspace_starts, # shapes / constexprs max_num_blocks_per_req, BLOCK_SIZE, BLOCK_N, + HAS_PREFILL_WORKSPACE, # strides bt_stride0, bt_stride1, @@ -249,7 +357,16 @@ def triton_convert_req_index_to_global_index( return out -@dataclass +def get_prefill_workspace_size(max_model_len: int): + # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. + # May be tuned later. + # Memory usage: 5 * max_model_len * 576 * 2 bytes + # Example: DeepSeek-V3.2 with max_model_len=163840 -> + # 5 * 163840 * 576 * 2 = ~900 MB + # This fits nicely below the typical MoE workspace size of >2GB so this is "free" + return max_model_len * 5 + + class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH @@ -259,29 +376,42 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad layer_names: list[str], vllm_config: VllmConfig, device: torch.device, - ): + ) -> None: + self.vllm_config = vllm_config + self.layer_names = layer_names cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.device = device + # Treat requests with query length <= 1 as decodes to match the + # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + props = torch.cuda.get_device_properties(device) sm_count = props.multi_processor_count self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) + self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" - self.topk_tokens_tensor = torch.tensor( - [self.topk_tokens], device=device, dtype=torch.int32 + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG) + self.topk_tokens_tensor = torch.full( + (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32 ) - self.max_model_len_tensor = torch.tensor( - [self.model_config.max_model_len], device=device, dtype=torch.int32 + # Shape: [max_num_seqs], all elements = max_model_len + self.max_model_len_tensor = torch.full( + (max_num_seqs,), + self.model_config.max_model_len, + device=device, + dtype=torch.int32, ) # this is ignored by `flash_mla_with_kvcache` if indices not None self.dummy_block_table = torch.empty( - (1, 1), dtype=torch.int32, device=self.device + (max_num_seqs, 1), dtype=torch.int32, device=self.device ) # Equation taken from FlashMLA/csrc/pybind.cpp @@ -290,7 +420,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad max_num_sm_parts = int( max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1) ) - if current_platform.is_device_capability(100): + if current_platform.is_device_capability_family(100): max_num_sm_parts *= 2 self.tile_scheduler_metadata_buffer = torch.empty( # TileSchedulerMetaDataSize = 8 @@ -299,10 +429,9 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad dtype=torch.int32, device=device, ) + # Sized for per-request batching (num_decodes + 1) self.num_splits_buffer = torch.empty( - # We pack all the tokens into one batch for sparse attention. - # Otherwise, we can exceed the sm of `get_mla_metadata`. - (2,), + (max_num_seqs + 1,), dtype=torch.int32, device=device, ) @@ -312,30 +441,171 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad device=device, ) - def build( + def _build_fp8_mixed_decode_prefill( self, - common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> FlashMLASparseMetadata: - num_tokens = common_attn_metadata.num_actual_tokens - starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) - seg_lengths = np.diff(starts) - req_id_per_token = np.repeat( - np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths - ) - # Zero-fill for cudagraphs - self.req_id_per_token_buffer.fill_(0) - self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( - torch.from_numpy(req_id_per_token), non_blocking=True - ) - req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + ) -> "FlashMLASparseMetadata.FP8KernelMetadata": + """Build FP8 metadata treating all tokens as one mixed batch. + + This matches main branch's approach and avoids the BF16 prefill kernel + which has head padding overhead when num_heads is small (high TP case). + """ + num_tokens = common_attn_metadata.num_actual_tokens + + # Build metadata for all tokens as a single batch + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens=self.topk_tokens_tensor[:1], # Single batch + num_q_tokens_per_head_k=num_tokens * self.num_heads, + topk=self.topk_tokens, + num_heads_q=self.num_heads, + num_heads_k=1, + is_fp8_kvcache=True, + ) + + num_sm_parts = tile_scheduler_metadata.size(0) + tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ + :num_sm_parts + ] + tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) + num_splits_view = self.num_splits_buffer[:2] + num_splits_view.copy_(num_splits) + + fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata( + scheduler_metadata=tile_scheduler_metadata_buffer, + num_splits=num_splits_view, + cache_lens=self.max_model_len_tensor[:1], + dummy_block_table=self.dummy_block_table[:1], + ) + + return fp8_metadata + + def _build_fp8_separate_prefill_decode( + self, + common_attn_metadata: CommonAttentionMetadata, + ) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode": + num_tokens = common_attn_metadata.num_actual_tokens + + (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold or 1, + require_uniform=True, + ) + ) + + FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode + fp8_metadata = FP8Meta( + num_decodes=num_decodes, + num_prefills=num_prefills, + num_decode_tokens=num_decode_tokens, + num_prefill_tokens=num_prefill_tokens, + ) + + # Extract prefill sequence lengths (context + query, not just query) + # Decode requests come first in the batch, prefill requests follow + prefill_seq_lens = None + prefill_request_id = None + prefill_workspace_starts = None + prefill_chunks = None + + # For pure decode batches, prefill_request_id will be None + # For mixed batches, it will have -1 for decode and request_id for prefill + if num_prefills > 0: + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens = common_attn_metadata.seq_lens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:] + prefill_seq_lens = seq_lens[num_decodes:] + + # Build prefill_request_id: -1 for decode, request index for + # prefill. This enables a single + # convert_logical_index_to_physical_index call for all tokens + prefill_request_id = torch.full( + (num_tokens,), -1, dtype=torch.int32, device=self.device + ) + # Map prefill tokens to their request IDs (0, 1, 2, ...) + for req_idx in range(num_prefills): + # Get query token range for this prefill request + global_req_idx = num_decodes + req_idx + req_query_start = query_start_loc_cpu[global_req_idx] + req_query_end = query_start_loc_cpu[global_req_idx + 1] + prefill_request_id[req_query_start:req_query_end] = req_idx + + # will be adjusted by chunk loop + prefill_workspace_starts_cpu = torch.zeros( + num_prefills, dtype=torch.int32, pin_memory=True + ) + prefill_workspace_starts_cpu[1:] = torch.cumsum( + prefill_seq_lens_cpu[:-1], dim=0 + ) + # populated by non-blocking copy after prefill_workspace_starts_cpu is + # updated by each chunk + prefill_workspace_starts = torch.empty( + num_prefills, dtype=torch.int32, device=self.device + ) + + # Chunk prefill requests to fit within workspace size + max_prefill_buffer_size = get_prefill_workspace_size( + self.vllm_config.model_config.max_model_len + ) + chunk_bounds = split_prefill_chunks( + prefill_seq_lens_cpu, max_prefill_buffer_size + ) + + prefill_chunks = [] + for chunk_start, chunk_end in chunk_bounds: + # Adjust workspace_starts in-place per chunk to be + # 0-indexed within each chunk + # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]] + # Initial: workspace_starts=[0,10,25,45] + # After: workspace_starts=[0,10,0,20] + # (chunk 0 starts at 0, chunk 1 starts at 0) + offset = prefill_workspace_starts_cpu[chunk_start].item() + prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset + + chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end] + chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum() + token_start = query_start_loc_cpu[num_decodes + chunk_start].item() + token_end = query_start_loc_cpu[num_decodes + chunk_end].item() + tokens_slice = slice(token_start, token_end) + + # Create chunk view of gpu tensor + chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end] + chunk_block_table = common_attn_metadata.block_table_tensor[ + num_decodes + chunk_start : num_decodes + chunk_end + ] + + prefill_chunks.append( + FP8Meta.Prefill.Chunk( + seq_lens=chunk_seq_lens, + tokens_slice=tokens_slice, + block_table=chunk_block_table, + req_start_idx=chunk_start, + workspace_starts=chunk_workspace_starts, + chunk_tot_seqlen=chunk_tot_seqlen, + ) + ) + + prefill_workspace_starts.copy_( + prefill_workspace_starts_cpu, non_blocking=True + ) + + fp8_metadata.prefill = FP8Meta.Prefill( + seq_lens=prefill_seq_lens, + request_ids=prefill_request_id, + workspace_starts=prefill_workspace_starts, + chunks=prefill_chunks, + ) + + if num_decodes > 0: + # Compute decode_query_len for spec decode (uniform due to require_uniform) + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item() - fp8_extra_metadata = None - if self.use_fp8_kv_cache: tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens=self.topk_tokens_tensor, - num_q_tokens_per_head_k=num_tokens * self.num_heads, + cache_seqlens=self.topk_tokens_tensor[:num_decodes], + num_q_tokens_per_head_k=decode_query_len * self.num_heads, topk=self.topk_tokens, num_heads_q=self.num_heads, num_heads_k=1, @@ -348,33 +618,70 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad :num_sm_parts ] tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) - self.num_splits_buffer.copy_(num_splits) + # num_splits has size [num_decodes + 1] + num_splits_view = self.num_splits_buffer[: num_decodes + 1] + num_splits_view.copy_(num_splits) - fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( + kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata( scheduler_metadata=tile_scheduler_metadata_buffer, - num_splits=self.num_splits_buffer, - # cache_lens and block_table are basically unused in sparse case - # but the decode kernel will treat -1 and indices >= cache_lens - # as invalid so we make sure cache_lens is large enough to not - # accidentally mark indices invalid, we will use -1 exclusively - # to mark invalid indices - cache_lens=self.max_model_len_tensor, - dummy_block_table=self.dummy_block_table, + num_splits=num_splits_view, + dummy_block_table=self.dummy_block_table[:num_decodes], + cache_lens=self.max_model_len_tensor[:num_decodes], + ) + fp8_metadata.decode = FP8Meta.Decode( + kernel_metadata=kernel_meta, + decode_query_len=decode_query_len, ) + return fp8_metadata + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashMLASparseMetadata: + cm = common_attn_metadata + num_tokens = cm.num_actual_tokens + starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + + fp8_extra_metadata: ( + FlashMLASparseMetadata.FP8SeperatePrefillDecode + | FlashMLASparseMetadata.FP8KernelMetadata + | None + ) = None + fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL + if self.use_fp8_kv_cache: + if fp8_use_mixed_batch: + fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm) + else: + fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm) + metadata = FlashMLASparseMetadata( - num_reqs=common_attn_metadata.num_reqs, - max_query_len=common_attn_metadata.max_query_len, - max_seq_len=common_attn_metadata.max_seq_len, - num_actual_tokens=common_attn_metadata.num_actual_tokens, - query_start_loc=common_attn_metadata.query_start_loc, - slot_mapping=common_attn_metadata.slot_mapping, - block_table=common_attn_metadata.block_table_tensor, + num_reqs=cm.num_reqs, + max_query_len=cm.max_query_len, + max_seq_len=cm.max_seq_len, + num_actual_tokens=cm.num_actual_tokens, + query_start_loc=cm.query_start_loc, + slot_mapping=cm.slot_mapping, + block_table=cm.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, fp8_extra_metadata=fp8_extra_metadata, + fp8_use_mixed_batch=fp8_use_mixed_batch, ) + return metadata @@ -412,7 +719,21 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer - self.padding = 128 if current_platform.is_device_capability(100) else 64 + self.padding = 128 if current_platform.is_device_capability_family(100) else 64 + + if kv_cache_dtype == "fp8_ds_mla": + # Reserve workspace during initialization + vllm_config = get_current_vllm_config() + assert vllm_config is not None and vllm_config.model_config is not None + prefill_workspace_size = get_prefill_workspace_size( + vllm_config.model_config.max_model_len + ) + self.prefill_workspace_shape = (prefill_workspace_size, head_size) + (self.prefill_bf16_workspace,) = ( + current_workspace_manager().get_simultaneous( + (self.prefill_workspace_shape, torch.bfloat16) + ) + ) def _forward_bf16_kv( self, @@ -420,6 +741,184 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): kv_c_and_k_pe_cache: torch.Tensor, topk_indices: torch.Tensor, attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + topk_indices = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + ) + + return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices) + + def _forward_fp8_kv_separate_prefill_decode( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + fp8_metadata = attn_metadata.fp8_extra_metadata + assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) + num_decodes = fp8_metadata.num_decodes + + prefill_request_ids = None + prefill_workspace_starts = None + has_prefill_workspace = False + if fp8_metadata.prefill is not None: + prefill_request_ids = fp8_metadata.prefill.request_ids + prefill_workspace_starts = fp8_metadata.prefill.workspace_starts + has_prefill_workspace = True + + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + # For FP8 cache: prefill uses workspace mapping (upconverted to BF16) + # For BF16 cache: always use global cache slots (no workspace) + # prefill_workspace_starts has been adjusted in-place per chunk so + # prefill indices automatically come out chunk-local + topk_indices = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + HAS_PREFILL_WORKSPACE=has_prefill_workspace, + prefill_workspace_request_ids=prefill_request_ids, + prefill_workspace_starts=prefill_workspace_starts, + ) + + fp8_metadata = attn_metadata.fp8_extra_metadata + assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode) + + def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: + # Reshape q: (num_decode_tokens, num_heads, head_dim) + # -> (num_decodes, seq_len, num_heads, head_dim) + q = reshape_query_for_spec_decode(q, num_decodes) + seq_len = q.shape[1] + # Reshape topk_indices: (num_decode_tokens, topk) + # -> (num_decodes, seq_len, topk) + topk_indices = topk_indices.view(num_decodes, seq_len, -1) + assert fp8_metadata.decode is not None + attn_out, _ = self._fp8_flash_mla_kernel( + q=q, + kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, + topk_indices=topk_indices, + kernel_metadata=fp8_metadata.decode.kernel_metadata, + ) + # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v) + # -> (num_decode_tokens, num_heads, head_dim_v) + return reshape_attn_output_for_spec_decode(attn_out) + + num_decode_tokens = fp8_metadata.num_decode_tokens + num_prefill_tokens = fp8_metadata.num_prefill_tokens + + # Pure decode: direct call without allocation + if num_decode_tokens > 0 and num_prefill_tokens == 0: + assert fp8_metadata.decode is not None + attn_out = _fp8_decode(q, topk_indices) + else: + # Mixed or pure prefill: allocate output tensor + attn_out = q.new_empty( + (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank), + dtype=q.dtype, + device=q.device, + ) + + if num_decode_tokens > 0: + attn_out[:num_decode_tokens] = _fp8_decode( + q[:num_decode_tokens], topk_indices[:num_decode_tokens] + ) + + assert fp8_metadata.prefill is not None + for chunk in fp8_metadata.prefill.chunks: + chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen] + ops.cp_gather_and_upconvert_fp8_kv_cache( + kv_c_and_k_pe_cache, + chunk_workspace, + chunk.block_table, + chunk.seq_lens, + chunk.workspace_starts, + len(chunk.block_table), + ) + + chunk_q = q[chunk.tokens_slice] + chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice] + + attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel( + chunk_q, + chunk_workspace, + chunk_topk_indices_workspace, + ) + + return attn_out + + def _forward_fp8_kv_mixed_batch( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: + """Mixed batch FP8 forward path that treats all tokens as one batch. + + This is equivalent to main branch's approach and avoids the BF16 + prefill kernel which has head padding overhead when num_heads is small. + Used when use_mixed_batch is True. + """ + # Convert per-request indices to global slots (decode) or workspace + # offsets (prefill). + topk_indices = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + ) + + assert attn_metadata.fp8_extra_metadata is not None + assert isinstance( + attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata + ) + fp8_metadata = attn_metadata.fp8_extra_metadata + + _attn_out, _ = self._fp8_flash_mla_kernel( + q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D) + kv_c_and_k_pe_cache=kv_c_and_k_pe_cache, + topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk) + kernel_metadata=fp8_metadata, + ) + + # Output is (1, T, H, D_v), squeeze back to (T, H, D_v) + return _attn_out.squeeze(0) + + def _fp8_flash_mla_kernel( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata, + ) -> torch.Tensor: + return flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), + block_table=kernel_metadata.dummy_block_table, + head_dim_v=512, + cache_seqlens=kernel_metadata.cache_lens, + tile_scheduler_metadata=kernel_metadata.scheduler_metadata, + num_splits=kernel_metadata.num_splits, + is_fp8_kvcache=True, + indices=topk_indices, + softmax_scale=self.softmax_scale, + ) + + def _bf16_flash_mla_kernel( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, ) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( @@ -445,31 +944,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): output = output[:, : self.num_heads, :] return output - def _forward_fp8_kv( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, - ) -> torch.Tensor: - assert attn_metadata.fp8_extra_metadata is not None - extra_metadata = attn_metadata.fp8_extra_metadata - - _attn_out, _ = flash_mla_with_kvcache( - q=q.unsqueeze(0), # unsqueeze to add batch_dim - k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), - block_table=extra_metadata.dummy_block_table, - head_dim_v=512, - cache_seqlens=extra_metadata.cache_lens, - tile_scheduler_metadata=extra_metadata.scheduler_metadata, - num_splits=extra_metadata.num_splits, - is_fp8_kvcache=True, - indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim - softmax_scale=self.softmax_scale, - ) - - return _attn_out - def forward( self, layer: AttentionLayer, @@ -477,7 +951,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, - attn_metadata: FlashMLASparseMetadata, + attn_metadata: FlashMLASparseMetadata | None, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, @@ -493,6 +967,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ) if attn_metadata is None: + # Dummy run - no need to allocate buffers # The zero fill is required when used with DP + EP # to ensure all ranks within a DP group compute the # same expert outputs. @@ -505,6 +980,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] + topk_indices = self.topk_indices_buffer[:num_actual_toks] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) @@ -514,16 +990,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): # Convert from (N, B, L) to (B, N, L) ql_nope = ql_nope.transpose(0, 1) - topk_indices = self.topk_indices_buffer[:num_actual_toks] - - # TODO: handle index / kv_cache correctly - topk_indices_global = triton_convert_req_index_to_global_index( - attn_metadata.req_id_per_token, - attn_metadata.block_table, - topk_indices, - BLOCK_SIZE=attn_metadata.block_size, - NUM_TOPK_TOKENS=attn_metadata.topk_tokens, - ) + use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla" q = torch.cat([ql_nope, q_pe], dim=-1) @@ -538,13 +1005,15 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): scale=layer._k_scale, ) - if self.kv_cache_dtype != "fp8_ds_mla": - attn_out = self._forward_bf16_kv( - q, kv_cache, topk_indices_global, attn_metadata + if not use_fp8_cache: + attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata) + elif attn_metadata.fp8_use_mixed_batch: + attn_out = self._forward_fp8_kv_mixed_batch( + q, kv_cache, topk_indices, attn_metadata ) else: - attn_out = self._forward_fp8_kv( - q, kv_cache, topk_indices_global, attn_metadata + attn_out = self._forward_fp8_kv_separate_prefill_decode( + q, kv_cache, topk_indices, attn_metadata ) self._v_up_proj(attn_out, out=output[:num_actual_toks]) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 77f1ba00d5b04..d0696f60a08c7 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -18,6 +18,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills, + split_prefill_chunks, ) logger = init_logger(__name__) @@ -176,40 +177,15 @@ def kv_spans_from_batches( def get_max_prefill_buffer_size(vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len - # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size. - # May be tuned later. - return max_model_len * 2 - - -def split_prefill_chunks( - seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int -) -> list[tuple[int, int]]: - """ - Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) - such that the total sequence length of each chunk is less than the - maximum prefill buffer size. - - Args: - seq_lens_cpu: The sequence lengths of the prefill requests. - max_prefill_buffer_size: The maximum prefill buffer size. - reqs_start: The start index of the prefill requests. - - Returns: - A list of tuples of (reqs_start, reqs_end). - """ - chunk_seq_ids = [] - total_seq_lens = 0 - for i in range(reqs_start, len(seq_lens_cpu)): - cur_seq_len = seq_lens_cpu[i].item() - assert cur_seq_len <= max_prefill_buffer_size - total_seq_lens += cur_seq_len - if total_seq_lens > max_prefill_buffer_size: - chunk_seq_ids.append((reqs_start, i)) - reqs_start = i - total_seq_lens = cur_seq_len - if total_seq_lens > 0: - chunk_seq_ids.append((reqs_start, len(seq_lens_cpu))) - return chunk_seq_ids + # NOTE(Chen): 40 is a magic number for controlling the prefill buffer size. + # Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes. + # The flashmla_sparse backend uses a workspace size of 5 * max_model_len. + # The memory usage of the workspace there is 576 * 2 bytes; so we size this as + # (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting + # within the flashmla_sparse workspace. + # For DeepSeek-V3.2, the max_model_len is 163840. + # 40 * 163840 * 132 = 865075200 bytes = 825 MB + return max_model_len * 40 class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): @@ -302,9 +278,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): prefill_metadata = None if num_prefills > 0: chunk_seq_ids = split_prefill_chunks( - common_attn_metadata.seq_lens_cpu, + common_attn_metadata.seq_lens_cpu[num_decodes:], self.max_prefill_buffer_size, - num_decodes, + request_offset=num_decodes, ) chunks = [ self.build_one_prefill_chunk( diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 00a0a77a1c2f7..589d6ef2f6348 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import ( MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): qo_indptr: torch.Tensor | None = None # The dtype of MLA out tensor attn_out_dtype: torch.dtype = torch.bfloat16 + # The max query output length: int + max_qo_len: int | None = None class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): @@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 - _cudagraph_support: ClassVar[AttentionCGSupport] = ( - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - ) + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM def __init__( self, @@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): max_num_reqs, dtype=torch.int32, device=device ) - self.qo_indptr = torch.arange( - 0, max_num_reqs + 1, dtype=torch.int32, device=device + self.qo_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device ) def _build_decode( @@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): seq_lens_device.cumsum(dim=0, dtype=torch.int32), ] ) + qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_qo_len = qo_len.max().item() if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): num_actual_pages = paged_kv_indices.size(0) @@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): self.paged_kv_last_page_len[num_reqs:].fill_(1) paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + self.qo_indptr[: 1 + num_reqs].copy_( + query_start_loc_device, non_blocking=True + ) + self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1] qo_indptr = self.qo_indptr[: 1 + num_reqs] else: @@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, dcp_tot_seq_lens=dcp_tot_seq_lens_device, + max_qo_len=max_qo_len, attn_out_dtype=self.decode_attn_out_dtype, ) @@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - # max_seqlen_qo must be 1 except for MTP - # TODO: Find the best value for MTP - max_seqlen_qo = 1 rocm_aiter_ops.mla_decode_fwd( q, kv_buffer, o, self.scale, attn_metadata.decode.qo_indptr, - max_seqlen_qo, + attn_metadata.decode.max_qo_len, attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len, diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 868143cc192e7..e2410a70b1a63 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -165,7 +165,7 @@ class RocmAttentionBackend(AttentionBackend): raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. " - "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " + "Set --attention-config.backend=FLEX_ATTENTION to use " "FlexAttention backend which supports all head sizes." ) diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index de0cb73db0917..c8fe0faf71088 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -83,11 +83,10 @@ class ShortConvAttentionMetadataBuilder( and num_decodes <= self.decode_cudagraph_max_bs and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( state_indices_tensor, non_blocking=True ) - state_indices_tensor = self.state_indices_tensor[:num_input_tokens] + state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID attn_metadata = ShortConvAttentionMetadata( diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index d051a89f03bb4..7bea3862a03f9 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -17,7 +17,7 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash, ) from vllm.attention.ops.triton_unified_attention import unified_attention -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability +from vllm.utils.math_utils import next_power_of_2 from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -36,6 +37,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +# constants +MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel +NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments + + @dataclass class TritonAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -54,6 +60,12 @@ class TritonAttentionMetadata: block_table: torch.Tensor slot_mapping: torch.Tensor + seq_threshold_3D: int + num_par_softmax_segments: int + softmax_segm_output: torch.Tensor + softmax_segm_max: torch.Tensor + softmax_segm_expsum: torch.Tensor + # For cascade attention. use_cascade: bool common_prefix_len: int @@ -87,6 +99,60 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() + # Check if CUDA Graphs are enabled for decode + self.decode_cudagraph_enabled = ( + self.vllm_config.compilation_config.cudagraph_mode + in ( + CUDAGraphMode.FULL_AND_PIECEWISE, + CUDAGraphMode.FULL_DECODE_ONLY, + CUDAGraphMode.FULL, + ) + ) + + # The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv). + # A lower bound for num_q_blocks is the number of sequences. + # To ensure the minimum launch grid size is achieved, the number of sequences + # must be at least equal to the threshold below. + # If this threshold is not reached (i.e., the batch size is not large enough), + # the 3D kernel will be selected instead. + self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv + + # Modify the threshold if needed. + if self.decode_cudagraph_enabled: + capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified." + + # Select the CUDA Graph capture size closest to self.seq_threshold_3D + # as threshold. This ensures that each captured graph covers the + # correct execution path. + self.seq_threshold_3D = min( + capture_sizes, + key=lambda x: abs(x - self.seq_threshold_3D), + ) + + self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS + headdim_padded = next_power_of_2(self.headdim) + self.softmax_segm_output = torch.empty( + ( + self.seq_threshold_3D, + self.num_heads_q, + self.num_par_softmax_segments, + headdim_padded, + ), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_max = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) + self.softmax_segm_expsum = torch.empty( + (self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments), + dtype=torch.float32, + device=device, + ) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: @@ -143,6 +209,11 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + seq_threshold_3D=self.seq_threshold_3D, + num_par_softmax_segments=self.num_par_softmax_segments, + softmax_segm_output=self.softmax_segm_output, + softmax_segm_max=self.softmax_segm_max, + softmax_segm_expsum=self.softmax_segm_expsum, ) return attn_metadata @@ -210,9 +281,6 @@ class TritonAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym - def supports_quant_query_input(self) -> bool: - return current_platform.is_cuda() - def __init__( self, num_heads: int, @@ -262,6 +330,8 @@ class TritonAttentionImpl(AttentionImpl): f"num_heads: {num_heads}." ) + self.supports_quant_query_input = current_platform.is_cuda() + def forward( self, layer: torch.nn.Module, @@ -350,6 +420,12 @@ class TritonAttentionImpl(AttentionImpl): max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + seq_threshold_3D = attn_metadata.seq_threshold_3D + num_par_softmax_segments = attn_metadata.num_par_softmax_segments + softmax_segm_output = attn_metadata.softmax_segm_output + softmax_segm_max = attn_metadata.softmax_segm_max + softmax_segm_expsum = attn_metadata.softmax_segm_expsum + descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) unified_attention( @@ -370,6 +446,11 @@ class TritonAttentionImpl(AttentionImpl): q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + seq_threshold_3D=seq_threshold_3D, + num_par_softmax_segments=num_par_softmax_segments, + softmax_segm_output=softmax_segm_output, + softmax_segm_max=softmax_segm_max, + softmax_segm_expsum=softmax_segm_expsum, sinks=self.sinks, output_scale=output_scale, ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8edfbb5140bc9..1cbe929fc57a8 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -18,7 +18,7 @@ from typing import ( import numpy as np import torch -from typing_extensions import runtime_checkable +from typing_extensions import deprecated, runtime_checkable from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils.math_utils import cdiv @@ -66,11 +66,6 @@ class CommonAttentionMetadata: """(batch_size + 1,), the start location of each request in query Tensor""" seq_lens: torch.Tensor - seq_lens_cpu: torch.Tensor - """(batch_size,), the length of each request including both computed tokens - and newly scheduled tokens""" - - num_computed_tokens_cpu: torch.Tensor """(batch_size,), the number of computed tokens for each request""" num_reqs: int @@ -81,7 +76,7 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" max_seq_len: int - """Longest context length in batch""" + """Longest context length (may be an upper bound)""" block_table_tensor: torch.Tensor slot_mapping: torch.Tensor @@ -100,6 +95,40 @@ class CommonAttentionMetadata: dcp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" + # WARNING: Deprecated fields. Will be removed in a future release (v0.14.0) + _seq_lens_cpu: torch.Tensor | None = None + _num_computed_tokens_cpu: torch.Tensor | None = None + + @property + @deprecated( + """ + Prefer using device seq_lens directly to avoid implicit H<>D sync. + If a CPU copy is needed, use `seq_lens.cpu()` instead. + Will be removed in a future release (v0.14.0) + """ + ) + def seq_lens_cpu(self) -> torch.Tensor: + if self._seq_lens_cpu is None: + self._seq_lens_cpu = self.seq_lens.to("cpu") + return self._seq_lens_cpu + + @property + @deprecated( + """ + Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full + async scheduling. If a CPU copy is needed, it can be derived from + query_start_loc_cpu and seq_lens. + Will be removed in a future release (v0.14.0) + """ + ) + def num_computed_tokens_cpu(self) -> torch.Tensor: + if self._num_computed_tokens_cpu is None: + query_seq_lens = ( + self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1] + ) + self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens + return self._num_computed_tokens_cpu + # TODO(lucas): remove once we have FULL-CG spec-decode support def unpadded( self, num_actual_tokens: int, num_actual_reqs: int @@ -109,8 +138,12 @@ class CommonAttentionMetadata: query_start_loc=self.query_start_loc[: num_actual_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], seq_lens=self.seq_lens[:num_actual_reqs], - seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs], - num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs], + _seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs] + if self._seq_lens_cpu is not None + else None, + _num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs] + if self._num_computed_tokens_cpu is not None + else None, num_reqs=num_actual_reqs, num_actual_tokens=num_actual_tokens, max_query_len=self.max_query_len, @@ -168,10 +201,11 @@ def _make_metadata_with_slice( ) # NOTE: last token can be outside of the last request if we have CG padding. - # If the "middle" request has tokens in both ubatches, we have to split it. - # If ubatch_slice is the first ubatch then we will be splitting the last - # request. If it's the second microbatch, then we will be splitting the - # first request + # If the request is split across ubatches, we have to adjust the metadata. + # splits_first_request: The first request in this slice is the continuation of + # a request that started in a previous slice. + # splits_last_request: The last request in this slice continues into the + # next slice. splits_first_request = first_tok > start_locs[first_req] splits_last_request = last_tok < start_locs[last_req + 1] - 1 @@ -192,7 +226,10 @@ def _make_metadata_with_slice( seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] if splits_last_request: - tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop + # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate + # the tokens skipped because query_start_loc_cpu might have been modified + # if splits_first_request is True. + tokens_skipped = start_locs[last_req + 1] - token_slice.stop query_start_loc[-1] -= tokens_skipped query_start_loc_cpu[-1] -= tokens_skipped @@ -224,14 +261,14 @@ def _make_metadata_with_slice( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens, - seq_lens_cpu=seq_lens_cpu, - num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_requests, num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, + _seq_lens_cpu=seq_lens_cpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, ) @@ -689,9 +726,7 @@ def make_local_attention_virtual_batches( return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True), - seq_lens_cpu=seq_lens_cpu, seq_lens=seq_lens_cpu.to(device=device, non_blocking=True), - num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), num_reqs=len(seq_lens_cpu), num_actual_tokens=common_attn_metadata.num_actual_tokens, max_query_len=seqlens_q_local.max(), @@ -699,6 +734,8 @@ def make_local_attention_virtual_batches( block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, + _seq_lens_cpu=seq_lens_cpu, + _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), ) @@ -719,7 +756,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( logits_indices = logits_indices_padded[:num_logits_indices] num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens # Example inputs # num_reqs: 3 # generation_indices: [14, 18, 19, 27] @@ -748,9 +784,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( common_attn_metadata = CommonAttentionMetadata( query_start_loc=decode_query_start_loc, query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True), - seq_lens=seq_lens, - seq_lens_cpu=seq_lens.to("cpu", non_blocking=True), - num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + seq_lens=common_attn_metadata.seq_lens, num_reqs=num_reqs, num_actual_tokens=total_num_decode_tokens, max_query_len=decode_max_query_len, @@ -758,6 +792,8 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, causal=True, + _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, + _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, ) return common_attn_metadata @@ -883,11 +919,15 @@ def split_decodes_and_prefills( return 0, num_reqs, 0, num_tokens if require_uniform: + # check if we are in a padded uniform batch; this is used for full-CGs, some + # requests may have a query length of 0 but since they are padding its fine + # to treat them as decodes (ensures num_decodes matches the captured size) + if torch.all((query_lens == query_lens[0]) | (query_lens == 0)): + assert num_reqs * query_lens[0] == num_tokens, "tokens not padded correctly" + return num_reqs, 0, num_tokens, 0 # all decodes is_prefill = query_lens != query_lens[0] else: - # 0-query len indicates a padded request; leave this at the back - # of the batch with the prefills - is_prefill = (query_lens > decode_threshold) | (query_lens == 0) + is_prefill = query_lens > decode_threshold if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 @@ -901,6 +941,33 @@ def split_decodes_and_prefills( return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) +def split_prefill_chunks( + seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0 +) -> list[tuple[int, int]]: + """ + Split the prefill requests into chunks such that the total sequence length + of each chunk is less than or equal to the workspace size. + + Args: + seq_lens_cpu: The sequence lengths of the prefill requests on CPU. + workspace_size: The maximum workspace size (in tokens) per chunk. + request_offset: The offset to add to the request indices. + Returns: + A list of tuples of (reqs_start, reqs_end) representing chunk boundaries. + """ + chunk_bounds = [] + i, n = 0, len(seq_lens_cpu) + assert torch.all(seq_lens_cpu <= workspace_size).item() + + while i < n: + start, chunk_total = i, 0 + while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size: + chunk_total += s + i += 1 + chunk_bounds.append((start + request_offset, i + request_offset)) + return chunk_bounds + + def reorder_batch_to_split_decodes_and_prefills( input_batch: "InputBatch", scheduler_output: "SchedulerOutput", diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ece9e8dfb2744..bcd872c2f29a2 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -11,6 +11,7 @@ from vllm.distributed.kv_events import ( KVCacheEvent, ) from vllm.logger import init_logger +from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_utils import ( BlockHash, BlockHashList, @@ -140,6 +141,7 @@ class BlockPool: where different KV cache groups have different block sizes, the actual block size can be a multiple of hash_block_size. enable_kv_cache_events: Whether to enable kv cache events. + metrics_collector: Optional metrics collector for tracking block residency. """ def __init__( @@ -148,6 +150,7 @@ class BlockPool: enable_caching: bool, hash_block_size: int, enable_kv_cache_events: bool = False, + metrics_collector: KVCacheMetricsCollector | None = None, ): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 self.num_gpu_blocks = num_gpu_blocks @@ -174,6 +177,8 @@ class BlockPool: self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue: list[KVCacheEvent] = [] + self.metrics_collector = metrics_collector + def get_cached_block( self, block_hash: BlockHash, kv_cache_group_ids: list[int] ) -> list[KVCacheBlock] | None: @@ -308,10 +313,14 @@ class BlockPool: self._maybe_evict_cached_block(block) assert block.ref_cnt == 0 block.ref_cnt += 1 + if self.metrics_collector: + self.metrics_collector.on_block_allocated(block) else: for block in ret: assert block.ref_cnt == 0 block.ref_cnt += 1 + if self.metrics_collector: + self.metrics_collector.on_block_allocated(block) return ret def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: @@ -325,6 +334,10 @@ class BlockPool: Returns: True if the block is evicted, False otherwise. """ + # Clean up metrics tracking first to prevent leaks + if self.metrics_collector: + self.metrics_collector.on_block_evicted(block) + block_hash = block.block_hash if block_hash is None: # The block doesn't have hash, eviction is not needed @@ -365,6 +378,8 @@ class BlockPool: if block.ref_cnt == 0 and not block.is_null: self.free_block_queue.remove(block) block.ref_cnt += 1 + if self.metrics_collector: + self.metrics_collector.on_block_accessed(block) def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: """Free a list of blocks. The blocks should be ordered by their @@ -382,6 +397,25 @@ class BlockPool: [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] ) + def evict_blocks(self, block_ids: set[int]) -> None: + """evict blocks from the prefix cache by their block IDs. + + only evicts blocks that are currently cached (have a hash). blocks + with ref_cnt > 0 are not freed from the block pool, only evicted + from the prefix cache hash table. + + Args: + block_ids: Set of block IDs to evict from cache. + """ + for block_id in block_ids: + assert block_id < len(self.blocks), ( + f"Invalid block_id {block_id} >= {len(self.blocks)}. " + f"This indicates a bug in the KV connector - workers should " + f"only report block IDs that were allocated by the scheduler." + ) + block = self.blocks[block_id] + self._maybe_evict_cached_block(block) + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalid prefix caching after the weights are updated, @@ -407,6 +441,9 @@ class BlockPool: for block in self.blocks: block.reset_hash() + if self.metrics_collector: + self.metrics_collector.reset() + logger.info("Successfully reset prefix cache") if self.enable_kv_cache_events: diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 3959e9a59a53b..50f738713590b 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -341,3 +341,56 @@ def compute_mm_encoder_budget( ) return encoder_compute_budget, encoder_cache_size + + +# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only +# use the manager for scheduling purposes. Encoder-decoder models will eventually +# utilize the cache and this class will fold into EncoderCacheManager, as +# differences with MM models shrink. +class EncoderDecoderCacheManager(EncoderCacheManager): + def __init__(self, cache_size: int): + self.cache_size = cache_size + self.num_free_slots = cache_size + self.freed: list[str] = [] + + def check_and_update_cache(self, request: Request, input_id: int) -> bool: + return False + + def can_allocate( + self, + request: Request, + input_id: int, + encoder_compute_budget: int, + num_tokens_to_schedule: int, + ) -> bool: + num_tokens = request.get_num_encoder_tokens(input_id) + # Not enough compute budget + if num_tokens > encoder_compute_budget: + return False + + num_tokens += num_tokens_to_schedule + # Enough free slots + return num_tokens <= self.num_free_slots + + def allocate(self, request: Request, input_id: int) -> None: + num_encoder_tokens = request.get_num_encoder_tokens(input_id) + self.num_free_slots -= num_encoder_tokens + + mm_hash = request.mm_features[input_id].identifier + self.freed.append(mm_hash) + + def free(self, request: Request) -> None: + for input_id in range(len(request.mm_features)): + self.free_encoder_input(request, input_id) + + def get_cached_input_ids(self, request: Request) -> set[int]: + return set(range(len(request.mm_features))) + + def get_freed_mm_hashes(self) -> list[str]: + freed = self.freed + self.freed = [] + return freed + + def free_encoder_input(self, request: Request, input_id: int) -> None: + num_tokens = request.get_num_encoder_tokens(input_id) + self.num_free_slots += num_tokens diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index fd1ec8e27fba2..4b09b76c1c591 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -5,6 +5,7 @@ from collections.abc import Sequence from math import lcm from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_utils import ( BlockHash, BlockHashList, @@ -39,6 +40,7 @@ class KVCacheCoordinator(ABC): dcp_world_size: int, pcp_world_size: int, hash_block_size: int, + metrics_collector: KVCacheMetricsCollector | None = None, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len @@ -49,6 +51,7 @@ class KVCacheCoordinator(ABC): enable_caching, hash_block_size, enable_kv_cache_events, + metrics_collector, ) # Needs special handling for find_longest_cache_hit if eagle is enabled @@ -228,6 +231,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): dcp_world_size: int, pcp_world_size: int, hash_block_size: int, + metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( kv_cache_config, @@ -238,6 +242,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + metrics_collector=metrics_collector, ) self.num_single_type_manager = len(self.single_type_managers) @@ -272,6 +277,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): dcp_world_size: int, pcp_world_size: int, hash_block_size: int, + metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( kv_cache_config, @@ -282,6 +288,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + metrics_collector=metrics_collector, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size @@ -338,6 +345,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): dcp_world_size: int, pcp_world_size: int, hash_block_size: int, + metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( kv_cache_config, @@ -348,6 +356,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + metrics_collector=metrics_collector, ) # hash_block_size: the block size used to compute block hashes. # The actual block size usually equals hash_block_size, but in cases where @@ -523,6 +532,7 @@ def get_kv_cache_coordinator( dcp_world_size: int, pcp_world_size: int, hash_block_size: int, + metrics_collector: KVCacheMetricsCollector | None = None, ) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache( @@ -530,9 +540,10 @@ def get_kv_cache_coordinator( max_model_len, use_eagle, enable_kv_cache_events, - dcp_world_size, - pcp_world_size, - hash_block_size, + dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, + hash_block_size=hash_block_size, + metrics_collector=metrics_collector, ) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator( @@ -541,9 +552,10 @@ def get_kv_cache_coordinator( use_eagle, enable_caching, enable_kv_cache_events, - dcp_world_size, - pcp_world_size, - hash_block_size, + dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, + hash_block_size=hash_block_size, + metrics_collector=metrics_collector, ) return HybridKVCacheCoordinator( kv_cache_config, @@ -551,7 +563,8 @@ def get_kv_cache_coordinator( use_eagle, enable_caching, enable_kv_cache_events, - dcp_world_size, - pcp_world_size, - hash_block_size, + dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, + hash_block_size=hash_block_size, + metrics_collector=metrics_collector, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b061e5cc831dd..13086a66f6ea6 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -9,6 +9,7 @@ from typing import Literal, overload from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator +from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats @@ -102,12 +103,14 @@ class KVCacheManager: enable_kv_cache_events: bool = False, dcp_world_size: int = 1, pcp_world_size: int = 1, + metrics_collector: KVCacheMetricsCollector | None = None, ) -> None: self.max_model_len = max_model_len self.enable_caching = enable_caching self.use_eagle = use_eagle self.log_stats = log_stats + self.metrics_collector = metrics_collector # FIXME: make prefix cache stats conditional on log_stats. We still need # this comment because when the log stats is enabled there are still # potential configs we could expose in the future. @@ -122,6 +125,7 @@ class KVCacheManager: dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=hash_block_size, + metrics_collector=self.metrics_collector, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool @@ -226,6 +230,9 @@ class KVCacheManager: delay_cache_blocks: Whether to skip caching the blocks. This is used by P/D when allocating blocks used in a KV transfer which will complete in a future step. + num_encoder_tokens: The number of encoder tokens to allocate for + cross-attention in encoder-decoder models(e.g., Whisper). + For decoder-only models, this should be 0. Blocks layout: ``` @@ -326,6 +333,14 @@ class KVCacheManager: """ self.coordinator.free(request.request_id) + def evict_blocks(self, block_ids: set[int]) -> None: + """evict blocks from the prefix cache by their block IDs. + + Args: + block_ids: Set of block IDs to evict from cache. + """ + self.block_pool.evict_blocks(block_ids) + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalidate prefix caching after the weights are updated, diff --git a/vllm/v1/core/kv_cache_metrics.py b/vllm/v1/core/kv_cache_metrics.py new file mode 100644 index 0000000000000..a6dbf5b1e4034 --- /dev/null +++ b/vllm/v1/core/kv_cache_metrics.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""KV cache metrics tracking.""" + +import random +import time +from collections import deque +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_utils import KVCacheBlock + +from vllm.v1.metrics.stats import KVCacheEvictionEvent + + +class BlockMetricsState: + """Tracks lifecycle metrics for a single KV cache block.""" + + def __init__(self): + now_ns = time.monotonic_ns() + self.birth_time_ns = now_ns + self.last_access_ns = now_ns + # Bounded to prevent unbounded growth if a block is accessed many times. + self.access_history: deque[int] = deque(maxlen=4) + + def record_access(self) -> None: + now_ns = time.monotonic_ns() + self.last_access_ns = now_ns + self.access_history.append(now_ns) + + def get_lifetime_seconds(self) -> float: + now_ns = time.monotonic_ns() + return (now_ns - self.birth_time_ns) / 1e9 + + def get_idle_time_seconds(self) -> float: + now_ns = time.monotonic_ns() + return (now_ns - self.last_access_ns) / 1e9 + + def get_reuse_gaps_seconds(self) -> list[float]: + if len(self.access_history) < 2: + return [] + history = list(self.access_history) + return [(history[i] - history[i - 1]) / 1e9 for i in range(1, len(history))] + + +class KVCacheMetricsCollector: + """Collects KV cache residency metrics with sampling.""" + + def __init__(self, sample_rate: float = 0.01): + assert 0 < sample_rate <= 1.0, ( + f"sample_rate must be in (0, 1.0], got {sample_rate}" + ) + self.sample_rate = sample_rate + + self.block_metrics: dict[int, BlockMetricsState] = {} + + self._eviction_events: list[KVCacheEvictionEvent] = [] + + def should_sample_block(self) -> bool: + return random.random() < self.sample_rate + + def on_block_allocated(self, block: "KVCacheBlock") -> None: + if self.should_sample_block(): + self.block_metrics[block.block_id] = BlockMetricsState() + + def on_block_accessed(self, block: "KVCacheBlock") -> None: + metrics = self.block_metrics.get(block.block_id) + if metrics: + metrics.record_access() + + def on_block_evicted(self, block: "KVCacheBlock") -> None: + metrics = self.block_metrics.pop(block.block_id, None) + if not metrics: + return + + lifetime = metrics.get_lifetime_seconds() + idle_time = metrics.get_idle_time_seconds() + reuse_gaps = tuple(metrics.get_reuse_gaps_seconds()) + + self._eviction_events.append( + KVCacheEvictionEvent( + lifetime_seconds=lifetime, + idle_seconds=idle_time, + reuse_gaps_seconds=reuse_gaps, + ) + ) + + def reset(self) -> None: + """Clear all state on cache reset.""" + self.block_metrics.clear() + self._eviction_events.clear() + + def drain_events(self) -> list[KVCacheEvictionEvent]: + events = self._eviction_events + self._eviction_events = [] + return events diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 602eb81beb010..e4360de3717d1 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -12,7 +12,7 @@ from typing import Any, NewType, TypeAlias, overload from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils.hashing import sha256_cbor +from vllm.utils.hashing import sha256_cbor, xxhash_cbor from vllm.utils.math_utils import cdiv from vllm.utils.mem_constants import GiB_bytes from vllm.v1.kv_cache_interface import ( @@ -83,18 +83,19 @@ logger = init_logger(__name__) # # The function `init_none_hash` initializes this variable globally. NONE_HASH: BlockHash +_CBOR_HASH_FUNCTIONS = frozenset({sha256_cbor, xxhash_cbor}) def init_none_hash(hash_fn: Callable[[Any], bytes]): global NONE_HASH hash_seed = os.getenv("PYTHONHASHSEED") - if hash_seed is None and hash_fn is sha256_cbor: + if hash_seed is None and hash_fn in _CBOR_HASH_FUNCTIONS: logger.warning( "PYTHONHASHSEED is not set. This will lead to non-reproducible " - "block-hashes when using sha256_cbor as the hash function." - "Consider setting PYTHONHASHSEED to a fixed value for " - "reproducibility." + "block-hashes when using CBOR-based hash functions such as " + "sha256_cbor or xxhash_cbor. Consider setting PYTHONHASHSEED to a " + "fixed value for reproducibility." ) if hash_seed is None: @@ -686,7 +687,9 @@ def check_enough_kv_cache_memory( raise ValueError( "No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " - "initializing the engine." + "initializing the engine. " + "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " + "for more details." ) max_model_len = vllm_config.model_config.max_model_len @@ -710,8 +713,10 @@ def check_enough_kv_cache_memory( f"cache is needed, which is larger than the available KV cache " f"memory ({available_memory / GiB_bytes:.2f} GiB). " f"{estimated_msg} " - f"Try increasing `gpu_memory_utilization` or decreasing " - f"`max_model_len` when initializing the engine." + f"Try increasing `gpu_memory_utilization` or decreasing `max_model_len` " + f"when initializing the engine. " + f"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " + f"for more details." ) diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 3214f65a09728..df61eebb395e5 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -33,7 +33,7 @@ class AsyncScheduler(Scheduler): # in this scheduling step. request.num_output_placeholders += 1 + cur_num_spec_tokens # Add placeholders for the new tokens in spec_token_ids. - # Wwe will update the actual spec token ids in the worker process. + # We will update the actual spec token ids in the worker process. request.spec_token_ids = [-1] * self.num_spec_tokens scheduler_output.pending_structured_output_tokens = ( @@ -45,6 +45,12 @@ class AsyncScheduler(Scheduler): request: Request, new_token_ids: list[int], ) -> tuple[list[int], bool]: + if request.discard_latest_async_tokens: + # If the request is force preempted in reset_prefix_cache, we + # should discard the latest async token. + request.discard_latest_async_tokens = False + return [], False + status_before_update = request.status new_token_ids, stopped = super()._update_request_with_output( request, new_token_ids diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 88d99d9402821..596ab05ad320a 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -152,10 +152,18 @@ class SchedulerInterface(ABC): return self.has_unfinished_requests() or self.has_finished_requests() @abstractmethod - def reset_prefix_cache(self) -> bool: + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: """Reset the prefix cache for KV cache. This is particularly required when the model weights are live-updated. + + Args: + reset_running_requests: If True, all the running requests will be + preempted and moved to the waiting queue. Otherwise, this method + will only reset the KV prefix cache when there is no running request + taking KV cache. """ raise NotImplementedError diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 7bc1010db23a2..a00ca1912b0f3 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -137,31 +137,30 @@ class PriorityRequestQueue(RequestQueue): """ A priority queue that supports heap operations. - Requests with a smaller value of `priority` are processed first. + Respects the ordering defined in the Request class, where + requests with a smaller value of `priority` are processed first. If multiple requests have the same priority, the one with the earlier `arrival_time` is processed first. """ def __init__(self) -> None: - self._heap: list[tuple[int, float, Request]] = [] + self._heap: list[Request] = [] def add_request(self, request: Request) -> None: """Add a request to the queue according to priority policy.""" - heapq.heappush(self._heap, (request.priority, request.arrival_time, request)) + heapq.heappush(self._heap, request) def pop_request(self) -> Request: """Pop a request from the queue according to priority policy.""" if not self._heap: raise IndexError("pop from empty heap") - _, _, request = heapq.heappop(self._heap) - return request + return heapq.heappop(self._heap) def peek_request(self) -> Request: """Peek at the next request in the queue without removing it.""" if not self._heap: raise IndexError("peek from empty heap") - _, _, request = self._heap[0] - return request + return self._heap[0] def prepend_request(self, request: Request) -> None: """Add a request to the queue according to priority policy. @@ -180,15 +179,13 @@ class PriorityRequestQueue(RequestQueue): def remove_request(self, request: Request) -> None: """Remove a specific request from the queue.""" - self._heap = [(p, t, r) for p, t, r in self._heap if r != request] + self._heap.remove(request) heapq.heapify(self._heap) def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the queue.""" - requests_to_remove = set(requests) - self._heap = [ - (p, t, r) for p, t, r in self._heap if r not in requests_to_remove - ] + requests_to_remove = requests if isinstance(requests, set) else set(requests) + self._heap = [r for r in self._heap if r not in requests_to_remove] heapq.heapify(self._heap) def __bool__(self) -> bool: @@ -203,8 +200,7 @@ class PriorityRequestQueue(RequestQueue): """Iterate over the queue according to priority policy.""" heap_copy = self._heap[:] while heap_copy: - _, _, request = heapq.heappop(heap_copy) - yield request + yield heapq.heappop(heap_copy) def __reversed__(self) -> Iterator[Request]: """Iterate over the queue in reverse priority order.""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4e38b991326d3..754e0b9d08316 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,6 +7,7 @@ from collections.abc import Iterable from typing import Any from vllm import envs +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ( ECConnectorMetadata, @@ -26,9 +27,11 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import ( EncoderCacheManager, + EncoderDecoderCacheManager, compute_encoder_budget, ) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import ( CachedRequestData, @@ -40,7 +43,10 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats +from vllm.v1.metrics.stats import ( + PrefixCacheStats, + SchedulerStats, +) from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats @@ -69,6 +75,12 @@ class Scheduler(SchedulerInterface): self.kv_events_config = vllm_config.kv_events_config self.parallel_config = vllm_config.parallel_config self.log_stats = log_stats + self.observability_config = vllm_config.observability_config + self.kv_metrics_collector: KVCacheMetricsCollector | None = None + if self.observability_config.kv_cache_metrics: + self.kv_metrics_collector = KVCacheMetricsCollector( + self.observability_config.kv_cache_metrics_sample, + ) self.structured_output_manager = structured_output_manager self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder @@ -95,6 +107,7 @@ class Scheduler(SchedulerInterface): # KV Connector pushes/pull of remote KVs for P/D and offloading. self.connector = None self.connector_prefix_cache_stats: PrefixCacheStats | None = None + self.recompute_kv_load_failures = True if self.vllm_config.kv_transfer_config is not None: assert not self.is_encoder_decoder, ( "Encoder-decoder models are not currently supported with KV connectors" @@ -106,6 +119,10 @@ class Scheduler(SchedulerInterface): ) if self.log_stats: self.connector_prefix_cache_stats = PrefixCacheStats() + kv_load_failure_policy = ( + self.vllm_config.kv_transfer_config.kv_load_failure_policy + ) + self.recompute_kv_load_failures = kv_load_failure_policy == "recompute" self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -165,7 +182,17 @@ class Scheduler(SchedulerInterface): # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. - self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size) + self.encoder_cache_manager = ( + EncoderDecoderCacheManager(cache_size=encoder_cache_size) + if self.is_encoder_decoder + else EncoderCacheManager(cache_size=encoder_cache_size) + ) + # For encoder-decoder models, allocate the maximum number of tokens for Cross + # Attn blocks, as for Whisper its input is always padded to the maximum length. + # TODO (NickLucche): Generalize to models with variable-length encoder inputs. + self._num_encoder_max_input_tokens = ( + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config) + ) speculative_config = vllm_config.speculative_config self.use_eagle = False @@ -187,6 +214,7 @@ class Scheduler(SchedulerInterface): dcp_world_size=self.dcp_world_size, pcp_world_size=self.pcp_world_size, hash_block_size=self.block_size, + metrics_collector=self.kv_metrics_collector, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER @@ -225,6 +253,22 @@ class Scheduler(SchedulerInterface): while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] + if ( + request.num_output_placeholders > 0 + # This is (num_computed_tokens + 1) - (num_output_placeholders - 1). + # Since output placeholders are also included in the computed tokens + # count, we subtract (num_output_placeholders - 1) to remove any draft + # tokens, so that we can be sure no further steps are needed even if + # they are all rejected. + and request.num_computed_tokens + 2 - request.num_output_placeholders + >= request.num_prompt_tokens + request.max_tokens + ): + # Async scheduling: Avoid scheduling an extra step when we are sure that + # the previous step has reached request.max_tokens. We don't schedule + # partial draft tokens since this prevents uniform decode optimizations. + req_index += 1 + continue + num_new_tokens = ( request.num_tokens_with_spec + request.num_output_placeholders @@ -234,18 +278,10 @@ class Scheduler(SchedulerInterface): num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) - num_spec_placeholders = max(0, request.num_output_placeholders - 1) - max_total_tokens = min( - # Avoid scheduling tokens that we're sure won't will be needed based on - # request.max_tokens. For this calculation we assume placeholder - # speculated output tokens are rejected. - request.num_prompt_tokens + request.max_tokens + num_spec_placeholders, - # Make sure the input position does not exceed the max model len. - # This is necessary when using spec decoding. - self.max_model_len, - ) + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. num_new_tokens = min( - num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens + num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens ) # Schedule encoder inputs. @@ -328,17 +364,7 @@ class Scheduler(SchedulerInterface): else: preempted_req = self.running.pop() - self.kv_cache_manager.free(preempted_req) - self.encoder_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - preempted_req.num_preemptions += 1 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, scheduled_timestamp - ) - - self.waiting.prepend_request(preempted_req) + self._preempt_request(preempted_req, scheduled_timestamp) preempted_reqs.append(preempted_req) if preempted_req == request: # No more request to preempt. Cannot schedule this request. @@ -548,17 +574,11 @@ class Scheduler(SchedulerInterface): 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens ) - # Determine if we need to allocate cross-attention blocks. - if self.is_encoder_decoder and request.has_encoder_inputs: - # TODO(russellb): For Whisper, we know that the input is - # always padded to the maximum length. If we support other - # encoder-decoder models, this will need to be updated if we - # want to only allocate what is needed. - num_encoder_tokens = ( - self.scheduler_config.max_num_encoder_input_tokens - ) - else: - num_encoder_tokens = 0 + num_encoder_tokens = ( + self._num_encoder_max_input_tokens + if self.is_encoder_decoder and request.has_encoder_inputs + else 0 + ) new_blocks = self.kv_cache_manager.allocate_slots( request, @@ -597,7 +617,6 @@ class Scheduler(SchedulerInterface): self._update_connector_prefix_cache_stats(request) - req_index += 1 self.running.append(request) if self.log_stats: request.record_event( @@ -737,6 +756,30 @@ class Scheduler(SchedulerInterface): self._update_after_schedule(scheduler_output) return scheduler_output + def _preempt_request( + self, + request: Request, + timestamp: float, + ) -> None: + """Preempt a request and put it back to the waiting queue. + + NOTE: The request should be popped from the running queue outside of this + method. + """ + assert request.status == RequestStatus.RUNNING, ( + "Only running requests can be preempted" + ) + self.kv_cache_manager.free(request) + self.encoder_cache_manager.free(request) + request.status = RequestStatus.PREEMPTED + request.num_computed_tokens = 0 + request.num_preemptions += 1 + if self.log_stats: + request.record_event(EngineCoreEventType.PREEMPTED, timestamp) + + # Put the request back to the waiting queue. + self.waiting.prepend_request(request) + def _update_after_schedule( self, scheduler_output: SchedulerOutput, @@ -788,15 +831,15 @@ class Scheduler(SchedulerInterface): for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) - num_tokens = num_scheduled_tokens[req_id] - len( - spec_decode_tokens.get(req_id, ()) - ) if self.use_pp: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. Otherwise, we don't # need to send the sampled tokens back because the model runner # will cache them. + num_tokens = num_scheduled_tokens[req_id] - len( + spec_decode_tokens.get(req_id, ()) + ) token_ids = req.all_token_ids[ req.num_computed_tokens : req.num_computed_tokens + num_tokens ] @@ -1004,6 +1047,7 @@ class Scheduler(SchedulerInterface): pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output + cudagraph_stats = model_runner_output.cudagraph_stats outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None @@ -1032,7 +1076,7 @@ class Scheduler(SchedulerInterface): for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): assert num_tokens_scheduled > 0 if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: - # Skip requests that were recovered from KV load failure + # skip failed or rescheduled requests from KV load failure continue request = self.requests.get(req_id) if request is None: @@ -1073,6 +1117,7 @@ class Scheduler(SchedulerInterface): stopped = False new_logprobs = None new_token_ids = generated_token_ids + pooler_output = pooler_outputs[req_index] if pooler_outputs else None kv_transfer_params = None status_before_stop = request.status @@ -1081,12 +1126,10 @@ class Scheduler(SchedulerInterface): new_token_ids, stopped = self._update_request_with_output( request, new_token_ids ) - - # Stop checking for pooler models. - pooler_output = None - if pooler_outputs: - pooler_output = pooler_outputs[req_index] - stopped = check_stop(request, self.max_model_len, pooler_output) + elif request.pooling_params and pooler_output is not None: + # Pooling stops as soon as there is output. + request.status = RequestStatus.FINISHED_STOPPED + stopped = True if stopped: kv_transfer_params = self._free_request(request) @@ -1143,6 +1186,21 @@ class Scheduler(SchedulerInterface): # This is a rare case and unlikely to impact performance. self.waiting.remove_requests(stopped_preempted_reqs) + if failed_kv_load_req_ids and not self.recompute_kv_load_failures: + requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids] + self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR) + for request in requests: + outputs[request.client_index].append( + EngineCoreOutput( + request_id=request.request_id, + new_token_ids=[], + finish_reason=request.get_finished_reason(), + events=request.take_events(), + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + ) + ) + # KV Connector: update state for finished KV Transfers. if kv_connector_output: self._update_from_kv_xfer_finished(kv_connector_output) @@ -1186,7 +1244,9 @@ class Scheduler(SchedulerInterface): finished_req_ids.clear() if ( - stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + stats := self.make_stats( + spec_decoding_stats, kv_connector_stats, cudagraph_stats + ) ) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: @@ -1343,27 +1403,96 @@ class Scheduler(SchedulerInterface): def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 - def reset_prefix_cache(self) -> bool: - return self.kv_cache_manager.reset_prefix_cache() + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + """Reset the KV prefix cache. + + If reset_running_requests is True, all the running requests will be + preempted and moved to the waiting queue. + Otherwise, this method will only reset the KV prefix cache when there + is no running requests taking KV cache. + """ + if reset_running_requests: + # For logging. + timestamp = time.monotonic() + # Invalidate all the current running requests KV's by pushing them to + # the waiting queue. In this case, we can reduce the ref count of all + # the kv blocks to 0 and thus we can make sure the reset is successful. + # Preempt in reverse order so the requests will be added back to the + # running queue in FIFO order. + while self.running: + request = self.running.pop() + self._preempt_request(request, timestamp) + # NOTE(zhuohan): For async scheduling, we need to discard the latest + # output token on the fly to avoid a redundant repetitive output token. + request.num_output_placeholders = 0 + request.discard_latest_async_tokens = True + + # Clear scheduled request ids cache. Since we are forcing preemption + # + resumption in the same step, we must act as if these requests were + # not scheduled in the prior step. They will be flushed from the + # persistent batch in the model runner. + self.prev_step_scheduled_req_ids.clear() + + reset_successful = self.kv_cache_manager.reset_prefix_cache() + if reset_running_requests and not reset_successful: + raise RuntimeError( + "Failed to reset KV cache even when all the running requests are " + "preempted and moved to the waiting queue. This is likely due to " + "the presence of running requests waiting for remote KV transfer, " + "which is not supported yet." + ) + + if reset_connector: + reset_successful = self.reset_connector_cache() and reset_successful + + return reset_successful + + def reset_connector_cache(self) -> bool: + if self.connector is None: + logger.warning("reset_connector called but no KV connector is configured.") + return False + + if self.connector.reset_cache() is False: + return False + + if self.log_stats: + assert self.connector_prefix_cache_stats is not None + self.connector_prefix_cache_stats.reset = True + + return True def make_stats( self, spec_decoding_stats: SpecDecodingStats | None = None, kv_connector_stats: KVConnectorStats | None = None, + cudagraph_stats: CUDAGraphStat | None = None, ) -> SchedulerStats | None: if not self.log_stats: return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None connector_prefix_cache_stats = self._make_connector_prefix_cache_stats() + eviction_events = ( + self.kv_metrics_collector.drain_events() + if self.kv_metrics_collector is not None + else [] + ) + spec_stats = spec_decoding_stats + connector_stats_payload = ( + kv_connector_stats.data if kv_connector_stats else None + ) return SchedulerStats( num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), kv_cache_usage=self.kv_cache_manager.usage, prefix_cache_stats=prefix_cache_stats, connector_prefix_cache_stats=connector_prefix_cache_stats, - spec_decoding_stats=spec_decoding_stats, - kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, + kv_cache_eviction_events=eviction_events, + spec_decoding_stats=spec_stats, + kv_connector_stats=connector_stats_payload, + cudagraph_stats=cudagraph_stats, ) def make_spec_decoding_stats( @@ -1505,8 +1634,11 @@ class Scheduler(SchedulerInterface): self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( - self, requests: Iterable[Request], invalid_block_ids: set[int] - ) -> tuple[set[str], int]: + self, + requests: Iterable[Request], + invalid_block_ids: set[int], + evict_blocks: bool = True, + ) -> tuple[set[str], int, set[int]]: """ Identify and update requests affected by invalid KV cache blocks. @@ -1518,16 +1650,21 @@ class Scheduler(SchedulerInterface): Args: requests: The set of requests to scan for invalid blocks. invalid_block_ids: IDs of invalid blocks. + evict_blocks: Whether to collect blocks for eviction (False for + async requests which aren't cached yet). Returns: tuple: - affected_req_ids (set[str]): IDs of requests impacted by invalid blocks. - total_affected_tokens (int): Total number of tokens that must - be recomputed across all affected requests (for observability). + be recomputed across all affected requests. + - blocks_to_evict (set[int]): Block IDs to evict from cache, + including invalid blocks and downstream dependent blocks. """ affected_req_ids: set[str] = set() total_affected_tokens = 0 + blocks_to_evict: set[int] = set() # If a block is invalid and shared by multiple requests in the batch, # these requests must be rescheduled, but only the first will recompute # it. This set tracks blocks already marked for recomputation. @@ -1585,6 +1722,9 @@ class Scheduler(SchedulerInterface): ) total_affected_tokens += num_affected_tokens request.num_external_computed_tokens -= num_affected_tokens + # collect invalid block and all downstream dependent blocks + if evict_blocks: + blocks_to_evict.update(req_block_ids[idx:]) if is_affected: if not marked_invalid_block: @@ -1600,47 +1740,70 @@ class Scheduler(SchedulerInterface): affected_req_ids.add(request.request_id) - return affected_req_ids, total_affected_tokens + return affected_req_ids, total_affected_tokens, blocks_to_evict def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: - total_requests_to_reschedule = 0 - total_tokens_to_reschedule = 0 + """ + Handle requests affected by invalid KV cache blocks. - # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- + Returns: + Set of affected request IDs to skip in update_from_output main loop. + """ + should_fail = not self.recompute_kv_load_failures + + # handle async KV loads (not cached yet, evict_blocks=False) async_load_reqs = ( req for req in self.waiting if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS ) - async_affected_req_ids, num_tokens_to_reschedule = ( + async_failed_req_ids, num_failed_tokens, _ = ( self._update_requests_with_invalid_blocks( - async_load_reqs, invalid_block_ids + async_load_reqs, invalid_block_ids, evict_blocks=False ) ) - total_requests_to_reschedule += len(async_affected_req_ids) - total_tokens_to_reschedule += num_tokens_to_reschedule + total_failed_requests = len(async_failed_req_ids) + total_failed_tokens = num_failed_tokens - # Mark requests with async KV load failures; they will be rescheduled - # once loading completes. - self.failed_recving_kv_req_ids |= async_affected_req_ids - - # --- Handle sync KV loads (running requests) --- - sync_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(self.running, invalid_block_ids) + # handle sync loads (may be cached, collect blocks for eviction) + sync_failed_req_ids, num_failed_tokens, sync_blocks_to_evict = ( + self._update_requests_with_invalid_blocks( + self.running, invalid_block_ids, evict_blocks=True + ) ) - total_requests_to_reschedule += len(sync_affected_req_ids) - total_tokens_to_reschedule += num_tokens_to_reschedule + total_failed_requests += len(sync_failed_req_ids) + total_failed_tokens += num_failed_tokens - if total_requests_to_reschedule: - logger.warning( - "Recovered from KV load failure: " - "%d request(s) rescheduled (%d tokens affected).", - total_requests_to_reschedule, - total_tokens_to_reschedule, + if not total_failed_requests: + return set() + + # evict invalid blocks and downstream dependent blocks from cache + # only when not using recompute policy (where blocks will be recomputed + # and reused by other requests sharing them) + if sync_blocks_to_evict and not self.recompute_kv_load_failures: + self.kv_cache_manager.evict_blocks(sync_blocks_to_evict) + + if should_fail: + all_failed_req_ids = async_failed_req_ids | sync_failed_req_ids + logger.error( + "Failing %d request(s) due to KV load failure " + "(failure_policy=fail, %d tokens affected). Request IDs: %s", + total_failed_requests, + total_failed_tokens, + all_failed_req_ids, ) + return all_failed_req_ids - # Return the IDs of affected running requests to skip in - # update_from_output. - return sync_affected_req_ids + logger.warning( + "Recovered from KV load failure: " + "%d request(s) rescheduled (%d tokens affected).", + total_failed_requests, + total_failed_tokens, + ) + + # Mark async requests with KV load failures for retry once loading completes + self.failed_recving_kv_req_ids |= async_failed_req_ids + # Return sync affected IDs to skip in update_from_output + return sync_failed_req_ids diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 82166dc978396..6319731883225 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -import torch - from vllm.v1.request import Request, RequestStatus @@ -39,14 +37,8 @@ def remove_all(lst: list, items_to_remove: set) -> list: return [item for item in lst if item not in items_to_remove] -def check_stop( - request: Request, max_model_len: int, pooler_output: torch.Tensor | None = None -) -> bool: - if request.pooling_params: - if pooler_output is not None: - request.status = RequestStatus.FINISHED_STOPPED - return True - return False +def check_stop(request: Request, max_model_len: int) -> bool: + assert not request.pooling_params sampling_params = request.sampling_params assert sampling_params is not None diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index ef0f8d9e67452..8a3500c0aac6b 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -145,7 +145,7 @@ class CudagraphDispatcher: num_tokens: int, uniform_decode: bool, has_lora: bool, - use_cascade_attn: bool = False, + disable_full: bool = False, ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ Given conditions(e.g.,batch descriptor and if using cascade attention), @@ -165,7 +165,7 @@ class CudagraphDispatcher: ) relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() - if not use_cascade_attn: + if not disable_full: # check if key exists for full cudagraph if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: return CUDAGraphMode.FULL, batch_desc diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index ce2aae77108da..4f54d12f4b8d0 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -19,24 +19,27 @@ from vllm.v1.serial_utils import UtilityResult # These are possible values of RequestOutput.finish_reason, # so form part of the external API. -FINISH_REASON_STRINGS = ("stop", "length", "abort") +FINISH_REASON_STRINGS = ("stop", "length", "abort", "error") class FinishReason(enum.IntEnum): """ - Reason a request finished - stop, length, or abort. + Reason a request finished - stop, length, abort, or error. Int rather than Str for more compact serialization. stop - a stop string was emitted length - max_tokens was consumed, or max_model_len was reached - abort - aborted for another reason + abort - aborted by client + error - retryable request-level internal error (e.g., KV load failure). + Invariant: always converted to 500 Internal Server Error. """ STOP = 0 LENGTH = 1 ABORT = 2 + ERROR = 3 def __str__(self): return FINISH_REASON_STRINGS[self.value] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index d0708a8a046d1..a6ee241c41151 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -26,10 +26,9 @@ from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext from vllm.utils.async_utils import cancel_task_threadsafe from vllm.utils.collection_utils import as_list @@ -112,7 +111,7 @@ class AsyncLLM(EngineClient): if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = init_tokenizer_from_configs(self.model_config) + tokenizer = cached_tokenizer_from_config(self.model_config) self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.io_processor = get_io_processor( @@ -167,32 +166,24 @@ class AsyncLLM(EngineClient): pass if ( - envs.VLLM_TORCH_PROFILER_DIR - and not envs.VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM + vllm_config.profiler_config.profiler == "torch" + and not vllm_config.profiler_config.ignore_frontend ): + profiler_dir = vllm_config.profiler_config.torch_profiler_dir logger.info( "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 - envs.VLLM_TORCH_PROFILER_DIR, + profiler_dir, ) - if envs.VLLM_PROFILER_MAX_ITERS > 0 or envs.VLLM_PROFILER_DELAY_ITERS > 0: - logger.warning_once( - "Torch profiler received max_iters or delay_iters setting. These " - "are not compatible with the AsyncLLM profiler and will be ignored " - "for the AsyncLLM process. Engine process profiling will still " - "respect these settings. Consider setting " - "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM=1 to disable " - "AsyncLLM profiling." - ) worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" self.profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, ], - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + with_stack=vllm_config.profiler_config.torch_profiler_with_stack, on_trace_ready=torch.profiler.tensorboard_trace_handler( - envs.VLLM_TORCH_PROFILER_DIR, + profiler_dir, worker_name=worker_name, - use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, + use_gzip=vllm_config.profiler_config.torch_profiler_use_gzip, ), ) else: @@ -201,7 +192,7 @@ class AsyncLLM(EngineClient): @property @deprecated( "`AsyncLLM.processor` has been renamed to `AsyncLLM.input_processor`. " - "The old name will be removed in v0.13." + "The old name will be removed in v0.14." ) def processor(self): return self.input_processor @@ -710,10 +701,6 @@ class AsyncLLM(EngineClient): def tokenizer(self) -> TokenizerLike | None: return self.input_processor.tokenizer - @tokenizer.setter - def tokenizer(self, tokenizer: TokenizerLike | None) -> None: - self.input_processor.tokenizer = tokenizer - async def get_tokenizer(self) -> TokenizerLike: if self.tokenizer is None: raise ValueError( @@ -750,8 +737,12 @@ class AsyncLLM(EngineClient): self.input_processor.clear_mm_cache() await self.engine_core.reset_mm_cache_async() - async def reset_prefix_cache(self) -> None: - await self.engine_core.reset_prefix_cache_async() + async def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return await self.engine_core.reset_prefix_cache_async( + reset_running_requests, reset_connector + ) async def sleep(self, level: int = 1) -> None: await self.reset_prefix_cache() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e3a5f51a8fc56..0045b8c1dd3e7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -204,11 +204,16 @@ class EngineCore: ) self.async_scheduling = vllm_config.scheduler_config.async_scheduling + self.aborts_queue = queue.Queue[list[str]]() + # Mark the startup heap as static so that it's ignored by GC. # Reduces pause times of oldest generation collections. freeze_gc_heap() # If enable, attach GC debugger after static variable freeze. maybe_attach_gc_debug_callback() + # Enable environment variable cache (e.g. assume no more + # environment variable overrides after this point) + enable_envs_cache() def _initialize_kv_caches( self, vllm_config: VllmConfig @@ -347,6 +352,9 @@ class EngineCore: if model_output is None: model_output = self.model_executor.sample_tokens(grammar_output) + # Before processing the model output, process any aborts that happened + # during the model execution. + self._process_aborts_queue() engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output ) @@ -440,6 +448,9 @@ class EngineCore: with self.log_error_detail(scheduler_output): model_output = future.result() + # Before processing the model output, process any aborts that happened + # during the model execution. + self._process_aborts_queue() engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output ) @@ -458,6 +469,18 @@ class EngineCore: return engine_core_outputs, model_executed + def _process_aborts_queue(self): + if not self.aborts_queue.empty(): + request_ids = [] + while not self.aborts_queue.empty(): + ids = self.aborts_queue.get_nowait() + if isinstance(ids, str): + # Should be a list here, but also handle string just in case. + ids = (ids,) + request_ids.extend(ids) + # More efficient to abort all as a single batch. + self.abort_requests(request_ids) + def shutdown(self): self.structured_output_manager.clear_backend() if self.model_executor: @@ -483,8 +506,12 @@ class EngineCore: self.model_executor.reset_mm_cache() - def reset_prefix_cache(self): - self.scheduler.reset_prefix_cache() + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.scheduler.reset_prefix_cache( + reset_running_requests, reset_connector + ) def sleep(self, level: int = 1): self.model_executor.sleep(level) @@ -648,10 +675,6 @@ class EngineCoreProc(EngineCore): assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") - # Enable environment variable cache (e.g. assume no more - # environment variable overrides after this point) - enable_envs_cache() - @contextmanager def _perform_handshakes( self, @@ -871,9 +894,13 @@ class EngineCoreProc(EngineCore): and not self.scheduler.has_requests() and not self.batch_queue ): - if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): - logger.debug("EngineCore waiting for work.") - waited = True + if self.input_queue.empty(): + # Drain aborts queue; all aborts are also processed via input_queue. + with self.aborts_queue.mutex: + self.aborts_queue.queue.clear() + if logger.isEnabledFor(DEBUG): + logger.debug("EngineCore waiting for work.") + waited = True req = self.input_queue.get() self._handle_client_request(*req) @@ -1027,6 +1054,13 @@ class EngineCoreProc(EngineCore): else: request = generic_decoder.decode(data_frames) + if request_type == EngineCoreRequestType.ABORT: + # Aborts are added to *both* queues, allows us to eagerly + # process aborts while also ensuring ordering in the input + # queue to avoid leaking requests. This is ok because + # aborting in the scheduler is idempotent. + self.aborts_queue.put_nowait(request) + # Push to input queue for core busy loop. self.input_queue.put_nowait((request_type, request)) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9b440505bd9dc..c936646aa7993 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -138,7 +138,9 @@ class EngineCoreClient(ABC): def reset_mm_cache(self) -> None: raise NotImplementedError - def reset_prefix_cache(self) -> None: + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: raise NotImplementedError def sleep(self, level: int = 1) -> None: @@ -208,7 +210,9 @@ class EngineCoreClient(ABC): async def reset_mm_cache_async(self) -> None: raise NotImplementedError - async def reset_prefix_cache_async(self) -> None: + async def reset_prefix_cache_async( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: raise NotImplementedError async def sleep_async(self, level: int = 1) -> None: @@ -287,8 +291,12 @@ class InprocClient(EngineCoreClient): def reset_mm_cache(self) -> None: self.engine_core.reset_mm_cache() - def reset_prefix_cache(self) -> None: - self.engine_core.reset_prefix_cache() + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.engine_core.reset_prefix_cache( + reset_running_requests, reset_connector + ) def sleep(self, level: int = 1) -> None: self.engine_core.sleep(level) @@ -751,8 +759,12 @@ class SyncMPClient(MPClient): def reset_mm_cache(self) -> None: self.call_utility("reset_mm_cache") - def reset_prefix_cache(self) -> None: - self.call_utility("reset_prefix_cache") + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.call_utility( + "reset_prefix_cache", reset_running_requests, reset_connector + ) def add_lora(self, lora_request: LoRARequest) -> bool: return self.call_utility("add_lora", lora_request) @@ -955,8 +967,12 @@ class AsyncMPClient(MPClient): async def reset_mm_cache_async(self) -> None: await self.call_utility_async("reset_mm_cache") - async def reset_prefix_cache_async(self) -> None: - await self.call_utility_async("reset_prefix_cache") + async def reset_prefix_cache_async( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return await self.call_utility_async( + "reset_prefix_cache", reset_running_requests, reset_connector + ) async def sleep_async(self, level: int = 1) -> None: await self.call_utility_async("sleep", level) diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index e6a94f4e3de5d..65e0c845b0afa 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -19,7 +19,8 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest from vllm.v1.metrics.stats import MultiModalCacheStats @@ -64,10 +65,6 @@ class InputProcessor: def tokenizer(self) -> TokenizerLike | None: return self.input_preprocessor.tokenizer - @tokenizer.setter - def tokenizer(self, tokenizer: TokenizerLike | None) -> None: - self.input_preprocessor.tokenizer = tokenizer - def _validate_logprobs( self, params: SamplingParams, @@ -192,29 +189,39 @@ class InputProcessor: def _validate_single_prompt(single_prompt: dict | str) -> None: if not isinstance(single_prompt, dict): return + mm_data = single_prompt.get("multi_modal_data") mm_uuids = single_prompt.get("multi_modal_uuids") if not mm_data or not mm_uuids: return + import torch + + def _get_len(items: object): + if isinstance(items, dict): # Embedding inputs + return _get_len(next(iter(items.values()))) if items else 1 + + if isinstance(items, list): + return len(items) + if isinstance(items, torch.Tensor): + # To keep backwards compatibility for single item embedding input + return 1 if getattr(items, "_is_single_item", False) else len(items) + + return 1 + for modality, items in mm_data.items(): if modality in mm_uuids: - data_len = len(items) if isinstance(items, list) else 1 - uuid_len = ( - len(mm_uuids[modality]) - if isinstance(mm_uuids[modality], list) - else 1 - ) + data_len = _get_len(items) + uuid_len = _get_len(mm_uuids[modality]) if uuid_len != data_len: raise ValueError( - f"multi_modal_uuids for modality '{modality}' " + f"multi_modal_uuids for modality {modality!r} " "must have same length as data: got " - f"{uuid_len} uuids vs " - f"{data_len} items." + f"{uuid_len} uuids vs {data_len} items." ) else: raise ValueError( - f"multi_modal_uuids for modality '{modality}' must " + f"multi_modal_uuids for modality {modality!r} must " "be provided if multi_modal_data is provided." ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a3bde7ba8d64d..1011317b706d3 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -23,9 +23,8 @@ from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask -from vllm.tokenizers import TokenizerLike +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tracing import init_tracer -from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient @@ -87,7 +86,7 @@ class LLMEngine: if self.model_config.skip_tokenizer_init: tokenizer = None else: - tokenizer = init_tokenizer_from_configs(self.model_config) + tokenizer = cached_tokenizer_from_config(self.model_config) self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.io_processor = get_io_processor( @@ -140,7 +139,7 @@ class LLMEngine: @property @deprecated( "`LLMEngine.processor` has been renamed to `LLMEngine.input_processor`. " - "The old name will be removed in v0.13." + "The old name will be removed in v0.14." ) def processor(self): return self.input_processor @@ -329,8 +328,12 @@ class LLMEngine: self.input_processor.clear_mm_cache() self.engine_core.reset_mm_cache() - def reset_prefix_cache(self): - self.engine_core.reset_prefix_cache() + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.engine_core.reset_prefix_cache( + reset_running_requests, reset_connector + ) def sleep(self, level: int = 1): self.engine_core.sleep(level) @@ -355,10 +358,6 @@ class LLMEngine: def tokenizer(self) -> TokenizerLike | None: return self.input_processor.tokenizer - @tokenizer.setter - def tokenizer(self, tokenizer: TokenizerLike | None) -> None: - self.input_processor.tokenizer = tokenizer - def get_tokenizer(self) -> TokenizerLike: if self.tokenizer is None: raise ValueError( @@ -410,8 +409,6 @@ class LLMEngine: return self.collective_rpc("apply_model", args=(func,)) def __del__(self): - if ( - dp_group := getattr(self, "dp_group", None) - and not self.external_launcher_dp - ): + dp_group = getattr(self, "dp_group", None) + if dp_group is not None and not self.external_launcher_dp: stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index e85fbb4ee0fb0..8f7d8a71f1a2e 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,6 +8,7 @@ from typing import Any, cast import torch +from vllm.lora.request import LoRARequest from vllm.outputs import ( CompletionOutput, PoolingOutput, @@ -93,7 +94,7 @@ class RequestState: request_id: str, parent_req: ParentRequest | None, request_index: int, - lora_name: str | None, + lora_request: LoRARequest | None, output_kind: RequestOutputKind, prompt: str | None, prompt_token_ids: list[int] | None, @@ -112,7 +113,8 @@ class RequestState: self.request_id = request_id self.parent_req = parent_req self.request_index = request_index - self.lora_name = lora_name + self.lora_request = lora_request + self.lora_name = lora_request.lora_name if lora_request is not None else None self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids @@ -178,9 +180,7 @@ class RequestState: request_id=request.request_id, parent_req=parent_req, request_index=request_index, - lora_name=( - request.lora_request.name if request.lora_request is not None else None - ), + lora_request=request.lora_request, output_kind=output_kind, prompt=prompt, prompt_token_ids=request.prompt_token_ids, @@ -289,6 +289,7 @@ class RequestState: return RequestOutput( request_id=request_id, + lora_request=self.lora_request, prompt=self.prompt, prompt_token_ids=prompt_token_ids, prompt_logprobs=prompt_logprobs, @@ -650,6 +651,7 @@ class OutputProcessor: ), max_tokens_param=req_state.max_tokens_param, req_stats=req_state.stats, + num_cached_tokens=req_state.num_cached_tokens, ) self.lora_states.request_finished(req_state.request_id, req_state.lora_name) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index bc5c7fc400fde..a8c93499299d3 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -10,7 +10,7 @@ def __getattr__(name: str): warnings.warn( "`vllm.v1.engine.processor.Processor` has been moved to " "`vllm.v1.engine.input_processor.InputProcessor`. " - "The old name will be removed in v0.13.", + "The old name will be removed in v0.14.", DeprecationWarning, stacklevel=2, ) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index db8303fcec501..8ada52435edae 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -219,7 +219,7 @@ class Executor(ABC): def sample_tokens( self, grammar_output: GrammarOutput | None, non_block: bool = False - ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: + ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: output = self.collective_rpc( # type: ignore[call-overload] "sample_tokens", args=(grammar_output,), non_block=non_block ) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 7e8ebe25c4603..649875fe8b7c1 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -124,9 +124,7 @@ class MultiprocExecutor(Executor): # Set multiprocessing envs set_multiprocessing_worker_envs() - # Multiprocessing-based executor does not support multi-node setting. - # Since it only works for single node, we can use the loopback address - # get_loopback_ip() for communication. + # use the loopback address get_loopback_ip() for communication. distributed_init_method = get_distributed_init_method( get_loopback_ip(), get_open_port() ) @@ -294,8 +292,8 @@ class MultiprocExecutor(Executor): kwargs: dict | None = None, non_block: bool = False, unique_reply_rank: int | None = None, - kv_output_aggregator: KVOutputAggregator = None, - ) -> Any | list[Any] | Future[Any | list[Any]]: + kv_output_aggregator: KVOutputAggregator | None = None, + ) -> Any: """Returns single result if unique_reply_rank and/or kv_output_aggregator is provided, otherwise list.""" assert self.rpc_broadcast_mq is not None, ( @@ -476,6 +474,8 @@ class WorkerProc: """Wrapper that runs one Worker in a separate process.""" READY_STR = "READY" + rpc_broadcast_mq: MessageQueue | None + worker_response_mq: MessageQueue | None def _init_message_queues( self, input_shm_handle: Handle, vllm_config: VllmConfig @@ -487,7 +487,7 @@ class WorkerProc: ) # Initializes a message queue for sending the model output - self.worker_response_mq: MessageQueue = MessageQueue(1, 1) + self.worker_response_mq = MessageQueue(1, 1) self.peer_response_handles = [] else: # Initialize remote MessageQueue for receiving SchedulerOutput across nodes @@ -706,7 +706,7 @@ class WorkerProc: death_pipe.recv() except EOFError: # Parent process has exited, terminate this worker - logger.info("Parent process exited, terminating worker") + logger.info_once("Parent process exited, terminating worker") # Send signal to self to trigger clean shutdown shutdown_event.set() except Exception as e: @@ -720,6 +720,7 @@ class WorkerProc: try: reader.close() worker = WorkerProc(*args, **kwargs) + assert worker.worker_response_mq is not None # Send READY once we know everything is loaded ready_writer.send( @@ -804,6 +805,7 @@ class WorkerProc: def worker_busy_loop(self, cancel: threading.Event | None = None): """Main busy loop for Multiprocessing Workers""" + assert self.rpc_broadcast_mq is not None while True: method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( cancel=cancel, indefinite=True diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 406eafcd339b0..2fd64e5c2277c 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -413,7 +413,7 @@ class RayDistributedExecutor(Executor): self, grammar_output: "GrammarOutput | None", non_block: bool = False, - ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: + ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: """Execute the model on the Ray workers. The scheduler output to use should have been provided in @@ -428,7 +428,7 @@ class RayDistributedExecutor(Executor): """ scheduler_output = self.scheduler_output if scheduler_output is None: - return COMPLETED_NONE_FUTURE if non_block else None # noqa + return COMPLETED_NONE_FUTURE if non_block else None self.scheduler_output = None @@ -439,7 +439,7 @@ class RayDistributedExecutor(Executor): scheduler_output: SchedulerOutput, grammar_output: "GrammarOutput | None", non_block: bool = False, - ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: + ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: # Build the compiled DAG for the first time. if self.forward_dag is None: # type: ignore self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 095d3d1dac21b..b8ca922554304 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -67,7 +67,7 @@ class UniProcExecutor(Executor): kwargs: dict | None = None, non_block: bool = False, single_value: bool = False, - ) -> Any | list[Any] | Future[Any | list[Any]]: + ) -> Any: if kwargs is None: kwargs = {} @@ -79,10 +79,13 @@ class UniProcExecutor(Executor): result = run_method(self.driver_worker, method, args, kwargs) if isinstance(result, AsyncModelRunnerOutput): if (async_thread := self.async_output_thread) is not None: - get_output = result.get_output - if not single_value: - get_output = lambda go=result.get_output: [go()] - return async_thread.submit(get_output) + if single_value: + return async_thread.submit(result.get_output) + + def get_output_list() -> list[Any]: + return [result.get_output()] + + return async_thread.submit(get_output_list) result = result.get_output() future = Future[Any]() future.set_result(result if single_value else [result]) diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py index 2f2e85c0ff332..e1cf7b14a785c 100644 --- a/vllm/v1/kv_offload/cpu.py +++ b/vllm/v1/kv_offload/cpu.py @@ -13,7 +13,7 @@ from vllm.v1.kv_offload.backends.cpu import CPUBackend from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.spec import OffloadingSpec -from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler +from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers from vllm.v1.kv_offload.worker.worker import OffloadingHandler @@ -32,7 +32,7 @@ class CPUOffloadingSpec(OffloadingSpec): self._manager: OffloadingManager | None = None # worker-side - self._handler: OffloadingHandler | None = None + self._handlers: CpuGpuOffloadingHandlers | None = None self.eviction_policy: str = self.extra_config.get("eviction_policy", "lru") @@ -67,13 +67,13 @@ class CPUOffloadingSpec(OffloadingSpec): kv_caches: dict[str, torch.Tensor], attn_backends: dict[str, type[AttentionBackend]], ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: - if not self._handler: + if not self._handlers: if not current_platform.is_cuda_alike(): raise Exception( "CPU Offloading is currently only supported on CUDA-alike GPUs" ) - self._handler = CpuGpuOffloadingHandler( + self._handlers = CpuGpuOffloadingHandlers( attn_backends=attn_backends, gpu_block_size=self.gpu_block_size, cpu_block_size=self.offloaded_block_size, @@ -81,6 +81,6 @@ class CPUOffloadingSpec(OffloadingSpec): gpu_caches=kv_caches, ) - assert self._handler is not None - yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler - yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler + assert self._handlers is not None + yield GPULoadStoreSpec, CPULoadStoreSpec, self._handlers.gpu_to_cpu_handler + yield CPULoadStoreSpec, GPULoadStoreSpec, self._handlers.cpu_to_gpu_handler diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index 461458c1f6ce8..42ae4f1413ad0 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import deque import numpy as np import torch @@ -8,7 +9,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.utils.platform_utils import is_pin_memory_available -from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.mediums import BlockIDsLoadStoreSpec from vllm.v1.kv_offload.worker.worker import ( OffloadingHandler, TransferResult, @@ -51,7 +52,123 @@ def expand_block_ids( output_idx = output_end_idx -class CpuGpuOffloadingHandler(OffloadingHandler): +class SingleDirectionOffloadingHandler(OffloadingHandler): + """ + SingleDirectionOffloadingHandler handles transfers for a single direction, + either CPU->GPU or GPU->CPU. + Transfers are guaranteed to be executed in order of their submission. + Each transfer uses a unique CUDA stream, and its stream will start + executing only after the streams of previous transfers have finished. + """ + + def __init__( + self, + src_tensors: list[torch.Tensor], + dst_tensors: list[torch.Tensor], + kv_dim_before_num_blocks: list[bool], + src_block_size_factor: int, + dst_block_size_factor: int, + priority: int, + ): + """ + Initialize a SingleDirectionOffloadingHandler. + + Args: + src_tensors: list of KV cache tensors to copy from. + dst_tensors: list of KV cache tensors to copy to. + Order should match src_tensors. + kv_dim_before_num_blocks: list of bools, indicating + whether the respective KV cache tensor has a KV + dimension before its num_blocks dimension. + e.g. (2, num_blocks, ...) + src_block_size_factor: The number of kernel blocks + per KV block in a source tensor. + dst_block_size_factor: The number of kernel blocks + per KV block in a destination tensor. + priority: The priority of the backing CUDA streams. + Lower numbers indicate higher priority. + """ + assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks) + + self.src_tensors: list[torch.Tensor] = src_tensors + self.dst_tensors: list[torch.Tensor] = dst_tensors + self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks + self.src_block_size_factor: int = src_block_size_factor + self.dst_block_size_factor: int = dst_block_size_factor + self.priority = priority + + # queue of transfers (job_id, stream, event) + self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque() + # list of CUDA streams available for re-use + self._stream_pool: list[torch.cuda.Stream] = [] + # list of CUDA events available for re-use + self._event_pool: list[torch.Event] = [] + + def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: + src_spec, dst_spec = transfer_spec + assert isinstance(src_spec, BlockIDsLoadStoreSpec) + assert isinstance(dst_spec, BlockIDsLoadStoreSpec) + + src_blocks = src_spec.block_ids + dst_blocks = dst_spec.block_ids + assert src_blocks.ndim == 1 + assert dst_blocks.ndim == 1 + + src_sub_block_count = src_blocks.size * self.src_block_size_factor + dst_sub_block_count = dst_blocks.size * self.dst_block_size_factor + src_sub_blocks_to_skip = -dst_blocks.size % self.src_block_size_factor + + assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip + + src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64) + expand_block_ids( + src_blocks, + self.src_block_size_factor, + src_to_dst[:, 0], + skip_count=src_sub_blocks_to_skip, + ) + expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1]) + src_to_dst_tensor = torch.from_numpy(src_to_dst) + + stream = ( + self._stream_pool.pop() + if self._stream_pool + else torch.cuda.Stream(priority=self.priority) + ) + event = self._event_pool.pop() if self._event_pool else torch.Event() + if self._transfers: + _, _, last_event = self._transfers[-1] + # assure job will start only after the previous one completes + stream.wait_event(last_event) + with torch.cuda.stream(stream): + for src_tensor, dst_tensor, kv_dim in zip( + self.src_tensors, self.dst_tensors, self.kv_dim_before_num_blocks + ): + if kv_dim: + src_key_cache, src_value_cache = src_tensor + dst_key_cache, dst_value_cache = dst_tensor + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor) + else: + ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor) + event.record(stream) + + self._transfers.append((job_id, stream, event)) + + # success + return True + + def get_finished(self) -> list[TransferResult]: + results: list[TransferResult] = [] + while self._transfers and self._transfers[0][2].query(): + job_id, stream, event = self._transfers.popleft() + results.append((job_id, True)) + self._stream_pool.append(stream) + self._event_pool.append(event) + return results + + +class CpuGpuOffloadingHandlers: def __init__( self, gpu_block_size: int, @@ -60,27 +177,20 @@ class CpuGpuOffloadingHandler(OffloadingHandler): gpu_caches: dict[str, torch.Tensor], attn_backends: dict[str, type[AttentionBackend]], ): + assert gpu_caches assert cpu_block_size % gpu_block_size == 0 - self.block_size_factor = cpu_block_size // gpu_block_size - - # cuda streams for gpu->cpu and cpu->gpu - self.d2h_stream = torch.cuda.Stream() - self.h2d_stream = torch.cuda.Stream() - - # job_id -> transfer cuda event - self.transfer_events: dict[int, torch.Event] = {} - # list of cuda events available for re-use - self.events_pool: list[torch.Event] = [] + block_size_factor = cpu_block_size // gpu_block_size pin_memory = is_pin_memory_available() # allocate cpu tensors logger.info("Allocating %d CPU tensors...", len(gpu_caches)) - self.gpu_tensors: list[torch.Tensor] = [] - self.cpu_tensors: list[torch.Tensor] = [] - self.kv_dim_before_num_blocks: list[bool] = [] + gpu_tensors: list[torch.Tensor] = [] + cpu_tensors: list[torch.Tensor] = [] + kv_dim_before_num_blocks: list[bool] = [] + kernel_block_size: int | None = None for layer_name, gpu_tensor in gpu_caches.items(): - self.gpu_tensors.append(gpu_tensor) + gpu_tensors.append(gpu_tensor) gpu_shape = gpu_tensor.shape attn_backend = attn_backends[layer_name] @@ -88,16 +198,21 @@ class CpuGpuOffloadingHandler(OffloadingHandler): num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 ) + has_layers_dim = False if len(gpu_shape) != len(test_shape): # cross-layers tensor # shape is (num_blocks, ...) assert len(gpu_shape) == len(test_shape) + 1 num_blocks_idx = 0 - self.kv_dim_before_num_blocks.append(False) + has_layers_dim = True + kv_dim_before_num_blocks.append(False) + + # prepend a dummy num_layers=80 to test_shape + test_shape = (80,) + test_shape elif test_shape[0] == 1234: # shape is (num_blocks, ...) num_blocks_idx = 0 - self.kv_dim_before_num_blocks.append(False) + kv_dim_before_num_blocks.append(False) else: # shape should be (2, num_blocks, ...) assert test_shape[0] == 2 @@ -105,13 +220,32 @@ class CpuGpuOffloadingHandler(OffloadingHandler): assert gpu_shape[0] == 2 num_blocks_idx = 1 - self.kv_dim_before_num_blocks.append(True) + kv_dim_before_num_blocks.append(True) + + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=has_layers_dim + ) + assert len(kv_cache_stride_order) == len(gpu_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(gpu_shape))) + + # permute test_shape according to stride_order + test_shape = tuple(test_shape[i] for i in kv_cache_stride_order) + + # find block_size (16) dimension index + block_size_idx = test_shape.index(16) + if kernel_block_size is not None: + assert kernel_block_size == gpu_shape[block_size_idx] + else: + kernel_block_size = gpu_shape[block_size_idx] + assert gpu_block_size % kernel_block_size == 0 cpu_shape = list(gpu_shape) - cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor + cpu_shape[num_blocks_idx] = num_cpu_blocks * block_size_factor logger.debug("Allocating CPU tensor of shape %r", cpu_shape) - self.cpu_tensors.append( + cpu_tensors.append( torch.zeros( cpu_shape, dtype=gpu_tensor.dtype, @@ -120,72 +254,27 @@ class CpuGpuOffloadingHandler(OffloadingHandler): ) ) - def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: - src_spec, dst_spec = spec - if isinstance(src_spec, CPULoadStoreSpec): - assert isinstance(dst_spec, GPULoadStoreSpec) - stream = self.h2d_stream - src_tensors = self.cpu_tensors - dst_tensors = self.gpu_tensors - src_block_size_factor = self.block_size_factor - dst_block_size_factor = 1 - else: - assert isinstance(src_spec, GPULoadStoreSpec) - assert isinstance(dst_spec, CPULoadStoreSpec) - stream = self.d2h_stream - src_tensors = self.gpu_tensors - dst_tensors = self.cpu_tensors - src_block_size_factor = 1 - dst_block_size_factor = self.block_size_factor + assert kernel_block_size is not None + gpu_block_size_factor = gpu_block_size // kernel_block_size + cpu_block_size_factor = cpu_block_size // kernel_block_size - src_blocks = src_spec.block_ids - dst_blocks = dst_spec.block_ids - assert src_blocks.ndim == 1 - assert dst_blocks.ndim == 1 + # TODO (orozery): adapt swap_blocks to support gpu_block_size_factor + assert gpu_block_size_factor == 1 - src_sub_block_count = src_blocks.size * src_block_size_factor - dst_sub_block_count = dst_blocks.size * dst_block_size_factor - src_sub_blocks_to_skip = -dst_blocks.size % src_block_size_factor - - assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip - - src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64) - expand_block_ids( - src_blocks, - src_block_size_factor, - src_to_dst[:, 0], - skip_count=src_sub_blocks_to_skip, + self.gpu_to_cpu_handler = SingleDirectionOffloadingHandler( + src_tensors=gpu_tensors, + dst_tensors=cpu_tensors, + kv_dim_before_num_blocks=kv_dim_before_num_blocks, + src_block_size_factor=gpu_block_size_factor, + dst_block_size_factor=cpu_block_size_factor, + priority=1, ) - expand_block_ids(dst_blocks, dst_block_size_factor, src_to_dst[:, 1]) - src_to_dst_tensor = torch.from_numpy(src_to_dst) - event = self.events_pool.pop() if self.events_pool else torch.Event() - with torch.cuda.stream(stream): - for src_tensor, dst_tensor, kv_dim in zip( - src_tensors, dst_tensors, self.kv_dim_before_num_blocks - ): - if kv_dim: - src_key_cache = src_tensor[0] - dst_key_cache = dst_tensor[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor) - src_value_cache = src_tensor[1] - dst_value_cache = dst_tensor[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor) - else: - ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor) - event.record(stream) - - self.transfer_events[job_id] = event - - # success - return True - - def get_finished(self) -> list[TransferResult]: - results: list[TransferResult] = [] - for job_id, event in self.transfer_events.items(): - if event.query(): - results.append((job_id, True)) - self.events_pool.append(event) - for job_id, _ in results: - del self.transfer_events[job_id] - return results + self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler( + src_tensors=cpu_tensors, + dst_tensors=gpu_tensors, + kv_dim_before_num_blocks=kv_dim_before_num_blocks, + src_block_size_factor=cpu_block_size_factor, + dst_block_size_factor=gpu_block_size_factor, + priority=-1, + ) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 429cee3b5af10..9eaee1bb97bb9 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -10,6 +10,7 @@ from typing import TypeAlias from prometheus_client import Counter, Gauge, Histogram import vllm.envs as envs +from vllm.compilation.cuda_graph import CUDAGraphLogging from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorLogging, @@ -106,6 +107,12 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging = SpecDecodingLogging() kv_transfer_config = self.vllm_config.kv_transfer_config self.kv_connector_logging = KVConnectorLogging(kv_transfer_config) + self.cudagraph_logging = None + if self.vllm_config.observability_config.cudagraph_metrics: + self.cudagraph_logging = CUDAGraphLogging( + self.vllm_config.compilation_config.cudagraph_mode, + self.vllm_config.compilation_config.cudagraph_capture_sizes, + ) self.last_prompt_throughput: float = 0.0 self.last_generation_throughput: float = 0.0 self.engine_is_idle = False @@ -161,6 +168,11 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) if kv_connector_stats := scheduler_stats.kv_connector_stats: self.kv_connector_logging.observe(kv_connector_stats) + if ( + self.cudagraph_logging is not None + and scheduler_stats.cudagraph_stats is not None + ): + self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats) if not self.aggregated: self.last_scheduler_stats = scheduler_stats if mm_cache_stats: @@ -240,6 +252,8 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn) + if self.cudagraph_logging is not None: + self.cudagraph_logging.log(log_fn=log_fn) def log_engine_initialized(self): if self.vllm_config.cache_config.num_gpu_blocks: @@ -375,6 +389,9 @@ class PrometheusStatLogger(AggregateStatLoggerBase): # Use this flag to hide metrics that were deprecated in # a previous release and which will be removed future self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics + self.kv_cache_metrics_enabled = ( + vllm_config.observability_config.kv_cache_metrics + ) labelnames = ["model_name", "engine"] model_name = vllm_config.model_config.served_model_name @@ -853,6 +870,92 @@ class PrometheusStatLogger(AggregateStatLoggerBase): histogram_decode_time_request, engine_indexes, model_name ) + histogram_prefill_kv_computed_request = self._histogram_cls( + name="vllm:request_prefill_kv_computed_tokens", + documentation=( + "Histogram of new KV tokens computed during prefill " + "(excluding cached tokens)." + ), + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames, + ) + self.histogram_prefill_kv_computed_request = make_per_engine( + histogram_prefill_kv_computed_request, engine_indexes, model_name + ) + + # + # KV Cache residency metrics + # + if self.kv_cache_metrics_enabled: + kv_cache_residency_buckets = [ + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1, + 2, + 5, + 10, + 20, + 30, + 60, + 120, + 300, + 600, + 1200, + 1800, + ] + + histogram_kv_block_lifetime = self._histogram_cls( + name="vllm:kv_block_lifetime_seconds", + documentation=( + "Histogram of KV cache block lifetime from allocation to eviction. " + "Sampled metrics (controlled by --kv-cache-metrics-sample)." + ), + buckets=kv_cache_residency_buckets, + labelnames=labelnames, + ) + self.histogram_kv_block_lifetime = make_per_engine( + histogram_kv_block_lifetime, engine_indexes, model_name + ) + + histogram_kv_block_idle_before_evict = self._histogram_cls( + name="vllm:kv_block_idle_before_evict_seconds", + documentation=( + "Histogram of idle time before KV cache block eviction. " + "Sampled metrics (controlled by --kv-cache-metrics-sample)." + ), + buckets=kv_cache_residency_buckets, + labelnames=labelnames, + ) + self.histogram_kv_block_idle_before_evict = make_per_engine( + histogram_kv_block_idle_before_evict, engine_indexes, model_name + ) + + histogram_kv_block_reuse_gap = self._histogram_cls( + name="vllm:kv_block_reuse_gap_seconds", + documentation=( + "Histogram of time gaps between consecutive KV cache block " + "accesses. Only the most recent accesses are recorded " + "(ring buffer). Sampled metrics (controlled by " + "--kv-cache-metrics-sample)." + ), + buckets=kv_cache_residency_buckets, + labelnames=labelnames, + ) + self.histogram_kv_block_reuse_gap = make_per_engine( + histogram_kv_block_reuse_gap, engine_indexes, model_name + ) + else: + self.histogram_kv_block_lifetime = {} + self.histogram_kv_block_idle_before_evict = {} + self.histogram_kv_block_reuse_gap = {} + # # LoRA metrics # @@ -862,7 +965,10 @@ class PrometheusStatLogger(AggregateStatLoggerBase): self.gauge_lora_info: Gauge | None = None if vllm_config.lora_config is not None: if len(self.engine_indexes) > 1: - raise NotImplementedError("LoRA in DP mode is not supported yet.") + logger.warning( + "vllm:lora_requests_info prometheus metrics may be " + "incorrect/misleading with data parallel deployments." + ) self.labelname_max_lora = "max_lora" self.labelname_waiting_lora_adapters = "waiting_lora_adapters" self.labelname_running_lora_adapters = "running_lora_adapters" @@ -944,6 +1050,20 @@ class PrometheusStatLogger(AggregateStatLoggerBase): scheduler_stats.kv_connector_stats, engine_idx ) + if ( + self.kv_cache_metrics_enabled + and scheduler_stats.kv_cache_eviction_events + ): + lifetime_hist = self.histogram_kv_block_lifetime[engine_idx] + idle_hist = self.histogram_kv_block_idle_before_evict[engine_idx] + reuse_hist = self.histogram_kv_block_reuse_gap[engine_idx] + + for event in scheduler_stats.kv_cache_eviction_events: + lifetime_hist.observe(event.lifetime_seconds) + idle_hist.observe(event.idle_seconds) + for gap in event.reuse_gaps_seconds: + reuse_hist.observe(gap) + if self.gauge_lora_info is not None: running_lora_adapters = ",".join( scheduler_stats.running_lora_adapters.keys() @@ -1011,6 +1131,13 @@ class PrometheusStatLogger(AggregateStatLoggerBase): self.histogram_decode_time_request[engine_idx].observe( finished_request.decode_time ) + # Calculate prefill KV compute (excludes cached tokens) + prefill_kv_computed = finished_request.num_prompt_tokens - max( + finished_request.num_cached_tokens, 0 + ) + self.histogram_prefill_kv_computed_request[engine_idx].observe( + prefill_kv_computed + ) self.histogram_num_prompt_tokens_request[engine_idx].observe( finished_request.num_prompt_tokens ) diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index a319ffb1d2573..4b46669d5d3bf 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -7,37 +7,55 @@ from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.spec_decode.metrics import SpecDecodingProm try: + from ray import serve as ray_serve from ray.util import metrics as ray_metrics from ray.util.metrics import Metric except ImportError: ray_metrics = None + ray_serve = None import regex as re +def _get_replica_id() -> str | None: + """Get the current Ray Serve replica ID, or None if not in a Serve context.""" + if ray_serve is None: + return None + try: + return ray_serve.get_replica_context().replica_id.unique_id + except ray_serve.exceptions.RayServeException: + return None + + class RayPrometheusMetric: def __init__(self): if ray_metrics is None: raise ImportError("RayPrometheusMetric requires Ray to be installed.") - self.metric: Metric = None + @staticmethod + def _get_tag_keys(labelnames: list[str] | None) -> tuple[str, ...]: + labels = list(labelnames) if labelnames else [] + labels.append("ReplicaId") + return tuple(labels) + def labels(self, *labels, **labelskwargs): + if labels: + # -1 because ReplicaId was added automatically + expected = len(self.metric._tag_keys) - 1 + if len(labels) != expected: + raise ValueError( + "Number of labels must match the number of tag keys. " + f"Expected {expected}, got {len(labels)}" + ) + labelskwargs.update(zip(self.metric._tag_keys, labels)) + + labelskwargs["ReplicaId"] = _get_replica_id() or "" + if labelskwargs: for k, v in labelskwargs.items(): if not isinstance(v, str): labelskwargs[k] = str(v) - self.metric.set_default_tags(labelskwargs) - - if labels: - if len(labels) != len(self.metric._tag_keys): - raise ValueError( - "Number of labels must match the number of tag keys. " - f"Expected {len(self.metric._tag_keys)}, got {len(labels)}" - ) - - self.metric.set_default_tags(dict(zip(self.metric._tag_keys, labels))) - return self @staticmethod @@ -71,10 +89,14 @@ class RayGaugeWrapper(RayPrometheusMetric): # "mostrecent", "all", "sum" do not apply. This logic can be manually # implemented at the observability layer (Prometheus/Grafana). del multiprocess_mode - labelnames_tuple = tuple(labelnames) if labelnames else None + + tag_keys = self._get_tag_keys(labelnames) name = self._get_sanitized_opentelemetry_name(name) + self.metric = ray_metrics.Gauge( - name=name, description=documentation, tag_keys=labelnames_tuple + name=name, + description=documentation, + tag_keys=tag_keys, ) def set(self, value: int | float): @@ -95,10 +117,12 @@ class RayCounterWrapper(RayPrometheusMetric): documentation: str | None = "", labelnames: list[str] | None = None, ): - labelnames_tuple = tuple(labelnames) if labelnames else None + tag_keys = self._get_tag_keys(labelnames) name = self._get_sanitized_opentelemetry_name(name) self.metric = ray_metrics.Counter( - name=name, description=documentation, tag_keys=labelnames_tuple + name=name, + description=documentation, + tag_keys=tag_keys, ) def inc(self, value: int | float = 1.0): @@ -118,13 +142,14 @@ class RayHistogramWrapper(RayPrometheusMetric): labelnames: list[str] | None = None, buckets: list[float] | None = None, ): - labelnames_tuple = tuple(labelnames) if labelnames else None + tag_keys = self._get_tag_keys(labelnames) name = self._get_sanitized_opentelemetry_name(name) + boundaries = buckets if buckets else [] self.metric = ray_metrics.Histogram( name=name, description=documentation, - tag_keys=labelnames_tuple, + tag_keys=tag_keys, boundaries=boundaries, ) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 4e9db98db0bc2..a0cc58d0a64e8 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import vllm.envs as envs +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.v1.spec_decode.metrics import SpecDecodingStats if TYPE_CHECKING: @@ -150,6 +151,15 @@ class MultiModalCacheStats(BaseCacheStats): """ +@dataclass +class KVCacheEvictionEvent: + """Single KV cache block eviction sample.""" + + lifetime_seconds: float + idle_seconds: float + reuse_gaps_seconds: tuple[float, ...] + + @dataclass class SchedulerStats: """Stats associated with the scheduler.""" @@ -166,12 +176,16 @@ class SchedulerStats: prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats) connector_prefix_cache_stats: PrefixCacheStats | None = None + kv_cache_eviction_events: list[KVCacheEvictionEvent] = field(default_factory=list) + spec_decoding_stats: SpecDecodingStats | None = None kv_connector_stats: dict[str, Any] | None = None waiting_lora_adapters: dict[str, int] = field(default_factory=dict) running_lora_adapters: dict[str, int] = field(default_factory=dict) + cudagraph_stats: CUDAGraphStat | None = None + @dataclass class RequestStateStats: @@ -210,6 +224,7 @@ class FinishedRequestStats: decode_time: float = 0.0 mean_time_per_output_token: float = 0.0 is_corrupted: bool = False + num_cached_tokens: int = 0 class IterationStats: @@ -316,6 +331,7 @@ class IterationStats: num_prompt_tokens: int, max_tokens_param: int | None, req_stats: RequestStateStats, + num_cached_tokens: int = 0, ): e2e_latency = self._time_since(req_stats.arrival_time) @@ -353,6 +369,7 @@ class IterationStats: decode_time=decode_time, mean_time_per_output_token=mean_time_per_output_token, is_corrupted=req_stats.is_corrupted, + num_cached_tokens=num_cached_tokens, ) self.finished_requests.append(finished_req) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 8110deb5a610b..bea9e5846de13 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -8,12 +8,15 @@ from typing import TYPE_CHECKING, NamedTuple import numpy as np import torch +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: + from vllm.distributed.kv_events import KVConnectorKVEvents from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats else: KVConnectorStats = object + KVConnectorKVEvents = object class LogprobsLists(NamedTuple): @@ -88,7 +91,7 @@ class LogprobsTensors(NamedTuple): # [num_reqs, <dynamic>] # The shape of each element depends on the pooler used -PoolerOutput = torch.Tensor | list[torch.Tensor] +PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None @dataclass @@ -107,6 +110,7 @@ class KVConnectorOutput: finished_sending: set[str] | None = None finished_recving: set[str] | None = None kv_connector_stats: KVConnectorStats | None = None + kv_cache_events: KVConnectorKVEvents | None = None # IDs of externally computed KV blocks that failed to load. # Requests referencing these blocks should be rescheduled to recompute them invalid_block_ids: set[int] = field(default_factory=set) @@ -122,6 +126,7 @@ class KVConnectorOutput: not self.finished_sending and not self.finished_recving and not self.kv_connector_stats + and not self.kv_cache_events and not self.invalid_block_ids ) @@ -169,6 +174,9 @@ class ModelRunnerOutput: # req_id -> num_nans_in_logits num_nans_in_logits: dict[str, int] | None = None + # information related to cudagraph execution + cudagraph_stats: CUDAGraphStat | None = None + # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 7bd2c7415dafe..acd1a00e87553 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import torch from vllm.pooling_params import PoolingParams +from vllm.tasks import PoolingTask from vllm.utils.platform_utils import is_pin_memory_available pin_memory = is_pin_memory_available() @@ -16,6 +17,7 @@ class PoolingCursor: first_token_indices_gpu: torch.Tensor last_token_indices_gpu: torch.Tensor prompt_lens_cpu: torch.Tensor + seq_lens_cpu: torch.Tensor num_scheduled_tokens_cpu: torch.Tensor def __getitem__(self, indices: slice): @@ -24,12 +26,25 @@ class PoolingCursor: first_token_indices_gpu=self.first_token_indices_gpu[indices], last_token_indices_gpu=self.last_token_indices_gpu[indices], prompt_lens_cpu=self.prompt_lens_cpu[indices], + seq_lens_cpu=self.seq_lens_cpu[indices], num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], ) def is_partial_prefill(self): return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) + def is_finished(self): + return self.prompt_lens_cpu == self.seq_lens_cpu + + +class PoolingStates: + def __init__(self): + # for chunked prefill with ALL pooling + self.hidden_states_cache: list[torch.Tensor] = [] + + def clean(self): + self.hidden_states_cache.clear() + @dataclass class PoolingMetadata: @@ -38,8 +53,21 @@ class PoolingMetadata: prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: torch.Tensor | None pooling_params: list[PoolingParams] + pooling_states: list[PoolingStates] pooling_cursor: PoolingCursor | None = None + def __post_init__(self) -> None: + pooling_params = self.pooling_params + + tasks: list[PoolingTask] = [ + task + for pooling_param in pooling_params + if (task := pooling_param.task) is not None + ] + assert len(pooling_params) == len(tasks) + + self.tasks = tasks + def __getitem__(self, indices: slice): return PoolingMetadata( prompt_lens=self.prompt_lens[indices], @@ -47,21 +75,36 @@ class PoolingMetadata: if self.prompt_token_ids is None else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], + pooling_states=self.pooling_states[indices], pooling_cursor=None if self.pooling_cursor is None else self.pooling_cursor[indices], ) + def get_prompt_token_ids(self) -> list[torch.Tensor]: + prompt_token_ids = self.prompt_token_ids + assert prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`" + ) + + return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] + def build_pooling_cursor( - self, num_scheduled_tokens: list[int], device: torch.device + self, + num_scheduled_tokens: list[int], + seq_lens_cpu: torch.Tensor, + device: torch.device, ): self.pooling_cursor = build_pooling_cursor( - num_scheduled_tokens, self.prompt_lens, device + num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device ) def build_pooling_cursor( - num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device + num_scheduled_tokens: list[int], + seq_lens_cpu: torch.Tensor, + prompt_lens: torch.Tensor, + device: torch.device, ): assert len(prompt_lens) == len(num_scheduled_tokens) @@ -78,5 +121,6 @@ def build_pooling_cursor( first_token_indices_gpu=cumsum[:n_seq], last_token_indices_gpu=cumsum[1:] - 1, prompt_lens_cpu=prompt_lens, + seq_lens_cpu=seq_lens_cpu, num_scheduled_tokens_cpu=num_scheduled_tokens_cpu, ) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 366cdadf5a583..a775e840e841c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -93,7 +93,12 @@ class Request: if self.prompt_token_ids is not None else [0] * self.num_prompt_tokens ) - self.num_output_placeholders = 0 # Used in async scheduling. + + # Used in async scheduling. + self.num_output_placeholders = 0 + # Used in forced preemption (reset_prefix_cache) with async scheduling. + self.discard_latest_async_tokens = False + self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 self.cache_salt: str | None = cache_salt @@ -222,6 +227,19 @@ class Request: events, self.events = self.events, [] return events + def __lt__(self, other: "Request") -> bool: + """ + Compare two requests based on priority, arrival time, and request ID. + Used in priority scheduling. + """ + if self.priority != other.priority: + return self.priority < other.priority + if self.arrival_time != other.arrival_time: + return self.arrival_time < other.arrival_time + if self.request_id != other.request_id: + return self.request_id < other.request_id + return id(self) < id(other) + class RequestStatus(enum.IntEnum): """Status of a request.""" @@ -237,6 +255,7 @@ class RequestStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_ABORTED = enum.auto() FINISHED_IGNORED = enum.auto() + FINISHED_ERROR = enum.auto() def __str__(self): return self.name @@ -259,4 +278,5 @@ _FINISHED_REASON_MAP = { RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH, RequestStatus.FINISHED_ABORTED: FinishReason.ABORT, RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH, + RequestStatus.FINISHED_ERROR: FinishReason.ERROR, } diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 4ee7dc2880c8c..82743f72b0310 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -110,7 +110,7 @@ class MinPLogitsProcessor(LogitsProcessor): # Identify valid tokens using threshold comparison invalid_token_mask = probability_values < adjusted_min_p # Apply mask using boolean indexing - logits[invalid_token_mask] = -float("inf") + logits.masked_fill_(invalid_token_mask, -float("inf")) return logits @@ -178,6 +178,10 @@ class MinTokensLogitsProcessor(LogitsProcessor): self._device_tensor([], torch.int32), ) + self.neg_inf_tensor = torch.tensor( + -float("inf"), dtype=torch.float32, device=self.device + ) + def is_argmax_invariant(self) -> bool: """By censoring stop tokens, min-tokens can change the outcome of the argmax operation in greedy sampling.""" @@ -229,7 +233,7 @@ class MinTokensLogitsProcessor(LogitsProcessor): def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.min_toks: # Inhibit EOS token for requests which have not reached min length - logits[self.logits_slice] = -float("inf") + logits.index_put_(self.logits_slice, self.neg_inf_tensor) return logits diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index ccaf07e18c468..50b91d8292ee8 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -145,7 +145,7 @@ class RejectionSampler(nn.Module): ) logprobs_tensors = None - if sampling_metadata.max_num_logprobs: + if sampling_metadata.max_num_logprobs is not None: logprobs_tensors = self._get_logprobs_tensors( sampling_metadata.max_num_logprobs, metadata, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 0a6806390451d..a3c30e368b828 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -27,7 +27,6 @@ from vllm.multimodal.inputs import ( MultiModalFieldConfig, MultiModalFieldElem, MultiModalFlatField, - MultiModalKwargs, MultiModalKwargsItem, MultiModalKwargsItems, MultiModalSharedField, @@ -176,9 +175,6 @@ class MsgpackEncoder: if isinstance(obj, MultiModalKwargsItems): return self._encode_mm_items(obj) - if isinstance(obj, MultiModalKwargs): - return self._encode_mm_kwargs(obj) - if isinstance(obj, UtilityResult): result = obj.result if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: @@ -259,11 +255,6 @@ class MsgpackEncoder: "field": self._encode_mm_field(elem.field), } - def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]: - return { - modality: self._encode_nested_tensors(data) for modality, data in kw.items() - } - def _encode_nested_tensors(self, nt: NestedTensors) -> Any: if isinstance(nt, torch.Tensor): return self._encode_tensor(nt) @@ -278,10 +269,11 @@ class MsgpackEncoder: name = MMF_CLASS_TO_FACTORY.get(field.__class__) if not name: raise TypeError(f"Unsupported field type: {field.__class__}") + # We just need to copy all of the field values in order # which will be then used to reconstruct the field. - field_values = (getattr(field, f.name) for f in dataclasses.fields(field)) - return name, *field_values + factory_kw = {f.name: getattr(field, f.name) for f in dataclasses.fields(field)} + return name, factory_kw class MsgpackDecoder: @@ -325,8 +317,6 @@ class MsgpackDecoder: return self._decode_mm_item(obj) if issubclass(t, MultiModalKwargsItems): return self._decode_mm_items(obj) - if issubclass(t, MultiModalKwargs): - return self._decode_mm_kwargs(obj) if t is UtilityResult: return self._decode_utility_result(obj) return obj @@ -403,25 +393,17 @@ class MsgpackDecoder: obj["data"] = self._decode_nested_tensors(obj["data"]) # Reconstruct the field processor using MultiModalFieldConfig - factory_meth_name, *field_args = obj["field"] + factory_meth_name, factory_kw = obj["field"] factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) # Special case: decode the union "slices" field of # MultiModalFlatField if factory_meth_name == "flat": - field_args[0] = self._decode_nested_slices(field_args[0]) + factory_kw["slices"] = self._decode_nested_slices(factory_kw["slices"]) - obj["field"] = factory_meth(None, *field_args).field + obj["field"] = factory_meth("", **factory_kw).field return MultiModalFieldElem(**obj) - def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs: - return MultiModalKwargs( - { - modality: self._decode_nested_tensors(data) - for modality, data in obj.items() - } - ) - def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, (int, float)): # Although it violates NestedTensors type, MultiModalKwargs diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d7111d52dd8a1..65a0a88ec0f5d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -178,6 +178,12 @@ class EagleProposer: ) rocm_types.append(AiterFlashAttentionMetadata) + + # TRITON_MLA backend support for MLA models (e.g., DeepSeek) + from vllm.v1.attention.backends.mla.common import MLACommonMetadata + + rocm_types.append(MLACommonMetadata) + self.allowed_attn_types = tuple(rocm_types) # Parse the speculative token tree. @@ -440,16 +446,16 @@ class EagleProposer: # of main model. # Increment the sequence lengths. common_attn_metadata.seq_lens += 1 - # This is an out-of-place operation to avoid modifying the original tensor. - common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) - common_attn_metadata.num_computed_tokens_cpu = ( - common_attn_metadata.seq_lens_cpu - 1 - ) + # Also update the CPU-side shadow; NOTE: this is hacky and should be + # removed in when common_attn_metadata.seq_lens_cpu is deprecated. + if common_attn_metadata._seq_lens_cpu is not None: + common_attn_metadata._seq_lens_cpu += 1 + if common_attn_metadata._num_computed_tokens_cpu is not None: + common_attn_metadata._num_computed_tokens_cpu += 1 # Compute the slot mapping. if self.uses_mrope: @@ -656,8 +662,8 @@ class EagleProposer: query_start_loc=common_attn_metadata.query_start_loc, seq_lens=common_attn_metadata.seq_lens, query_start_loc_cpu=query_start_loc_cpu, - seq_lens_cpu=common_attn_metadata.seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + _seq_lens_cpu=common_attn_metadata._seq_lens_cpu, + _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -932,8 +938,8 @@ class EagleProposer: query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, - seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + _seq_lens_cpu=new_seq_lens_cpu, + _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -1016,6 +1022,10 @@ class EagleProposer: "Qwen3VLForConditionalGeneration", ]: self.model.config.image_token_index = target_model.config.image_token_id + elif self.get_model_name(target_model) == "PixtralForConditionalGeneration": + self.model.config.image_token_index = ( + target_model.config.vision_config.image_token_id + ) else: self.model.config.image_token_index = ( target_model.config.image_token_index @@ -1254,7 +1264,7 @@ class EagleProposer: num_tokens_padded: int, ) -> tuple[int, torch.Tensor]: # TODO(Flechman): support DBO ubatching - ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp( + should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp( num_tokens_unpadded=num_tokens_unpadded, parallel_config=self.vllm_config.parallel_config, allow_microbatching=False, @@ -1263,7 +1273,7 @@ class EagleProposer: uniform_decode=None, num_scheduled_tokens_per_request=None, ) - assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE" + assert not should_ubatch, "DBO ubatching not implemented for EAGLE" num_tokens_dp_padded = num_tokens_padded if num_toks_across_dp is not None: diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 12b903ccaca97..989478f348161 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -38,16 +38,16 @@ class MedusaProposer: self, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> list[list[int]]: + ) -> torch.Tensor: # Generate blocks and compute logits blocks = self.model(target_hidden_states) logits = self.model.compute_logits(blocks) - # Get draft tokens and transpose the result - # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU - # synchronization. - draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] - return [list(row) for row in zip(*draft_tokens)] + # Compute argmax for each Medusa head and stack into a single tensor + # Shape: [batch_size, num_heads] + draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1) + + return draft_tokens def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 10b3f0aa040e5..1273ca12c3600 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -196,9 +196,9 @@ def batch_propose_numba( k=k, ) - valid_ngram_num_drafts[i] = drafter_output.shape[0] + valid_ngram_num_drafts[idx] = drafter_output.shape[0] if len(drafter_output): - valid_ngram_draft[i, : drafter_output.shape[0]] = drafter_output + valid_ngram_draft[idx, : drafter_output.shape[0]] = drafter_output @jit(nopython=True) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 029129cf1a475..79ee4161e9dfa 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs +from vllm.tokenizers import cached_tokenizer_from_config from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import ( @@ -40,6 +40,16 @@ class StructuredOutputManager: self.reasoner: ReasoningParser | None = None self.vllm_config = vllm_config + # When in external_launcher mode, async grammar compilation causes deadlocks + # due to external_launcher mode having a scheduler for each TP rank. + # Async grammar compilation causes the WAITING_FOR_FSM → WAITING transition to + # happen at different times on different TP ranks, + # breaking the determinism assumption that external_launcher relies on. + self._use_async_grammar_compilation = ( + vllm_config.parallel_config.distributed_executor_backend + != "external_launcher" + ) + self._grammar_bitmask: torch.Tensor | None = None self._full_mask = torch.tensor(-1, dtype=torch.int32) @@ -61,7 +71,7 @@ class StructuredOutputManager: # of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) - self.tokenizer = init_tokenizer_from_configs( + self.tokenizer = cached_tokenizer_from_config( model_config=self.vllm_config.model_config ) reasoning_parser = ( @@ -138,10 +148,13 @@ class StructuredOutputManager: else: raise ValueError(f"Unsupported structured output backend: {backend}") - grammar = self.executor.submit(self._async_create_grammar, request) + if self._use_async_grammar_compilation: + grammar = self.executor.submit(self._create_grammar, request) + else: + grammar = self._create_grammar(request) # type: ignore[assignment] request.structured_output_request.grammar = grammar # type: ignore[assignment] - def _async_create_grammar( + def _create_grammar( self, request: Request, ) -> StructuredOutputGrammar: @@ -326,7 +339,9 @@ class StructuredOutputManager: return True # Check if reasoning ends in *this* step - if self.reasoner.is_reasoning_end(request.all_token_ids): + if self.reasoner.is_reasoning_end_streaming( + request.all_token_ids, request.all_token_ids[request.num_computed_tokens :] + ): # Reasoning just ended, so we shouldn't advance til # next pass structured_req.reasoning_ended = True diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index f8a2df43dd90e..9dd506880389a 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -10,7 +10,8 @@ import torch import vllm.envs from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.tokenizers import MistralTokenizer +from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer +from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, @@ -56,6 +57,27 @@ class XgrammarBackend(StructuredOutputBackend): stop_token_ids=stop_token_ids, add_prefix_space=True, ) + elif isinstance(self.tokenizer, DeepseekV32Tokenizer): + # copy from xgr.TokenizerInfo.from_huggingface() + # because we are using a custom tokenizer wrapper here. + vocab_dict = self.tokenizer.get_vocab() + tokenizer_vocab_size = max(len(vocab_dict), self.tokenizer.max_token_id + 1) + vocab_size = self.vocab_size or tokenizer_vocab_size + # maintain tokenizer's indexing + encoded_vocab = [""] * vocab_size + for token, idx in vocab_dict.items(): + if idx < vocab_size: + encoded_vocab[idx] = token + stop_token_ids = [self.tokenizer.eos_token_id] + backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str() + metadata = xgr.TokenizerInfo._detect_metadata_from_hf(backend_str) + tokenizer_info = xgr.TokenizerInfo( + encoded_vocab=encoded_vocab, + vocab_type=metadata["vocab_type"], + vocab_size=vocab_size, + stop_token_ids=stop_token_ids, + add_prefix_space=metadata["add_prefix_space"], + ) else: tokenizer_info = xgr.TokenizerInfo.from_huggingface( self.tokenizer, @@ -246,13 +268,7 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: # Unsupported keywords for objects if obj.get("type") == "object" and any( - key in obj - for key in ( - "minProperties", - "maxProperties", - "propertyNames", - "patternProperties", - ) + key in obj for key in ("patternProperties", "propertyNames") ): return True diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index ae42b33f80f88..cb5ad99cfbdf7 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -21,8 +21,8 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput if TYPE_CHECKING: import outlines_core as oc import transformers.file_utils as file_utils - import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2 import xgrammar as xgr + from transformers.convert_slow_tokenizer import bytes_to_unicode from vllm.tokenizers import TokenizerLike from vllm.v1.worker.gpu_input_batch import InputBatch @@ -30,10 +30,8 @@ else: xgr = LazyLoader("xgr", globals(), "xgrammar") oc = LazyLoader("oc", globals(), "outlines_core") file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils") - tokenization_gpt2 = LazyLoader( - "tokenization_gpt2", - globals(), - "transformers.models.gpt2.tokenization_gpt2", + bytes_to_unicode = LazyLoader( + "bytes_to_unicode", globals(), "transformers.convert_slow_tokenizer" ) TokenizerLike = object @@ -204,7 +202,7 @@ def _reduced_vocabulary( A Dict of token string -> equivalent token ids """ - unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()} + unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()} def convert_token_to_string(token: str) -> str: string = tokenizer.convert_tokens_to_string([token]) diff --git a/vllm/v1/worker/cp_utils.py b/vllm/v1/worker/cp_utils.py new file mode 100644 index 0000000000000..f666c739b0be7 --- /dev/null +++ b/vllm/v1/worker/cp_utils.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Any, cast + +from vllm.config import VllmConfig, get_layers_from_vllm_config + +if TYPE_CHECKING: + from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +else: + AttentionLayerBase = object + + +def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None: + pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + dcp_size = vllm_config.parallel_config.decode_context_parallel_size + interleave_size = vllm_config.parallel_config.cp_kv_cache_interleave_size + if pcp_size * dcp_size > 1: + layer_type = cast(type[Any], AttentionLayerBase) + layers = get_layers_from_vllm_config(vllm_config, layer_type) + for layer in layers.values(): + layer_impl = getattr(layer, "impl", None) + if layer_impl is None: + continue + if vllm_config.speculative_config is not None and interleave_size > 1: + assert layer_impl.supports_mtp_with_cp_non_trivial_interleave_size, ( + "MTP with cp_kv_cache_interleave_size > 1 is not " + f"supported in {layer_impl.__class__.__name__}." + ) + if dcp_size > 1: + assert layer_impl.need_to_return_lse_for_decode, ( + "DCP requires attention impls to return" + " the softmax lse for decode, but the impl " + f"{layer_impl.__class__.__name__} " + "does not return the softmax lse for decode." + ) + + if pcp_size > 1: + assert layer_impl.supports_pcp, ( + "PCP requires attention impls' support, " + f"but the impl {layer_impl.__class__.__name__} " + "does not support PCP." + ) diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index b080fea1d2dd6..e54b995ab908f 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.model_executor.utils import set_random_seed from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo +from vllm.profiler.wrapper import TorchProfilerWrapper from vllm.v1.worker.cpu_model_runner import CPUModelRunner from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment @@ -38,30 +39,17 @@ class CPUWorker(Worker): self.parallel_config.disable_custom_all_reduce = True - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + # Torch profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_config = vllm_config.profiler_config + if profiler_config.profiler == "torch": worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" - logger.info( - "Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir, + self.profiler = TorchProfilerWrapper( + profiler_config, + worker_name=worker_name, + local_rank=self.local_rank, + activities=["CPU"], ) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - ], - record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, - with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, worker_name=worker_name, use_gzip=False - ), - ) - else: - self.profiler = None def init_device(self): # Setup OpenMP threads affinity. @@ -202,9 +190,3 @@ class CPUWorker(Worker): self.profiler.start() else: self.profiler.stop() - if self.local_rank == 0: - logger.info( - self.profiler.key_averages().table( - sort_by="self_cpu_time_total", row_limit=50 - ) - ) diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 6539d72d81cb7..82de0cba9194b 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import numpy as np import torch import torch.distributed as dist @@ -9,11 +10,8 @@ from vllm.config import ParallelConfig from vllm.distributed.parallel_state import get_dp_group from vllm.logger import init_logger from vllm.v1.worker.ubatch_utils import ( - UBatchSlice, - UBatchSlices, check_ubatch_thresholds, - create_ubatch_slices, - is_second_ubatch_empty, + is_last_ubatch_empty, ) logger = init_logger(__name__) @@ -42,21 +40,23 @@ def _run_ar( should_dp_pad: bool, orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int, + cudagraph_mode: int, parallel_config: ParallelConfig, ) -> torch.Tensor: dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank device, group = _get_device_and_group(parallel_config) - tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32) + tensor = torch.zeros(5, dp_size, device=device, dtype=torch.int32) tensor[0][dp_rank] = orig_num_tokens_per_ubatch tensor[1][dp_rank] = padded_num_tokens_per_ubatch tensor[2][dp_rank] = 1 if should_ubatch else 0 tensor[3][dp_rank] = 1 if should_dp_pad else 0 + tensor[4][dp_rank] = cudagraph_mode dist.all_reduce(tensor, group=group) return tensor -def _post_process_ubatch(tensor: torch.Tensor) -> bool: +def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool: orig_num_tokens_tensor = tensor[0, :] padded_num_tokens_tensor = tensor[1, :] @@ -68,7 +68,7 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool: # there are no "empty" second ubatches orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) - if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): + if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches): logger.debug( "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens ) @@ -91,18 +91,13 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch return num_tokens_across_dp.cpu() -# This just pads the second ubatch slice out to the total number of tokens -# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding. -def _pad_out_ubatch_slice( - ubatch_slices: UBatchSlices, num_total_tokens: int -) -> UBatchSlices: - padded_second_token_slice = slice( - ubatch_slices[1].token_slice.start, num_total_tokens - ) - ubatch_slices[1] = UBatchSlice( - ubatch_slices[1].request_slice, padded_second_token_slice - ) - return ubatch_slices +def _post_process_cudagraph_mode(tensor: torch.Tensor) -> int: + """ + Synchronize cudagraph_mode across DP ranks by taking the minimum. + If any rank has NONE (0), all ranks use NONE. + This ensures all ranks send consistent values (all padded or all unpadded). + """ + return int(tensor[4, :].min().item()) def _synchronize_dp_ranks( @@ -110,8 +105,9 @@ def _synchronize_dp_ranks( num_tokens_padded: int, should_attempt_ubatching: bool, should_attempt_dp_padding: bool, + cudagraph_mode: int, parallel_config: ParallelConfig, -) -> tuple[bool, torch.Tensor | None]: +) -> tuple[bool, torch.Tensor | None, int]: """ 1. Decides if each DP rank is going to microbatch. Either all ranks run with microbatching or none of them do. @@ -120,10 +116,13 @@ def _synchronize_dp_ranks( When running microbatched or if should_attempt_dp_padding is True, all ranks will be padded out so that the run with the same number of tokens + 3. Synchronizes cudagraph_mode across ranks by taking the minimum. + Returns: tuple[ should_ubatch: Are all DP ranks going to microbatch num_tokens_after_padding: A tensor containing the total number of tokens per-microbatch for each DP rank including any DP padding. + synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks) ] """ @@ -137,6 +136,7 @@ def _synchronize_dp_ranks( should_dp_pad=should_attempt_dp_padding, orig_num_tokens_per_ubatch=num_tokens_unpadded, padded_num_tokens_per_ubatch=num_tokens_padded, + cudagraph_mode=cudagraph_mode, parallel_config=parallel_config, ) @@ -146,7 +146,7 @@ def _synchronize_dp_ranks( assert should_attempt_dp_padding == should_dp_pad # Check conditions for microbatching - should_ubatch = _post_process_ubatch(tensor) + should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches) if should_ubatch and not should_dp_pad: logger.debug_once( @@ -164,7 +164,10 @@ def _synchronize_dp_ranks( should_dp_pad, ) - return should_ubatch, num_tokens_after_padding + # Synchronize cudagraph_mode across ranks (take min) + synced_cudagraph_mode = _post_process_cudagraph_mode(tensor) + + return should_ubatch, num_tokens_after_padding, synced_cudagraph_mode def coordinate_batch_across_dp( @@ -175,7 +178,8 @@ def coordinate_batch_across_dp( num_tokens_padded: int | None = None, uniform_decode: bool | None = None, num_scheduled_tokens_per_request: np.ndarray | None = None, -) -> tuple[UBatchSlices | None, torch.Tensor | None]: + cudagraph_mode: int = 0, +) -> tuple[bool, torch.Tensor | None, int]: """ Coordinates amongst all DP ranks to determine if and how the full batch should be split into microbatches. @@ -191,6 +195,7 @@ def coordinate_batch_across_dp( only contains single token decodes num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The number of tokens per request. + cudagraph_mode: The cudagraph mode for this rank (0=NONE, 1=PIECEWISE, 2=FULL) Returns: tuple[ ubatch_slices: if this is set then all DP ranks have agreed to @@ -199,12 +204,13 @@ def coordinate_batch_across_dp( tokens per-microbatch for each DP rank including padding. Will be padded up to the max value across all DP ranks when allow_dp_padding is True. + synced_cudagraph_mode: The synchronized cudagraph mode (min across ranks) ] """ if parallel_config.data_parallel_size == 1: # Early exit. - return None, None + return False, None, cudagraph_mode # If the caller has explicitly enabled microbatching. should_attempt_ubatching = False @@ -220,31 +226,15 @@ def coordinate_batch_across_dp( if num_tokens_padded is None: num_tokens_padded = num_tokens_unpadded - (should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks( - num_tokens_unpadded, - num_tokens_padded, - should_attempt_ubatching, - allow_dp_padding, - parallel_config, + (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) = ( + _synchronize_dp_ranks( + num_tokens_unpadded, + num_tokens_padded, + should_attempt_ubatching, + allow_dp_padding, + cudagraph_mode, + parallel_config, + ) ) - # Don't microbatch unless every other DP worker is also microbatching - if not should_ubatch: - return (None, num_tokens_after_padding) - - # This doesn't actually pad the ubatch slices. It just initializes the - # split point to the padded value so that padding can be applied - # to the second ubatch in pad_out_ubatch_slice after attention - # metadata creation - assert num_tokens_after_padding is not None - num_tokens_padded = int(num_tokens_after_padding[0].item()) - token_split_point = int(num_tokens_padded) // 2 - - assert num_scheduled_tokens_per_request is not None - ubatch_slices = create_ubatch_slices( - num_scheduled_tokens_per_request, token_split_point - ) - ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded) - assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded - - return (ubatch_slices, num_tokens_after_padding) + return (should_ubatch, num_tokens_after_padding, synced_cudagraph_mode) diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py index f6bc607c1ae67..a2e3decad0486 100644 --- a/vllm/v1/worker/gpu/async_utils.py +++ b/vllm/v1/worker/gpu/async_utils.py @@ -2,14 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager +import numpy as np import torch from vllm.v1.outputs import ( AsyncModelRunnerOutput, LogprobsTensors, ModelRunnerOutput, - SamplerOutput, ) +from vllm.v1.worker.gpu.sample.output import SamplerOutput class AsyncOutput(AsyncModelRunnerOutput): @@ -34,29 +35,18 @@ class AsyncOutput(AsyncModelRunnerOutput): with torch.cuda.stream(self.copy_stream): self.copy_stream.wait_stream(default_stream) - # NOTE(woosuk): We must ensure that CPU tensors are not freed - # before the device-to-host copy is fully completed. For instance, - # operations like - # self.sampled_token_np = ...to("cpu", non_blocking=True).numpy() - # are unsafe because the underlying CPU tensor can be prematurely freed and - # reused by other tensors before the asynchronous copy finishes, potentially - # causing race conditions. To prevent this, we delay freeing by holding - # references until the copy event signals completion. - # Likewise, we also need to keep the reference to the GPU tensors. - # This is done by keeping the reference to sampler_output and - # model_runner_output. - self.sampled_token_ids = sampler_output.sampled_token_ids.to( - "cpu", non_blocking=True - ) + self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids) if sampler_output.logprobs_tensors is not None: self.logprobs_tensors: LogprobsTensors | None = ( sampler_output.logprobs_tensors.to_cpu_nonblocking() ) else: self.logprobs_tensors = None - self.num_sampled_tokens_cpu = num_sampled_tokens.to( - "cpu", non_blocking=True - ) + if sampler_output.num_nans is not None: + self.num_nans = async_copy_to_np(sampler_output.num_nans) + else: + self.num_nans = None + self.num_sampled_tokens_np = async_copy_to_np(num_sampled_tokens) self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} if self.model_runner_output.prompt_logprobs_dict: for k, v in self.model_runner_output.prompt_logprobs_dict.items(): @@ -68,7 +58,6 @@ class AsyncOutput(AsyncModelRunnerOutput): def get_output(self) -> ModelRunnerOutput: self.copy_event.synchronize() - num_sampled_tokens_np = self.num_sampled_tokens_cpu.numpy() # NOTE(woosuk): The following code is to ensure compatibility with # the existing model runner. @@ -76,10 +65,18 @@ class AsyncOutput(AsyncModelRunnerOutput): # rather than Python lists. sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist() num_reqs = len(sampled_token_ids) + num_sampled_tokens = self.num_sampled_tokens_np.tolist() for i in range(num_reqs): - del sampled_token_ids[i][num_sampled_tokens_np[i] :] + del sampled_token_ids[i][num_sampled_tokens[i] :] self.model_runner_output.sampled_token_ids = sampled_token_ids + if self.num_nans is not None: + num_nans = self.num_nans.tolist() + self.model_runner_output.num_nans_in_logits = { + req_id: num_nans[i] + for i, req_id in enumerate(self.model_runner_output.req_ids) + } + if self.logprobs_tensors is not None: self.model_runner_output.logprobs = self.logprobs_tensors.tolists() self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict @@ -95,3 +92,7 @@ def async_barrier(event: torch.cuda.Event | None): finally: if event is not None: event.record() + + +def async_copy_to_np(x: torch.Tensor) -> np.ndarray: + return x.to("cpu", non_blocking=True).numpy() diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 5aa1a33d851cc..6386f1a08b446 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -168,9 +168,9 @@ def build_attn_metadata( query_start_loc=query_start_loc_gpu, query_start_loc_cpu=query_start_loc_cpu, seq_lens=seq_lens, - seq_lens_cpu=seq_lens_cpu, + _seq_lens_cpu=seq_lens_cpu, max_seq_len=max_seq_len, - num_computed_tokens_cpu=num_computed_tokens_cpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 8ae887fe82cfe..1b78734fba78f 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -354,6 +354,55 @@ def combine_sampled_and_draft_tokens( return logits_indices +@triton.jit +def _get_num_sampled_and_rejected_kernel( + num_sampled_ptr, + num_rejected_ptr, + seq_lens_ptr, + cu_num_logits_ptr, + idx_mapping_ptr, + prefill_len_ptr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + seq_len = tl.load(seq_lens_ptr + batch_idx) + prefill_len = tl.load(prefill_len_ptr + req_state_idx) + is_chunked_prefilling = seq_len < prefill_len + + num_sampled = tl.load(num_sampled_ptr + batch_idx) + num_sampled = tl.where(is_chunked_prefilling, 0, num_sampled) + tl.store(num_sampled_ptr + batch_idx, num_sampled) + + logits_start = tl.load(cu_num_logits_ptr + batch_idx) + logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1) + num_logits = logits_end - logits_start + + num_rejected = num_logits - num_sampled + num_rejected = tl.where(is_chunked_prefilling, 0, num_rejected) + tl.store(num_rejected_ptr + batch_idx, num_rejected) + + +def get_num_sampled_and_rejected( + num_sampled: torch.Tensor, + seq_lens: torch.Tensor, + cu_num_logits: torch.Tensor, + idx_mapping: torch.Tensor, + prefill_len: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + num_reqs = idx_mapping.shape[0] + num_rejected = torch.empty_like(num_sampled) + _get_num_sampled_and_rejected_kernel[(num_reqs,)]( + num_sampled, + num_rejected, + seq_lens, + cu_num_logits, + idx_mapping, + prefill_len, + ) + return num_sampled, num_rejected + + @triton.jit def _post_update_kernel( idx_mapping_ptr, diff --git a/vllm/v1/worker/gpu/metrics/__init__.py b/vllm/v1/worker/gpu/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/worker/gpu/metrics/logits.py b/vllm/v1/worker/gpu/metrics/logits.py new file mode 100644 index 0000000000000..fd7b30beaa1f8 --- /dev/null +++ b/vllm/v1/worker/gpu/metrics/logits.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch._inductor.runtime.triton_helpers import libdevice + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _num_nans_kernel( + logits_ptr, + logits_stride, + num_nans_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + num_nans = 0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, mask=mask, other=0 + ) + logits = logits.to(tl.float32) + is_nan = libdevice.isnan(logits).to(tl.int1) + num_nans += tl.sum(is_nan).to(tl.int32) + tl.store(num_nans_ptr + req_idx, num_nans) + + +def get_num_nans(logits: torch.Tensor) -> torch.Tensor: + num_reqs, vocab_size = logits.shape + BLOCK_SIZE = 8192 + num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device) + _num_nans_kernel[(num_reqs,)]( + logits, + logits.stride(0), + num_nans, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + return num_nans diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 9bf345053c30c..9f4c6edfb6aa9 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -25,7 +25,6 @@ from vllm.v1.outputs import ( LogprobsTensors, ModelRunnerOutput, ) -from vllm.v1.sample.sampler import SamplerOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, @@ -43,6 +42,7 @@ from vllm.v1.worker.gpu.input_batch import ( InputBatch, InputBuffers, combine_sampled_and_draft_tokens, + get_num_sampled_and_rejected, post_update, prepare_pos_seq_lens, prepare_prefill_inputs, @@ -52,12 +52,10 @@ from vllm.v1.worker.gpu.sample.metadata import ( SamplingMetadata, expand_sampling_metadata, ) +from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.spec_decode import init_speculator -from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( - get_num_rejected, - rejection_sample, -) +from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -621,16 +619,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Sample tokens and compute logprobs (if needed). sampler_output = self.sampler(logits, sampling_metadata) - # Get the number of sampled tokens. - prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping] - is_chunked_prefilling = input_batch.seq_lens < prefill_len if input_batch.num_draft_tokens == 0: # No draft tokens (common case). - # 0 if chunked-prefilling, 1 if not. - num_sampled = (~is_chunked_prefilling).int() - num_rejected = torch.zeros_like(num_sampled) + num_sampled = torch.ones( + input_batch.num_reqs, dtype=torch.int32, device=self.device + ) else: - # Draft tokens for spec decoding. + # Rejection sampling for spec decoding. input_ids = input_batch.input_ids[input_batch.logits_indices] sampled_tokens, num_sampled = rejection_sample( sampler_output.sampled_token_ids, @@ -638,13 +633,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch.cu_num_logits, self.num_speculative_steps, ) - num_sampled *= ~is_chunked_prefilling - num_rejected = get_num_rejected( - input_batch.cu_num_logits, - num_sampled, - ) sampler_output.sampled_token_ids = sampled_tokens - # TODO(woosuk): Support logprobs with spec decoding. + + # Get the number of sampled and rejected tokens. + # For chunked prefills, num_sampled and num_rejected are both 0. + num_sampled, num_rejected = get_num_sampled_and_rejected( + num_sampled, + input_batch.seq_lens, + input_batch.cu_num_logits, + input_batch.idx_mapping, + self.req_states.prefill_len.gpu, + ) return sampler_output, num_sampled, num_rejected def compute_prompt_logprobs( diff --git a/vllm/v1/worker/gpu/sample/metadata.py b/vllm/v1/worker/gpu/sample/metadata.py index 040771c051bb4..f10c72049cbae 100644 --- a/vllm/v1/worker/gpu/sample/metadata.py +++ b/vllm/v1/worker/gpu/sample/metadata.py @@ -13,6 +13,7 @@ class SamplingMetadata: top_p: torch.Tensor | None top_k: torch.Tensor | None + min_p: torch.Tensor | None repetition_penalty: torch.Tensor frequency_penalty: torch.Tensor @@ -44,6 +45,7 @@ class SamplingMetadata: # top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device) top_p = None top_k = None + min_p = torch.zeros(num_reqs, dtype=torch.float32, device=device) # NOTE(woosuk): We must set penalties to their default values to make sure # the penalties kernel does not touch the placeholder bin_counts tensors. repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device) @@ -64,6 +66,7 @@ class SamplingMetadata: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, @@ -85,6 +88,8 @@ def _expand_sampling_metadata_kernel( expanded_top_p_ptr, top_k_ptr, expanded_top_k_ptr, + min_p_ptr, + expanded_min_p_ptr, rep_penalty_ptr, expanded_rep_penalty_ptr, freq_penalty_ptr, @@ -115,6 +120,10 @@ def _expand_sampling_metadata_kernel( top_k = tl.load(top_k_ptr + req_idx) tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask) + if min_p_ptr is not None: + min_p = tl.load(min_p_ptr + req_idx) + tl.store(expanded_min_p_ptr + start_idx + block, min_p, mask=mask) + rep_penalty = tl.load(rep_penalty_ptr + req_idx) tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask) @@ -138,6 +147,7 @@ def expand_sampling_metadata( expanded_temp = create_empty(sampling_metadata.temperature) expanded_top_p = create_empty(sampling_metadata.top_p) expanded_top_k = create_empty(sampling_metadata.top_k) + expanded_min_p = create_empty(sampling_metadata.min_p) expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty) expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty) expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty) @@ -151,6 +161,8 @@ def expand_sampling_metadata( expanded_top_p, sampling_metadata.top_k, expanded_top_k, + sampling_metadata.min_p, + expanded_min_p, sampling_metadata.repetition_penalty, expanded_repetition_penalty, sampling_metadata.frequency_penalty, @@ -166,6 +178,7 @@ def expand_sampling_metadata( temperature=expanded_temp, top_p=expanded_top_p, top_k=expanded_top_k, + min_p=expanded_min_p, seeds=expanded_seeds, repetition_penalty=expanded_repetition_penalty, frequency_penalty=expanded_frequency_penalty, diff --git a/vllm/v1/worker/gpu/sample/min_p.py b/vllm/v1/worker/gpu/sample/min_p.py new file mode 100644 index 0000000000000..c98a42cb2b1bb --- /dev/null +++ b/vllm/v1/worker/gpu/sample/min_p.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _min_p_kernel( + logits_ptr, + logits_stride, + min_p_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + min_p = tl.load(min_p_ptr + req_idx).to(tl.float32) + if min_p == 0.0: + return + + max_val = float("-inf") + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf") + ) + max_val = tl.max(tl.maximum(logits, max_val)) + max_val = max_val.to(tl.float32) # type: ignore + + threshold = max_val + tl.log(min_p) + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf") + ) + logits = tl.where(logits < threshold, float("-inf"), logits) + tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask) + + +def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None: + num_reqs, vocab_size = logits.shape + BLOCK_SIZE = 1024 + _min_p_kernel[(num_reqs,)]( + logits, + logits.stride(0), + min_p, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/vllm/v1/worker/gpu/sample/output.py b/vllm/v1/worker/gpu/sample/output.py new file mode 100644 index 0000000000000..13e8cf1d6c1ec --- /dev/null +++ b/vllm/v1/worker/gpu/sample/output.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import torch + +from vllm.v1.outputs import LogprobsTensors + + +@dataclass +class SamplerOutput: + sampled_token_ids: torch.Tensor + logprobs_tensors: LogprobsTensors | None + num_nans: torch.Tensor | None diff --git a/vllm/v1/worker/gpu/sample/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py index c8d4b7d81841d..b4fcc822ecfce 100644 --- a/vllm/v1/worker/gpu/sample/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -62,6 +62,7 @@ def _penalties_and_temperature_kernel( mask=packed_block < tl.cdiv(vocab_size, 32), ) prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1 + prompt_bin_mask = prompt_bin_mask.to(tl.int1) prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE) # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 3429dd3e4d0fb..84a3e18671b2c 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -3,12 +3,15 @@ import torch +import vllm.envs as envs from vllm.config.model import LogprobsMode -from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu.sample.min_p import apply_min_p +from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature @@ -20,12 +23,16 @@ class Sampler: if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]: raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}") self.logprobs_mode = logprobs_mode + self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default. def __call__( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: + # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear + # that num_nans is computed before applying penalties and temperature. + num_nans = get_num_nans(logits) if self.compute_nans else None sampled, processed_logits = self.sample(logits, sampling_metadata) if sampling_metadata.max_num_logprobs is not None: logits = ( @@ -48,6 +55,7 @@ class Sampler: # token per request. sampled_token_ids=sampled.view(-1, 1), logprobs_tensors=logprobs_tensors, + num_nans=num_nans, ) return sampler_output @@ -61,6 +69,10 @@ class Sampler: # Apply penalties and temperature in place. apply_penalties_and_temperature(logits, sampling_metadata) + # Apply min_p in place. + if sampling_metadata.min_p is not None: + apply_min_p(logits, sampling_metadata.min_p) + # Apply top_k and/or top_p. This might return a new tensor. logits = apply_top_k_top_p( logits, sampling_metadata.top_k, sampling_metadata.top_p ) diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py index 43c6ac518bccc..8a7bf28bacbd4 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py @@ -69,15 +69,3 @@ def rejection_sample( num_warps=1, ) return sampled, num_sampled - - -@torch.compile(dynamic=True) -def get_num_rejected( - cu_num_logits: torch.Tensor, - num_sampled: torch.Tensor, -) -> torch.Tensor: - num_logits = cu_num_logits[1:] - cu_num_logits[:-1] - num_rejected = num_logits - num_sampled - # No token is rejected for chunked prefills. - num_rejected *= num_sampled > 0 - return num_rejected diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 367348c4a18f7..6823c0c8ee5c7 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -87,6 +87,7 @@ class RequestState: self.temperature = self._make_param(self.max_num_reqs, torch.float32) self.top_p = self._make_param(self.max_num_reqs, torch.float32) self.top_k = self._make_param(self.max_num_reqs, torch.int32) + self.min_p = self._make_param(self.max_num_reqs, torch.float32) self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32) self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32) self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32) @@ -162,6 +163,7 @@ class RequestState: else: top_k = self.vocab_size self.top_k.np[req_idx] = top_k + self.min_p.np[req_idx] = sampling_params.min_p self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty self.presence_penalty.np[req_idx] = sampling_params.presence_penalty @@ -217,6 +219,10 @@ class RequestState: no_top_k = np.all(top_k == self.vocab_size) top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None + min_p = self.min_p.np[idx_mapping_np] + no_min_p = np.all(min_p == 0.0) + min_p = self.min_p.copy_np_to_gpu(min_p) if not no_min_p else None + rep_penalty = self.repetition_penalty.np[idx_mapping_np] rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty) freq_penalty = self.frequency_penalty.np[idx_mapping_np] @@ -236,6 +242,7 @@ class RequestState: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, repetition_penalty=rep_penalty, frequency_penalty=freq_penalty, presence_penalty=pres_penalty, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e7991baeaa1b8..ead7a3619dea5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -15,7 +15,7 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.collection_utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors -from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates from vllm.v1.sample.logits_processor import ( BatchUpdateBuilder, LogitsProcessors, @@ -33,7 +33,6 @@ class CachedRequestState: prompt_token_ids: list[int] | None mm_features: list[MultiModalFeatureSpec] sampling_params: SamplingParams | None - pooling_params: PoolingParams | None generator: torch.Generator | None block_ids: tuple[list[int], ...] @@ -51,11 +50,18 @@ class CachedRequestState: # Used when both async_scheduling and spec_decode are enabled. prev_num_draft_len: int = 0 + # for pooling models + pooling_params: PoolingParams | None = None + pooling_states: PoolingStates | None = None + def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds ) + if self.pooling_params is not None: + self.pooling_states = PoolingStates() + @property def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) @@ -255,7 +261,9 @@ class InputBatch: # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() + # for pooling models self.pooling_params: dict[str, PoolingParams] = {} + self.pooling_states: dict[str, PoolingStates] = {} # Cached reference to the GPU tensor of previously sampled tokens self.prev_sampled_token_ids: torch.Tensor | None = None @@ -413,7 +421,11 @@ class InputBatch: sampling_params.bad_words_token_ids ) elif pooling_params := request.pooling_params: + pooling_states = request.pooling_states + assert pooling_states is not None + self.pooling_params[req_id] = pooling_params + self.pooling_states[req_id] = pooling_states self.logits_processing_needs_token_ids[req_index] = ( pooling_params.requires_token_ids ) @@ -469,6 +481,7 @@ class InputBatch: if self.is_pooling_model: self.pooling_params.pop(req_id, None) + self.pooling_states.pop(req_id, None) return req_index self.greedy_reqs.discard(req_id) @@ -482,6 +495,8 @@ class InputBatch: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) + if self.prev_req_id_to_index is not None: + self.prev_req_id_to_index.pop(req_id, None) self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: @@ -835,13 +850,19 @@ class InputBatch: assert len(self.req_ids) == len(self.pooling_params) return [self.pooling_params[req_id] for req_id in self.req_ids] + def get_pooling_states(self) -> list[PoolingStates]: + assert len(self.req_ids) == len(self.pooling_states) + return [self.pooling_states[req_id] for req_id in self.req_ids] + def get_pooling_metadata(self) -> PoolingMetadata: pooling_params = self.get_pooling_params() + pooling_states = self.get_pooling_states() return PoolingMetadata( prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, + pooling_states=pooling_states, ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2218e4f023f92..1aa2ec6bb655c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import gc import itertools import time @@ -27,7 +28,7 @@ from vllm.attention.backends.abstract import ( ) from vllm.attention.layer import Attention, MLAAttention from vllm.compilation.counter import compilation_counter -from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import ( CompilationMode, @@ -48,7 +49,10 @@ from vllm.distributed.parallel_state import ( is_global_first_rank, prepare_communication_buffer_for_model, ) -from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.forward_context import ( + BatchDescriptor, + set_forward_context, +) from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.rotary_embedding import ( @@ -88,6 +92,7 @@ from vllm.utils.jsontree import json_map_leaves from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.nvtx_pytorch_hooks import PytHooks from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.torch_utils import ( get_dtype_size, @@ -131,7 +136,7 @@ from vllm.v1.outputs import ( SamplerOutput, make_empty_encoder_model_runner_output, ) -from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.logits_processor.interface import LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata @@ -144,6 +149,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext +from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -153,8 +159,10 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.ubatch_utils import ( UBatchSlices, check_ubatch_thresholds, + maybe_create_ubatch_slices, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp +from vllm.v1.worker.workspace import lock_workspace from .utils import ( AttentionGroup, @@ -257,6 +265,7 @@ class ExecuteModelState(NamedTuple): sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None ec_connector_output: ECConnectorOutput | None + cudagraph_stats: CUDAGraphStat | None class GPUModelRunner( @@ -289,6 +298,7 @@ class GPUModelRunner( self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype + self.kv_cache_dtype = kv_cache_dtype_str_to_dtype( cache_config.cache_dtype, self.model_config ) @@ -326,6 +336,7 @@ class GPUModelRunner( self.use_alibi = model_config.uses_alibi self.cascade_attn_enabled = not self.model_config.disable_cascade_attn + self.is_mm_prefix_lm = self.model_config.is_mm_prefix_lm # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY @@ -597,6 +608,7 @@ class GPUModelRunner( # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None + self.layerwise_nvtx_hooks_registered = False def reset_mm_cache(self) -> None: if self.mm_budget: @@ -774,7 +786,14 @@ class GPUModelRunner( # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + # NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint, + # so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds + # apart from the forced-preemption case in reset_prefix_cache. And in + # that case we include the resumed_req_ids in the unscheduled set so + # that they get cleared from the persistent batch before being re-scheduled + # in the normal resumed request path. + unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids) # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct @@ -1086,7 +1105,6 @@ class GPUModelRunner( device=self.device, pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, - multimodal_cpu_fields=model.multimodal_cpu_fields, ): mm_kwargs_combined.update(mm_kwargs_group) @@ -1253,6 +1271,8 @@ class GPUModelRunner( if not isinstance(kv_cache_spec, CrossAttentionSpec): return None, None + # Zero out buffer for padding requests that are not actually scheduled (CGs) + self.encoder_seq_lens.np[:num_reqs] = 0 # Build encoder_seq_lens array mapping request indices to # encoder lengths for inputs scheduled in this batch for req_id in num_scheduled_tokens: @@ -1516,28 +1536,13 @@ class GPUModelRunner( """ :return: tuple[attn_metadata, spec_decode_common_attn_metadata] """ + # Attention metadata is not needed for attention free models + if len(self.kv_cache_config.kv_cache_groups) == 0: + return {}, None + num_tokens_padded = num_tokens_padded or num_tokens num_reqs_padded = num_reqs_padded or num_reqs - - logits_indices_padded = None - num_logits_indices = None - if logits_indices is not None: - num_logits_indices = logits_indices.size(0) - if self.cache_config.kv_sharing_fast_prefill: - logits_indices_padded = self._prepare_kv_sharing_fast_prefill( - logits_indices - ) - - # update seq_lens of decode reqs under DCP. - if self.dcp_world_size > 1: - self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( - self.seq_lens.cpu[:num_reqs], - self.dcp_world_size, - self.dcp_rank, - self.parallel_config.cp_kv_cache_interleave_size, - ) - self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0) - self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded) + assert num_reqs_padded is not None and num_tokens_padded is not None attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: @@ -1558,36 +1563,12 @@ class GPUModelRunner( self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - # Used in the below loop, uses padded shapes - query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1] - query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1] - seq_lens = self.seq_lens.gpu[:num_reqs_padded] - seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded] - num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs_padded - ] + kv_cache_groups = self.kv_cache_config.kv_cache_groups - dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None - if self.dcp_world_size > 1: - dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded] - dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded] - - spec_decode_common_attn_metadata = None - - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - for kv_cache_gid, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups - ): - encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens( - num_scheduled_tokens or {}, - kv_cache_group.kv_cache_spec, - num_reqs_padded, - ) - - if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): - # Encoder-only layers do not have KV cache, so we need to - # create a dummy block table and slot mapping for them. + def _get_block_table_and_slot_mapping(kv_cache_gid: int): + assert num_reqs_padded is not None and num_tokens_padded is not None + kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): blk_table_tensor = torch.zeros( (num_reqs_padded, 1), dtype=torch.int32, @@ -1603,92 +1584,149 @@ class GPUModelRunner( blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID - slot_mapping[num_tokens:num_tokens_padded].fill_(-1) - blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID + slot_mapping[num_tokens:num_tokens_padded].fill_(-1) + blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - seq_lens=seq_lens, - seq_lens_cpu=seq_lens_cpu, - num_computed_tokens_cpu=num_computed_tokens_cpu, - num_actual_tokens=num_tokens_padded, - num_reqs=num_reqs_padded, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - block_table_tensor=blk_table_tensor, - slot_mapping=slot_mapping, - logits_indices_padded=logits_indices_padded, - num_logits_indices=num_logits_indices, - causal=True, - encoder_seq_lens=encoder_seq_lens, - encoder_seq_lens_cpu=encoder_seq_lens_cpu, - dcp_local_seq_lens=dcp_local_seq_lens, - dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu, + return blk_table_tensor, slot_mapping + + block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) + cm_base = CommonAttentionMetadata( + query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], + query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], + seq_lens=self.seq_lens.gpu[:num_reqs_padded], + _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], + _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs_padded + ], + num_reqs=num_reqs_padded, + num_actual_tokens=num_tokens_padded, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=block_table_gid_0, + slot_mapping=slot_mapping_gid_0, + causal=True, + ) + + if self.dcp_world_size > 1: + self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( + self.seq_lens.cpu[:num_reqs], + self.dcp_world_size, + self.dcp_rank, + self.parallel_config.cp_kv_cache_interleave_size, ) + self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0) + self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded) + + cm_base.dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded] + cm_base.dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[ + :num_reqs_padded + ] + + if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill: + cm_base.num_logits_indices = logits_indices.size(0) + cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill( + logits_indices + ) + + def _build_attn_group_metadata( + kv_cache_gid: int, + attn_gid: int, + common_attn_metadata: CommonAttentionMetadata, + ubid: int | None = None, + ) -> None: + attn_group = self.attn_groups[kv_cache_gid][attn_gid] + cascade_attn_prefix_len = ( + cascade_attn_prefix_lens[kv_cache_gid][attn_gid] + if cascade_attn_prefix_lens + else 0 + ) + + builder = attn_group.get_metadata_builder(ubid or 0) + extra_attn_metadata_args = {} + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + assert ubid is None, "UBatching not supported with GDN yet" + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs_padded + ], + ) + + if for_cudagraph_capture: + attn_metadata_i = builder.build_for_cudagraph_capture( + common_attn_metadata + ) + else: + attn_metadata_i = builder.build( + common_prefix_len=cascade_attn_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args, + ) + + if ubid is None: + assert isinstance(attn_metadata, dict) + attn_metadata_dict = attn_metadata + else: + assert isinstance(attn_metadata, list) + attn_metadata_dict = attn_metadata[ubid] + + for layer_name in attn_group.layer_names: + attn_metadata_dict[layer_name] = attn_metadata_i + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + spec_decode_common_attn_metadata = None + for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups): + cm = copy(cm_base) # shallow copy + + # Basically only the encoder seq_lens, block_table and slot_mapping change + # for each kv_cache_group. + cm.encoder_seq_lens, cm.encoder_seq_lens_cpu = self._get_encoder_seq_lens( + num_scheduled_tokens or {}, + kv_cache_group.kv_cache_spec, + num_reqs_padded, + ) + if kv_cache_gid > 0: + cm.block_table_tensor, cm.slot_mapping = ( + _get_block_table_and_slot_mapping(kv_cache_gid) + ) if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: - spec_decode_common_attn_metadata = common_attn_metadata + spec_decode_common_attn_metadata = cm else: - spec_decode_common_attn_metadata = common_attn_metadata - - for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]): - cascade_attn_prefix_len = ( - cascade_attn_prefix_lens[kv_cache_gid][attn_gid] - if cascade_attn_prefix_lens - else 0 - ) - builder = attn_group.get_metadata_builder() - - extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): - extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens.gpu[ - :num_reqs_padded - ], - num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ - :num_reqs_padded - ], - ) + spec_decode_common_attn_metadata = cm + for attn_gid in range(len(self.attn_groups[kv_cache_gid])): if ubatch_slices is not None: - common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata - ) - for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list - ): - builder = attn_group.get_metadata_builder(ubatch_id=ubid) - if for_cudagraph_capture: - attn_metadata_i = builder.build_for_cudagraph_capture( - common_attn_metadata - ) - else: - attn_metadata_i = builder.build( - common_prefix_len=cascade_attn_prefix_len, - common_attn_metadata=common_attn_metadata, - ) - for layer_name in kv_cache_group.layer_names: - assert type(attn_metadata) is list - attn_metadata[ubid][layer_name] = attn_metadata_i + for ubid, _cm in enumerate(split_attn_metadata(ubatch_slices, cm)): + _build_attn_group_metadata(kv_cache_gid, attn_gid, _cm, ubid) + else: - assert isinstance(attn_metadata, dict) - if for_cudagraph_capture: - attn_metadata_i = builder.build_for_cudagraph_capture( - common_attn_metadata - ) - else: - attn_metadata_i = builder.build( - common_prefix_len=cascade_attn_prefix_len, - common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args, - ) - for layer_name in attn_group.layer_names: - attn_metadata[layer_name] = attn_metadata_i + _build_attn_group_metadata(kv_cache_gid, attn_gid, cm) + + if self.is_mm_prefix_lm: + req_doc_ranges = {} + for req_id in self.input_batch.req_ids: + image_doc_ranges = [] + req_state = self.requests[req_id] + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position + img_doc_range = pos_info.extract_embeds_range() + image_doc_ranges.extend(img_doc_range) + req_idx = self.input_batch.req_id_to_index[req_id] + req_doc_ranges[req_idx] = image_doc_ranges + + if isinstance(attn_metadata, list): + for ub_metadata in attn_metadata: + for _metadata in ub_metadata.values(): + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] + else: + for _metadata in attn_metadata.values(): + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] if spec_decode_common_attn_metadata is not None and ( num_reqs != num_reqs_padded or num_tokens != num_tokens_padded @@ -2098,8 +2136,6 @@ class GPUModelRunner( mm_kwargs, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - multimodal_cpu_fields=model.multimodal_cpu_fields, ): curr_group_outputs: list[torch.Tensor] = [] @@ -2125,8 +2161,6 @@ class GPUModelRunner( [video_mm_kwargs_item], device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - multimodal_cpu_fields=model.multimodal_cpu_fields, ) ) @@ -2283,20 +2317,6 @@ class GPUModelRunner( supported_tasks = list(model.pooler.get_supported_tasks()) - if self.scheduler_config.enable_chunked_prefill: - if "token_embed" in supported_tasks: - supported_tasks.remove("token_embed") - if "token_classify" in supported_tasks: - supported_tasks.remove("token_classify") - - logger.debug_once( - "Chunked prefill is not supported with " - "token_embed and token_classify tasks " - "which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it." - ) - if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: @@ -2373,11 +2393,12 @@ class GPUModelRunner( ) hidden_states = hidden_states[:num_scheduled_tokens] + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] + pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np.tolist(), device=hidden_states.device + num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device ) - seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( @@ -2385,7 +2406,7 @@ class GPUModelRunner( pooling_metadata=pooling_metadata, ) raw_pooler_output = json_map_leaves( - lambda x: x.to("cpu", non_blocking=True), + lambda x: x.to("cpu", non_blocking=True) if x is not None else x, raw_pooler_output, ) self._sync_device() @@ -2410,10 +2431,7 @@ class GPUModelRunner( # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if ( - self.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1 - ): + if self.compilation_config.pass_config.enable_sp and tp_size > 1: return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens @@ -2432,16 +2450,13 @@ class GPUModelRunner( ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens is_first_rank = get_pp_group().is_first_rank + is_encoder_decoder = self.model_config.is_encoder_decoder # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order ec_connector_output = None - if ( - self.supports_mm_inputs - and is_first_rank - and not self.model_config.is_encoder_decoder - ): + if self.supports_mm_inputs and is_first_rank and not is_encoder_decoder: # Run the multimodal encoder if any. with self.maybe_get_ec_connector_output( scheduler_output, @@ -2519,10 +2534,7 @@ class GPUModelRunner( num_input_tokens, intermediate_tensors, True ) - if ( - self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs - ): + if is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: # Run the encoder, just like we do with other multimodal inputs. # For an encoder-decoder model, our processing here is a bit # simpler, because the outputs are just passed to the decoder. @@ -2756,8 +2768,13 @@ class GPUModelRunner( # be improved in model runner v2) force_uniform_decode: bool | None = None, force_has_lora: bool | None = None, + num_encoder_reqs: int = 0, ) -> tuple[ - CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None + CUDAGraphMode, + BatchDescriptor, + bool, + torch.Tensor | None, + CUDAGraphStat | None, ]: num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) uniform_decode = ( @@ -2768,6 +2785,11 @@ class GPUModelRunner( if force_uniform_decode is None else force_uniform_decode ) + # Encoder-decoder models only support CG for decoder_step > 0 (no enc_output + # is present). Also, chunked-prefill is disabled, so batch are uniform. + has_encoder_output = ( + self.model_config.is_encoder_decoder and num_encoder_reqs > 0 + ) has_lora = ( len(self.input_batch.lora_id_to_lora_request) > 0 @@ -2776,22 +2798,24 @@ class GPUModelRunner( ) dispatch_cudagraph = ( - lambda num_tokens: self.cudagraph_dispatcher.dispatch( + lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch( num_tokens=num_tokens, has_lora=has_lora, - use_cascade_attn=use_cascade_attn, uniform_decode=uniform_decode, + disable_full=disable_full, ) if not force_eager else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) ) - cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded) + cudagraph_mode, batch_descriptor = dispatch_cudagraph( + num_tokens_padded, use_cascade_attn or has_encoder_output + ) num_tokens_padded = batch_descriptor.num_tokens # Extra coordination when running data-parallel since we need to coordinate # across ranks - ubatch_slices, num_tokens_across_dp = None, None + should_ubatch, num_tokens_across_dp = False, None if self.vllm_config.parallel_config.data_parallel_size > 1: # Disable DP padding when running eager to avoid excessive padding when # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller @@ -2801,28 +2825,84 @@ class GPUModelRunner( self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE ) - ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_tokens_unpadded=num_tokens_padded, - parallel_config=self.parallel_config, - allow_microbatching=allow_microbatching, - allow_dp_padding=allow_dp_padding, - num_tokens_padded=num_tokens_padded, - uniform_decode=uniform_decode, - num_scheduled_tokens_per_request=num_scheduled_tokens_np, + should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = ( + coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens, + parallel_config=self.parallel_config, + allow_microbatching=allow_microbatching, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens_np, + cudagraph_mode=cudagraph_mode.value, + ) ) - # Extract DP padding if there is any + # Extract DP-synced values if num_tokens_across_dp is not None: dp_rank = self.parallel_config.data_parallel_rank num_tokens_padded = int(num_tokens_across_dp[dp_rank].item()) - - # Re-dispatch with DP padding - cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded) + # Re-dispatch with DP padding so we have the correct batch_descriptor + cudagraph_mode, batch_descriptor = dispatch_cudagraph( + num_tokens_padded, + disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value, + ) # Assert to make sure the agreed upon token count is correct otherwise # num_tokens_across_dp will no-longer be valid assert batch_descriptor.num_tokens == num_tokens_padded - return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp + cudagraph_stats = None + if self.vllm_config.observability_config.cudagraph_metrics: + cudagraph_stats = CUDAGraphStat( + num_unpadded_tokens=num_tokens, + num_padded_tokens=batch_descriptor.num_tokens, + num_paddings=batch_descriptor.num_tokens - num_tokens, + runtime_mode=str(cudagraph_mode), + ) + + return ( + cudagraph_mode, + batch_descriptor, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) + + def _register_layerwise_nvtx_hooks(self) -> None: + """ + Register layerwise NVTX hooks if --enable-layerwise-nvtx-tracing is enabled + to trace detailed information of each layer or module in the model. + """ + + if ( + self.vllm_config.observability_config.enable_layerwise_nvtx_tracing + and not self.layerwise_nvtx_hooks_registered + ): + if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + logger.debug_once( + "layerwise NVTX tracing is not supported when CUDA graph is " + "turned off; you may observe part or all of the model " + "missing NVTX markers" + ) + + # In STOCK_TORCH_COMPILE mode, after registering hooks here, + # the __call__ function of nn.module will be recompiled with + # fullgraph=True. Since nvtx.range_push/pop are not traceable + # by torch dynamo, we can't register hook functions here + # because hook functions will also be traced by torch dynamo. + if ( + self.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + ): + logger.debug_once( + "layerwise NVTX tracing is not supported when " + "CompilationMode is STOCK_TORCH_COMPILE, skipping " + "function hooks registration" + ) + else: + pyt_hooks = PytHooks() + pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) + self.layerwise_nvtx_hooks_registered = True @torch.inference_mode() def execute_model( @@ -2907,7 +2987,7 @@ class GPUModelRunner( cascade_attn_prefix_lens = None # Disable cascade attention when using microbatching (DBO) - if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: # Pre-compute cascade attention prefix lengths cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( num_scheduled_tokens_np, @@ -2918,22 +2998,24 @@ class GPUModelRunner( ( cudagraph_mode, batch_desc, - ubatch_slices, + should_ubatch, num_tokens_across_dp, + cudagraph_stats, ) = self._determine_batch_execution_and_padding( num_tokens=num_tokens_unpadded, num_reqs=num_reqs, num_scheduled_tokens_np=num_scheduled_tokens_np, max_num_scheduled_tokens=max_num_scheduled_tokens, use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), ) logger.debug( "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " - "ubatch_slices: %s, num_tokens_across_dp: %s", + "should_ubatch: %s, num_tokens_across_dp: %s", cudagraph_mode, batch_desc, - ubatch_slices, + should_ubatch, num_tokens_across_dp, ) @@ -2941,9 +3023,24 @@ class GPUModelRunner( num_reqs_padded = ( batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + pad_attn = cudagraph_mode == CUDAGraphMode.FULL use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 - pad_attn = cudagraph_mode == CUDAGraphMode.FULL + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices (attn_metadata, spec_decode_common_attn_metadata) = ( self._build_attention_metadata( @@ -2952,7 +3049,7 @@ class GPUModelRunner( num_reqs=num_reqs, num_reqs_padded=num_reqs_padded if pad_attn else None, max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_attn, logits_indices=logits_indices, use_spec_decode=use_spec_decode, num_scheduled_tokens=scheduler_output.num_scheduled_tokens, @@ -2989,7 +3086,7 @@ class GPUModelRunner( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded, ), record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, @@ -3069,6 +3166,7 @@ class GPUModelRunner( sample_hidden_states, aux_hidden_states, ec_connector_output, + cudagraph_stats, ) self.kv_connector_output = kv_connector_output return None @@ -3104,6 +3202,7 @@ class GPUModelRunner( sample_hidden_states, aux_hidden_states, ec_connector_output, + cudagraph_stats, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -3219,6 +3318,7 @@ class GPUModelRunner( if self.supports_mm_inputs else None, num_nans_in_logits=num_nans_in_logits, + cudagraph_stats=cudagraph_stats, ) if not self.use_async_scheduling: @@ -3480,74 +3580,89 @@ class GPUModelRunner( if self.parallel_config.enable_eplb: self.eplb_state = EplbState(self.parallel_config, self.device) eplb_models = 0 - with DeviceMemoryProfiler() as m: - time_before_load = time.perf_counter() - model_loader = get_model_loader(self.load_config) - self.model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config - ) - if self.lora_config: - self.model = self.load_lora_model( - self.model, self.vllm_config, self.device + + try: + with DeviceMemoryProfiler() as m: + time_before_load = time.perf_counter() + model_loader = get_model_loader(self.load_config) + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config ) - if hasattr(self, "drafter"): - logger.info_once("Loading drafter model...") - self.drafter.load_model(self.model) - if ( - hasattr(self.drafter, "model") - and is_mixture_of_experts(self.drafter.model) - and self.parallel_config.enable_eplb - ): - spec_config = self.vllm_config.speculative_config - assert spec_config is not None - assert spec_config.draft_model_config is not None - logger.info_once( - "EPLB is enabled for drafter model %s.", - spec_config.draft_model_config.model, + if self.lora_config: + self.model = self.load_lora_model( + self.model, self.vllm_config, self.device ) + if hasattr(self, "drafter"): + logger.info_once("Loading drafter model...") + self.drafter.load_model(self.model) + if ( + hasattr(self.drafter, "model") + and is_mixture_of_experts(self.drafter.model) + and self.parallel_config.enable_eplb + ): + spec_config = self.vllm_config.speculative_config + assert spec_config is not None + assert spec_config.draft_model_config is not None + logger.info_once( + "EPLB is enabled for drafter model %s.", + spec_config.draft_model_config.model, + ) - global_expert_load = ( - global_expert_loads[eplb_models] - if global_expert_loads - else None - ) - old_global_expert_indices = ( - old_global_expert_indices_per_model[eplb_models] - if old_global_expert_indices_per_model - else None - ) - if self.eplb_state is None: - self.eplb_state = EplbState(self.parallel_config, self.device) - self.eplb_state.add_model( - self.drafter.model, - spec_config.draft_model_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, - ) - eplb_models += 1 + global_expert_load = ( + global_expert_loads[eplb_models] + if global_expert_loads + else None + ) + old_global_expert_indices = ( + old_global_expert_indices_per_model[eplb_models] + if old_global_expert_indices_per_model + else None + ) + if self.eplb_state is None: + self.eplb_state = EplbState( + self.parallel_config, self.device + ) + self.eplb_state.add_model( + self.drafter.model, + spec_config.draft_model_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, + ) + eplb_models += 1 - if self.use_aux_hidden_state_outputs: - if not supports_eagle3(self.get_model()): - raise RuntimeError( - "Model does not support EAGLE3 interface but " - "aux_hidden_state_outputs was requested" - ) + if self.use_aux_hidden_state_outputs: + if not supports_eagle3(self.get_model()): + raise RuntimeError( + "Model does not support EAGLE3 interface but " + "aux_hidden_state_outputs was requested" + ) - # Try to get auxiliary layers from speculative config, - # otherwise use model's default layers - aux_layers = self._get_eagle3_aux_layers_from_config() - if aux_layers: - logger.info( - "Using auxiliary layers from speculative config: %s", - aux_layers, - ) - else: - aux_layers = self.model.get_eagle3_aux_hidden_state_layers() + # Try to get auxiliary layers from speculative config, + # otherwise use model's default layers + aux_layers = self._get_eagle3_aux_layers_from_config() + if aux_layers: + logger.info( + "Using auxiliary layers from speculative config: %s", + aux_layers, + ) + else: + aux_layers = self.model.get_eagle3_aux_hidden_state_layers() - self.model.set_aux_hidden_state_layers(aux_layers) - time_after_load = time.perf_counter() - self.model_memory_usage = m.consumed_memory + self.model.set_aux_hidden_state_layers(aux_layers) + time_after_load = time.perf_counter() + self.model_memory_usage = m.consumed_memory + except torch.cuda.OutOfMemoryError as e: + msg = ( + "Failed to load model - not enough GPU memory. " + "Try lowering --gpu-memory-utilization to free memory for weights, " + "increasing --tensor-parallel-size, or using --quantization. " + "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " + "for more tips." + ) + combined_msg = f"{msg} (original error: {e})" + logger.error(combined_msg) + raise e logger.info_once( "Model loading took %.4f GiB memory and %.6f seconds", self.model_memory_usage / GiB_bytes, @@ -3602,11 +3717,14 @@ class GPUModelRunner( # wrap the model with full cudagraph wrapper if needed. cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: + if ( + cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.use_ubatching + ): self.model = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) - elif self.parallel_config.enable_dbo: + elif self.parallel_config.use_ubatching: if cudagraph_mode.has_full_cudagraphs(): self.model = UBatchWrapper( self.model, self.vllm_config, CUDAGraphMode.FULL, self.device @@ -3785,19 +3903,21 @@ class GPUModelRunner( return {} @contextmanager - def maybe_randomize_inputs(self, input_ids: torch.Tensor): + def maybe_randomize_inputs( + self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None + ): """ Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. This is to help balance expert-selection - during profile_run - during DP rank dummy run """ + dp_size = self.vllm_config.parallel_config.data_parallel_size randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 if not randomize_inputs: yield - else: - import functools + elif input_ids is not None: @functools.cache def rand_input_ids() -> torch.Tensor: @@ -3805,13 +3925,27 @@ class GPUModelRunner( self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype, ) - logger.debug_once("Randomizing dummy data for DP Rank") + logger.debug_once("Randomizing dummy input_ids for DP Rank") input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True) yield input_ids.fill_(0) + else: + + @functools.cache + def rand_inputs_embeds() -> torch.Tensor: + return torch.randn_like( + self.inputs_embeds.gpu, + ) + + assert inputs_embeds is not None + logger.debug_once("Randomizing dummy inputs_embeds for DP Rank") + inputs_embeds.copy_( + rand_inputs_embeds()[: inputs_embeds.size(0)], non_blocking=True + ) + yield + inputs_embeds.fill_(0) def _get_mm_dummy_batch( self, @@ -3833,15 +3967,12 @@ class GPUModelRunner( dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_items = [dummy_mm_item] * max_items_per_batch - model = cast(SupportsMultiModal, self.model) return next( mm_kwargs_group for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( dummy_mm_items, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - multimodal_cpu_fields=model.multimodal_cpu_fields, ) ) @@ -3939,7 +4070,7 @@ class GPUModelRunner( num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = ( + _cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = ( self._determine_batch_execution_and_padding( num_tokens=num_tokens_unpadded, num_reqs=num_reqs, @@ -3973,6 +4104,18 @@ class GPUModelRunner( num_reqs_padded = ( batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) attn_metadata: PerLayerAttnMetadata | None = None @@ -3994,12 +4137,13 @@ class GPUModelRunner( self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_reqs=num_reqs_padded, max_query_len=max_query_len, - ubatch_slices=ubatch_slices, - for_cudagraph_capture=True, + ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, + for_cudagraph_capture=is_graph_capturing, ) with self.maybe_dummy_run_with_lora( @@ -4050,16 +4194,16 @@ class GPUModelRunner( num_tokens_padded, None, False ) - if ubatch_slices is not None: + if ubatch_slices_padded is not None: # Adjust values to reflect a single ubatch. # TODO(sage,lucas): this is cruft that should be addressed in # the padding refactor. - num_tokens_padded = ubatch_slices[0].num_tokens + num_tokens_padded = ubatch_slices_padded[0].num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded with ( - self.maybe_randomize_inputs(input_ids), + self.maybe_randomize_inputs(input_ids, inputs_embeds), set_forward_context( attn_metadata, self.vllm_config, @@ -4067,7 +4211,7 @@ class GPUModelRunner( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, - ubatch_slices=ubatch_slices, + ubatch_slices=ubatch_slices_padded, ), ): outputs = self.model( @@ -4085,10 +4229,19 @@ class GPUModelRunner( if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) + # Eagle currently only supports PIECEWISE cudagraphs. + # Therefore only use cudagraphs if the main model uses PIECEWISE + # NOTE(lucas): this is a hack, need to clean up. use_cudagraphs = ( - cudagraph_runtime_mode.has_mode(CUDAGraphMode.PIECEWISE) - and not self.speculative_config.enforce_eager - ) + ( + is_graph_capturing + and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + ) + or ( + not is_graph_capturing + and cudagraph_runtime_mode != CUDAGraphMode.NONE + ) + ) and not self.speculative_config.enforce_eager # Note(gnovack) - We need to disable cudagraphs for one of the two # lora cases when cudagraph_specialize_lora is enabled. This is a @@ -4103,6 +4256,17 @@ class GPUModelRunner( is_graph_capturing=is_graph_capturing, ) + # We register layerwise NVTX hooks here after the first dynamo tracing is + # done to avoid nvtx operations in hook functions being traced by + # torch dynamo and causing graph breaks. + # Note that for DYNAMO_ONCE and VLLM_COMPILE mode, + # compiled model's dynamo tracing is only done once and the compiled model's + # __call__ function is replaced by calling the compiled function. + # So it's safe to register hooks here. Hooks will be registered to + # both compiled and uncompiled models but they will never + # be called on the compiled model execution path. + self._register_layerwise_nvtx_hooks() + # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real # requests to process. @@ -4226,10 +4390,13 @@ class GPUModelRunner( prompt_lens=dummy_prompt_lens, prompt_token_ids=dummy_token_ids, pooling_params=[dummy_pooling_params] * num_reqs, + pooling_states=[PoolingStates() for i in range(num_reqs)], ) dummy_metadata.build_pooling_cursor( - num_scheduled_tokens_list, device=hidden_states.device + num_scheduled_tokens_list, + seq_lens_cpu=dummy_prompt_lens, + device=hidden_states.device, ) try: @@ -4256,22 +4423,12 @@ class GPUModelRunner( supported_pooling_tasks = self.get_supported_pooling_tasks() if not supported_pooling_tasks: - if self.scheduler_config.enable_chunked_prefill: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks with chunked prefill enabled. " - "Please add --no-enable-chunked-prefill to your " - "config or CLI args. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) - else: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) + raise RuntimeError( + f"Model {self.model_config.model} does not support " + "any pooling tasks. See " + "https://docs.vllm.ai/en/latest/models/pooling_models.html " + "to learn more." + ) output_size = dict[PoolingTask, float]() for task in supported_pooling_tasks: @@ -4461,6 +4618,10 @@ class GPUModelRunner( # after here. set_cudagraph_capturing_enabled(False) + # Lock workspace to prevent resizing during execution. + # Max workspace sizes should have been captured during warmup/profiling. + lock_workspace() + end_time = time.perf_counter() elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory @@ -4502,7 +4663,7 @@ class GPUModelRunner( # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph allow_microbatching = ( - self.parallel_config.enable_dbo + self.parallel_config.use_ubatching and cudagraph_runtime_mode == CUDAGraphMode.FULL and uniform_decode and check_ubatch_thresholds( @@ -4616,6 +4777,9 @@ class GPUModelRunner( attention_backend_list, kv_cache_config.kv_cache_groups ) + # Check if attention backend supports PCP&DCP and related features. + check_attention_cp_compatibility(self.vllm_config) + for i, attn_backend_map in enumerate(attention_backend_maps): self.attn_groups.append(create_attn_groups(attn_backend_map, i)) @@ -4634,8 +4798,8 @@ class GPUModelRunner( if kv_cache_group_id < len(kernel_block_sizes) else None, num_metadata_builders=1 - if not self.parallel_config.enable_dbo - else 2, + if not self.parallel_config.use_ubatching + else self.parallel_config.num_ubatches, ) # Calculate reorder batch threshold (if needed) # Note (tdoublep): do this *after* constructing builders, @@ -4775,7 +4939,7 @@ class GPUModelRunner( # we need to adjust the cudagraph sizes to be a multiple of the uniform # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207 # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536 - # Will be removed in the near future when we have seperate cudagraph capture + # Will be removed in the near future when we have separate cudagraph capture # sizes for decode and mixed prefill-decode. if ( cudagraph_mode.decode_mode() == CUDAGraphMode.FULL @@ -5274,20 +5438,6 @@ class GPUModelRunner( kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) - if self.dcp_world_size > 1: - layer_type = cast(type[Any], AttentionLayerBase) - layers = get_layers_from_vllm_config(self.vllm_config, layer_type) - for layer in layers.values(): - layer_impl = getattr(layer, "impl", None) - if layer_impl is None: - continue - assert layer_impl.need_to_return_lse_for_decode, ( - "DCP requires attention impls to return" - " the softmax lse for decode, but the impl " - f"{layer_impl.__class__.__name__} " - "does not return the softmax lse for decode." - ) - def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 2ce2b64512560..af09129e67b1e 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -103,8 +103,10 @@ class UBatchWrapper: self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.comm_stream = torch.cuda.Stream(device=device) - # Two ubatch threads plus the main thread - self.ready_barrier = threading.Barrier(3) + # Ubatch threads plus the main thread + self.ready_barrier = threading.Barrier( + self.vllm_config.parallel_config.num_ubatches + 1 + ) self.cudagraphs: dict[int, CUDAGraphMetaData] = {} @@ -309,7 +311,7 @@ class UBatchWrapper: create_forward_context( attn_metadata[i] if attn_metadata is not None else None, self.vllm_config, - dp_metadata=dp_metadata, + dp_metadata=dp_metadata[i], batch_descriptor=batch_descriptor, cudagraph_runtime_mode=cudagraph_runtime_mode, ) @@ -417,18 +419,19 @@ class UBatchWrapper: # We shouldn't be here unless we are running with multiple DP ranks assert dp_metadata is not None - num_tokens_per_ubatch = ( - ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start - ) - dp_size = self.vllm_config.parallel_config.data_parallel_size - ubatch_num_tokens_across_dp = torch.tensor( - [num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32 - ) - ubatch_dp_metadata = DPMetadata.make( - self.vllm_config.parallel_config, - num_tokens_per_ubatch, - ubatch_num_tokens_across_dp, - ) + ubatch_dp_metadata = [] + for ubatch_slice in ubatch_slices: + dp_size = self.vllm_config.parallel_config.data_parallel_size + ubatch_num_tokens_across_dp = torch.tensor( + [ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32 + ) + ubatch_dp_metadata.append( + DPMetadata.make( + self.vllm_config.parallel_config, + ubatch_slice.num_tokens, + ubatch_num_tokens_across_dp, + ) + ) if ( num_tokens not in self.cudagraphs @@ -464,7 +467,7 @@ class UBatchWrapper: intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, compute_stream=compute_stream, - dp_metadata=dp_metadata, + dp_metadata=ubatch_dp_metadata, batch_descriptor=batch_descriptor, cudagraph_runtime_mode=CUDAGraphMode.NONE, ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ed6fb32bcb2f6..1e13650cd083e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -15,6 +15,7 @@ import torch.nn as nn import vllm.envs as envs from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config.compilation import CompilationMode from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -37,7 +38,7 @@ from vllm.model_executor import set_random_seed from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.platforms import current_platform -from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper +from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.utils.mem_constants import GiB_bytes @@ -51,14 +52,15 @@ from vllm.v1.outputs import ( ModelRunnerOutput, ) from vllm.v1.utils import report_usage_stats -from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase +from vllm.v1.worker.workspace import init_workspace_manager logger = init_logger(__name__) if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + from vllm.v1.worker.gpu_model_runner import GPUModelRunner class Worker(WorkerBase): @@ -78,6 +80,10 @@ class Worker(WorkerBase): is_driver_worker=is_driver_worker, ) + # configure float32 matmul precision according to vLLM env. + precision = envs.VLLM_FLOAT32_MATMUL_PRECISION + torch.backends.cuda.matmul.fp32_precision = precision + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils.import_utils import init_cached_hf_modules @@ -87,17 +93,19 @@ class Worker(WorkerBase): # Buffers saved before sleep self._sleep_saved_buffers: dict[str, torch.Tensor] = {} - # Torch/CUDA profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - # VLLM_TORCH_CUDA_PROFILE=1 + # Torch/CUDA profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None - if envs.VLLM_TORCH_PROFILER_DIR: + profiler_config = vllm_config.profiler_config + if profiler_config.profiler == "torch": worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" self.profiler = TorchProfilerWrapper( - worker_name=worker_name, local_rank=self.local_rank + profiler_config, + worker_name=worker_name, + local_rank=self.local_rank, + activities=["CPU", "CUDA"], ) - elif envs.VLLM_TORCH_CUDA_PROFILE: - self.profiler = CudaProfilerWrapper() + elif profiler_config.profiler == "cuda": + self.profiler = CudaProfilerWrapper(profiler_config) else: self.profiler = None @@ -248,6 +256,10 @@ class Worker(WorkerBase): else: raise RuntimeError(f"Not support device type: {self.device_config.device}") + # Initialize workspace manager + num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 + init_workspace_manager(self.device, num_ubatches) + # Construct the model runner if self.use_v2_model_runner: from vllm.v1.worker.gpu.model_runner import ( @@ -259,7 +271,11 @@ class Worker(WorkerBase): self.vllm_config, self.device ) else: - self.model_runner = GPUModelRunner(self.vllm_config, self.device) + from vllm.v1.worker.gpu_model_runner import ( + GPUModelRunner as GPUModelRunnerV1, + ) + + self.model_runner = GPUModelRunnerV1(self.vllm_config, self.device) if self.rank == 0: # If usage stat is enabled, collect relevant info. @@ -403,15 +419,31 @@ class Worker(WorkerBase): self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: - # warm up sizes that are not in cudagraph capture sizes, - # but users still want to compile for better performance, - # e.g. for the max-num-batched token size in chunked prefill. - compile_sizes = self.vllm_config.compilation_config.compile_sizes - warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] - if not self.model_config.enforce_eager: - capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes - if capture_sizes is not None: - warmup_sizes = [x for x in warmup_sizes if x not in capture_sizes] + warmup_sizes = [] + + if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + compile_sizes = self.vllm_config.compilation_config.compile_sizes + warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] + cg_capture_sizes: list[int] = [] + + if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + cg_capture_sizes = [] if cg_sizes is None else cg_sizes + warmup_sizes = [x for x in warmup_sizes if x not in cg_capture_sizes] + + compile_ranges = self.vllm_config.compilation_config.get_compile_ranges() + # For each compile_range, if none of the batch sizes + # in warmup_sizes or cudagraph_capture_sizes are in the range, + # add the end of the range to ensure compilation/warmup. + all_sizes = set(cg_capture_sizes) + all_sizes.update([x for x in warmup_sizes if isinstance(x, int)]) + for compile_range in compile_ranges: + if not any(x in compile_range for x in all_sizes): + warmup_sizes.append(compile_range.end) + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) @@ -552,11 +584,11 @@ class Worker(WorkerBase): if ( parallel_config.pipeline_parallel_size > 1 - and compilation_config.pass_config.enable_sequence_parallelism + and compilation_config.pass_config.enable_sp and forward_pass ): # currently only supported by V1 GPUModelRunner - assert isinstance(self.model_runner, GPUModelRunner) + assert not self.use_v2_model_runner num_scheduled_tokens_np = np.array( list(scheduler_output.num_scheduled_tokens.values()), dtype=np.int32, @@ -564,7 +596,7 @@ class Worker(WorkerBase): # TODO(lucas): This is pretty gross; ideally we should only ever call # `_determine_batch_execution_and_padding` once (will get called again # in `execute_model`) but this requires a larger refactor of PP. - _, batch_desc, _, _ = ( + _, batch_desc, _, _, _ = ( self.model_runner._determine_batch_execution_and_padding( num_tokens=num_scheduled_tokens, num_reqs=len(num_scheduled_tokens_np), @@ -899,10 +931,11 @@ def init_worker_distributed_environment( backend: str = "nccl", ) -> None: """Initialize the distributed environment.""" + attention_config = vllm_config.attention_config parallel_config = vllm_config.parallel_config from vllm.model_executor.layers.batch_invariant import init_batch_invariance - init_batch_invariance() + init_batch_invariance(attention_config.backend) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_method = distributed_init_method or "env://" diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index b799f1be73d9c..2bcc87b63bcdf 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -22,7 +22,6 @@ from vllm.distributed.kv_transfer import ( has_kv_transfer_group, ) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig @@ -138,16 +137,10 @@ class KVConnectorModelRunnerMixin: ) output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors() - output.kv_connector_stats = ( - KVConnectorModelRunnerMixin.get_kv_connector_stats() - ) - kv_connector.clear_connector_metadata() + output.kv_connector_stats = kv_connector.get_kv_connector_stats() + output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events() - @staticmethod - def get_kv_connector_stats() -> KVConnectorStats | None: - if has_kv_transfer_group(): - return get_kv_transfer_group().get_kv_connector_stats() - return None + kv_connector.clear_connector_metadata() @staticmethod def use_uniform_kv_cache( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f3dd9aa96d2ae..283f21b779e38 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -969,8 +969,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_kwargs, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - multimodal_cpu_fields=model.multimodal_cpu_fields, ): # Run the encoder. # `curr_group_outputs` is either of the following: @@ -2051,15 +2049,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_items = [dummy_mm_item] * max_items_per_batch - model = cast(SupportsMultiModal, self.model) return next( grouped_mm_kwargs for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( dummy_mm_items, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - multimodal_cpu_fields=model.multimodal_cpu_fields, ) ) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index ce18ca6c37165..5f6136b178b46 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, @@ -98,10 +98,10 @@ class TPUWorker: # MP runtime is initialized. self.profiler = None self.profile_dir = None - if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: + if vllm_config.profiler_config.profiler == "torch" and self.rank < 1: # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. - self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR + self.profile_dir = vllm_config.profiler_config.torch_profiler_dir logger.info( "Profiling enabled. Traces will be saved to: %s", self.profile_dir ) @@ -207,7 +207,8 @@ class TPUWorker: # one compiled bytecode. Having one FX graph/cached bytecode per # compiled model is required for `support_torch_compile` decorator to # skip dynamo guard. - self.model_runner.reset_dynamo_cache() + with set_current_vllm_config(self.vllm_config): + self.model_runner.reset_dynamo_cache() # Get the maximum amount of memory used by the model weights and # intermediate activations. diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 33a1921d2d98e..f6889173578d6 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -27,14 +27,16 @@ class UBatchSlice: UBatchSlices: TypeAlias = list[UBatchSlice] -def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool: - return (padded_num_tokens // 2) >= orig_num_tokens +def is_last_ubatch_empty( + orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int +) -> bool: + return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens def check_ubatch_thresholds( config: ParallelConfig, num_tokens: int, uniform_decode: bool ) -> bool: - if not config.enable_dbo: + if not config.use_ubatching: return False if uniform_decode: return num_tokens >= config.dbo_decode_token_threshold @@ -42,32 +44,69 @@ def check_ubatch_thresholds( return num_tokens >= config.dbo_prefill_token_threshold -def create_ubatch_slices( - num_scheduled_tokens: np.ndarray, split_point: int +# This pads the last ubatch slice out to the total number of tokens +# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding. +def _pad_out_ubatch_slices( + ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int ) -> UBatchSlices: + last_slice = ubatch_slices[-1] + padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded) + padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens) + + return ubatch_slices[:-1] + [ + UBatchSlice(padded_last_request_slice, padded_last_token_slice) + ] + + +def maybe_create_ubatch_slices( + should_ubatch: bool, + num_scheduled_tokens: np.ndarray, + num_tokens_padded: int, + num_reqs_padded: int, + num_ubatches: int, + split_point: list[int] | int | None = None, +) -> tuple[UBatchSlices | None, UBatchSlices | None]: + if not should_ubatch: + return None, None + + if split_point is None: + split_point = int(num_tokens_padded) // num_ubatches + + token_split_points = [split_point * i for i in range(1, num_ubatches)] + # TODO(lucas): Refactor the gpu_model_runner.py so we can pass # in cu_num_tokens directly (i.e. query_start_loc) cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) - first_ubatch_token_slice = slice(0, split_point) - second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1]) + ubatch_slices = [] + start_token = 0 - # Determine request slices using exclusive stop semantics - # First ubatch includes requests whose tokens overlap [0, split_point) - first_ubatch_req_stop = int( - np.searchsorted(cu_num_tokens, split_point, side="left") + # Add the end point to the split points to make iteration easier + all_points = token_split_points + [cu_num_tokens[-1]] + + for end_token in all_points: + token_slice = slice(start_token, end_token) + + # Determine request slices using exclusive stop semantics + # Ubatch includes requests whose tokens overlap [start_token, end_token) + + # Start at the request that contains the start_token + # or the request starting exactly at start_token (if on boundary) + req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1) + + # Stop at the request that starts at or after end_token + req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left")) + + req_slice = slice(req_start, req_stop) + ubatch_slices.append(UBatchSlice(req_slice, token_slice)) + + start_token = end_token + + ubatch_slices_padded = _pad_out_ubatch_slices( + ubatch_slices, num_tokens_padded, num_reqs_padded ) - first_ubatch_req_slice = slice(0, first_ubatch_req_stop) - # Second ubatch starts at the request that contains the split_point - # or the request starting exactly at split_point (if on boundary) - second_ubatch_req_start = int( - np.searchsorted(cu_num_tokens, split_point, side="right") - 1 - ) - second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1) + assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded - return [ - UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), - UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice), - ] + return ubatch_slices, ubatch_slices_padded diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index be8326e2fdbc1..e7a947f2ea8ca 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -7,10 +7,15 @@ import torch from vllm import forward_context from vllm.forward_context import ForwardContext +from vllm.logger import init_logger from vllm.utils.torch_utils import current_stream +logger = init_logger(__name__) + _THREAD_ID_TO_CONTEXT: dict = {} -_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] +# Here we hardcode the number of microbatches to 2 for default. +_NUM_UBATCHES: int = 2 +_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [] class UBatchContext: @@ -48,6 +53,7 @@ class UBatchContext: global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id _CURRENT_CONTEXTS[self.id] = self + # _NUM_UBATCHES is set in make_ubatch_contexts self.ready_barrier.wait() self.cpu_wait_event.wait() @@ -181,7 +187,7 @@ dbo_switch_to_compute_sync = _register_ubatch_function( def dbo_register_recv_hook(recv_hook): if len(_THREAD_ID_TO_CONTEXT) > 0: ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] - next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] + next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES] next_ctx.recv_hook = recv_hook @@ -202,7 +208,14 @@ def make_ubatch_contexts( ready_barrier: threading.Barrier, schedule: str = "default", ) -> list[UBatchContext]: - assert num_micro_batches == 2, "only been tested with 2 micro-batches" + global _NUM_UBATCHES, _CURRENT_CONTEXTS + assert num_micro_batches > 1, "num_micro_batches must be greater than 1" + + _NUM_UBATCHES = num_micro_batches + # Ensure the global context list is large enough + if len(_CURRENT_CONTEXTS) < num_micro_batches: + _CURRENT_CONTEXTS.extend([None] * (num_micro_batches - len(_CURRENT_CONTEXTS))) + """ Create a context manager for micro-batching synchronization. """ @@ -210,8 +223,6 @@ def make_ubatch_contexts( gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)] gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)] - assert len(forward_contexts) == 2 - ctxs = [] for i in range(num_micro_batches): ctx = UBatchContext( diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index bd88cb1b253f8..e9c48223d58b9 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -135,7 +135,7 @@ class AttentionGroup: kv_cache_spec: KVCacheSpec kv_cache_group_id: int # When ubatching is enabled we will have a metadata builder for each ubatch - # so that if they use internal persistant buffers for cudagraphs, and they + # so that if they use internal persistent buffers for cudagraphs, and they # won't have to worry about conflicting with the other ubatches. metadata_builders: list[AttentionMetadataBuilder] = field( default_factory=lambda: [] @@ -313,8 +313,12 @@ def bind_kv_cache( # TODO - analyze where runner_kv_caches is used and the right # way to ensure it properly reflects multiple attention layers # in the same decoder block. - if current_platform.is_cuda_alike() or current_platform.is_xpu(): - # We know that the GPU runner is not impacted by this + if ( + current_platform.is_cuda_alike() + or current_platform.is_xpu() + or current_platform.is_cpu() + ): + # We know that the GPU / CPU runner is not impacted by this # case. Some test code depends on runner_kv_caches, but # not in a way that's impacted by ignoring this. pass @@ -337,12 +341,12 @@ def is_residual_scattered_for_sp( The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled. - This follows the same logic as SequenceParallelismPass.is_applicable(): + This follows the same logic as SequenceParallelismPass.is_applicable_for_range(): - In full-graph compilation mode (no splitting ops or using inductor graph partition), SP is always applied - Otherwise, SP is only applied for specific shapes in compile_sizes """ - if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism: + if not vllm_config.compilation_config.pass_config.enable_sp: return False tp = vllm_config.parallel_config.tensor_parallel_size diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py new file mode 100644 index 0000000000000..bbbd7705d54e4 --- /dev/null +++ b/vllm/v1/worker/workspace.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +import os +from itertools import accumulate +from math import prod +from typing import Optional + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.utils.math_utils import round_up +from vllm.v1.worker.ubatching import dbo_current_ubatch_id + +logger = init_logger(__name__) + + +def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int: + return prod(shape) * dtype.itemsize + + +# Constants +_MB = 1024**2 +_GiB = 1024**3 + +# Global workspace manager instance +_manager: Optional["WorkspaceManager"] = None + + +class WorkspaceManager: + """Manager for workspace allocation. + + Manages workspace buffers for DBO (Dual Batch Overlap) execution. + Can be locked to prevent further growth during execution. + """ + + def __init__(self, device: torch.device, num_ubatches: int | None = None): + self._device = device + # Cache num ubatches at init based on configuration (default to 1) + self._num_ubatches = num_ubatches if num_ubatches is not None else 1 + self._current_workspaces: list[torch.Tensor | None] = [None, None] + self._locked: bool = False + + @staticmethod + def _workspace_size_bytes(workspace: torch.Tensor | None) -> int: + """Get size of workspace in bytes.""" + if workspace is None: + return 0 + return workspace.numel() * workspace.element_size() + + def lock(self) -> None: + """Lock the workspace to prevent further growth. + + After locking, any attempt to allocate a larger workspace will raise + an assertion error. This ensures workspace size is fixed during execution. + """ + self._locked = True + if envs.VLLM_DEBUG_WORKSPACE: + logger.info( + "[WORKSPACE DEBUG] Workspace locked. Current sizes: %s", + [ + self._workspace_size_bytes(ws) / _MB + for ws in self._current_workspaces + if ws is not None + ], + ) + + def is_locked(self) -> bool: + """Check if workspace is locked.""" + return self._locked + + def get_simultaneous( + self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype] + ) -> list[torch.Tensor]: + """Get multiple workspace tensors simultaneously from a single allocation. + + Args: + *shapes_and_dtypes: One or more (shape, dtype) tuples. + + Returns: + List of tensor views into the workspace buffer, one per shape/dtype pair. + """ + actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes] + aligned_bytes = [round_up(actual, 256) for actual in actual_bytes] + total_bytes = sum(aligned_bytes) + + # Calculate cumulative offsets using itertools.accumulate + offsets = list(accumulate([0] + aligned_bytes[:-1])) + + current_workspace = self._ensure_workspace_size(total_bytes) + + return [ + current_workspace[offsets[i] : offsets[i] + actual_bytes[i]] + .view(shapes_and_dtypes[i][1]) + .reshape(shapes_and_dtypes[i][0]) + for i in range(len(shapes_and_dtypes)) + ] + + def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor: + """Ensure workspace is allocated and large enough, return current workspace. + + Args: + required_bytes: The number of bytes required. + + Returns: + The current workspace tensor. + """ + ubatch_id = dbo_current_ubatch_id() + current_workspace = self._current_workspaces[ubatch_id] + current_size = self._workspace_size_bytes(current_workspace) + + if current_size < required_bytes: + + def get_caller_info() -> str: + """Find first frame outside WorkspaceManager.""" + curr_frame = inspect.currentframe() + if curr_frame is None: + return "unknown" + # Walk up the stack skipping WorkspaceManager frames + curr_frame = curr_frame.f_back + while curr_frame is not None: + # TODO: This only catches instance methods (self), missing + # classmethods and staticmethods. Once Python 3.11+ is the + # minimum supported version, use co_qualname instead: + # qualname = curr_frame.f_code.co_qualname + # if qualname.startswith("WorkspaceManager."): + if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager): + curr_frame = curr_frame.f_back + continue + filename = os.path.basename(curr_frame.f_code.co_filename) + return ( + f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}" + ) + return "unknown" + + if self._locked: + raise AssertionError( + f"Workspace is locked but allocation from '{get_caller_info()}' " + f"requires {required_bytes / _MB:.2f} MB, current size is " + f"{current_size / _MB:.2f} MB. " + "Workspace growth is not allowed after locking." + ) + + for ubatch_id in range(self._num_ubatches): + current_workspace = self._current_workspaces[ubatch_id] + if ( + current_workspace is None + or self._workspace_size_bytes(current_workspace) < required_bytes + ): + # Delete old tensor before allocating new one to avoid + # memory spike from resize_(). resize_() allocates new + # memory before freeing old, which can cause OOM. + # Must clear the list reference first since local var + # is just a copy of the reference. + self._current_workspaces[ubatch_id] = None + del current_workspace + self._current_workspaces[ubatch_id] = torch.empty( + (required_bytes,), dtype=torch.uint8, device=self._device + ) + + if envs.VLLM_DEBUG_WORKSPACE: + logger.info( + "[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> " + "%.2f MB (%d ubatches, total memory %.2f MB)", + get_caller_info(), + current_size / _MB, + required_bytes / _MB, + self._num_ubatches, + required_bytes * self._num_ubatches / _MB, + ) + + current_workspace = self._current_workspaces[dbo_current_ubatch_id()] + + return current_workspace + + +def is_workspace_manager_initialized() -> bool: + """Check if workspace manager has been initialized. + + Returns: + True if workspace manager is initialized, False otherwise. + """ + return _manager is not None + + +def current_workspace_manager() -> "WorkspaceManager": + """Get the current workspace manager instance. + + Raises: + AssertionError: If workspace manager has not been initialized. + """ + assert _manager is not None, ( + "WorkspaceManager not initialized. Call init_workspace_manager() " + "with a device before using workspace functions." + ) + return _manager + + +def init_workspace_manager( + device: torch.device, num_ubatches: int | None = None +) -> None: + """Initialize the workspace manager with a device. + + Must be called before using any workspace functions. Typically called + from GPUModelRunner.__init__. + + Args: + device: The device to allocate workspace on. + num_ubatches: Number of micro-batches. Defaults to 1. + """ + global _manager + if _manager is not None: + logger.warning( + "WorkspaceManager already initialized on device %s, " + "reinitializing on device %s", + _manager._device, + device, + ) + _manager = WorkspaceManager(device, num_ubatches) + + +def lock_workspace() -> None: + """Lock the workspace to prevent further growth. + + After calling this function, any attempt to allocate a workspace larger + than the current size will raise an AssertionError. This ensures that + workspace size is fixed during execution and prevents unexpected memory + allocations in the hot path. + + Example: + # During initialization + init_workspace_manager(device) + reserve_workspace(shape1, dtype1) + reserve_workspace(shape2, dtype2) + + # Lock after warmup/profiling + lock_workspace() + + # Now all get_workspace calls must fit in pre-allocated size + """ + current_workspace_manager().lock() + + +def reset_workspace_manager() -> None: + """Reset the workspace manager to uninitialized state. + + This is primarily intended for testing purposes to allow tests + to reinitialize the workspace manager cleanly. + """ + global _manager + _manager = None diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 267369c730368..1faa1a24ff0ea 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -6,12 +6,12 @@ from typing import Any import torch import torch.distributed -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_world_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.profiler.wrapper import TorchProfilerWrapper from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.v1.worker.xpu_model_runner import XPUModelRunner @@ -36,41 +36,17 @@ class XPUWorker(Worker): assert device_config.device_type == "xpu" assert current_platform.is_xpu() - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + # Torch profiler. Enabled and configured through profiler_config. self.profiler: Any | None = None - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_config = vllm_config.profiler_config + if profiler_config.profiler == "torch": worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" - logger.info( - "Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir, + self.profiler = TorchProfilerWrapper( + profiler_config, + worker_name=worker_name, + local_rank=self.local_rank, + activities=["CPU", "XPU"], ) - logger.debug( - "Profiler config: record_shapes=%s," - "profile_memory=%s,with_stack=%s,with_flops=%s", - envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - envs.VLLM_TORCH_PROFILER_WITH_STACK, - envs.VLLM_TORCH_PROFILER_WITH_FLOPS, - ) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.XPU, - ], - record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, - profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, - with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, - with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, - worker_name=worker_name, - use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, - ), - ) - else: - self.profiler = None # we provide this function due to `torch.xpu.mem_get_info()` doesn't # return correct free_gpu_memory on intel client GPU. We need to